goose: support provider cleanup functions

Documented and updated appropriate providers to reflect.

Reviewed-by: Tuo Shan <shantuo@google.com>
This commit is contained in:
Ross Light
2018-04-03 13:13:15 -07:00
parent 1380f96c06
commit ccf63fec5d
12 changed files with 213 additions and 13 deletions

View File

@@ -145,17 +145,29 @@ func (g *gen) frame() []byte {
// inject emits the code for an injector.
func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, sets []symref) error {
results := sig.Results()
returnsErr := false
var returnsCleanup, returnsErr bool
switch results.Len() {
case 0:
return fmt.Errorf("inject %s: no return values", name)
case 1:
// nothing special
returnsCleanup, returnsErr = false, false
case 2:
if t := results.At(1).Type(); !types.Identical(t, errorType) {
return fmt.Errorf("inject %s: second return type is %s; must be error", name, types.TypeString(t, nil))
switch t := results.At(1).Type(); {
case types.Identical(t, errorType):
returnsCleanup, returnsErr = false, true
case types.Identical(t, cleanupType):
returnsCleanup, returnsErr = true, false
default:
return fmt.Errorf("inject %s: second return type is %s; must be error or func()", name, types.TypeString(t, nil))
}
returnsErr = true
case 3:
if t := results.At(1).Type(); !types.Identical(t, cleanupType) {
return fmt.Errorf("inject %s: second return type is %s; must be func()", name, types.TypeString(t, nil))
}
if t := results.At(2).Type(); !types.Identical(t, errorType) {
return fmt.Errorf("inject %s: third return type is %s; must be error", name, types.TypeString(t, nil))
}
returnsCleanup, returnsErr = true, true
default:
return fmt.Errorf("inject %s: too many return values", name)
}
@@ -170,6 +182,9 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se
return err
}
for i := range calls {
if calls[i].hasCleanup && !returnsCleanup {
return fmt.Errorf("inject %s: provider for %s returns cleanup but injection does not return cleanup function", name, types.TypeString(calls[i].out, nil))
}
if calls[i].hasErr && !returnsErr {
return fmt.Errorf("inject %s: provider for %s returns error but injection not allowed to fail", name, types.TypeString(calls[i].out, nil))
}
@@ -194,6 +209,7 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se
// Set up local variables
paramNames := make([]string, params.Len())
localNames := make([]string, len(calls))
cleanupNames := make([]string, len(calls))
errVar := disambiguate("err", g.nameInFileScope)
collides := func(v string) bool {
if v == errVar {
@@ -209,6 +225,11 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se
return true
}
}
for _, l := range cleanupNames {
if l == v {
return true
}
}
return g.nameInFileScope(v)
}
@@ -228,7 +249,11 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se
paramNames[i] = disambiguate(a, collides)
g.p("%s %s", paramNames[i], paramTypes[i])
}
if returnsErr {
if returnsCleanup && returnsErr {
g.p(") (%s, func(), error) {\n", outTypeString)
} else if returnsCleanup {
g.p(") (%s, func()) {\n", outTypeString)
} else if returnsErr {
g.p(") (%s, error) {\n", outTypeString)
} else {
g.p(") %s {\n", outTypeString)
@@ -242,6 +267,10 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se
lname = disambiguate(lname, collides)
localNames[i] = lname
g.p("\t%s", lname)
if c.hasCleanup {
cleanupNames[i] = disambiguate("cleanup", collides)
g.p(", %s", cleanupNames[i])
}
if c.hasErr {
g.p(", %s", errVar)
}
@@ -261,8 +290,17 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se
g.p(")\n")
if c.hasErr {
g.p("\tif %s != nil {\n", errVar)
for j := i - 1; j >= 0; j-- {
if calls[j].hasCleanup {
g.p("\t\t%s()\n", cleanupNames[j])
}
}
g.p("\t\treturn %s", zv)
if returnsCleanup {
g.p(", nil")
}
// TODO(light): give information about failing provider
g.p("\t\treturn %s, err\n", zv)
g.p(", err\n")
g.p("\t}\n")
}
}
@@ -276,6 +314,15 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se
} else {
g.p("\treturn %s", localNames[len(calls)-1])
}
if returnsCleanup {
g.p(", func() {\n")
for i := len(calls) - 1; i >= 0; i-- {
if calls[i].hasCleanup {
g.p("\t\t%s()\n", cleanupNames[i])
}
}
g.p("\t}")
}
if returnsErr {
g.p(", nil")
}
@@ -419,4 +466,7 @@ func disambiguate(name string, collides func(string) bool) string {
}
}
var errorType = types.Universe.Lookup("error").Type()
var (
errorType = types.Universe.Lookup("error").Type()
cleanupType = types.NewSignature(nil, nil, nil, false)
)