diff --git a/README.md b/README.md index 5695275..5a2df00 100644 --- a/README.md +++ b/README.md @@ -261,6 +261,33 @@ func provideBar(foo Foo) Bar { If used as part of an injector that does not bring in the `Foo` dependency, then the injector will pass the provider the zero value as the `foo` argument. +### Cleanup functions + +If a provider creates a value that needs to be cleaned up (e.g. closing a file), +then it can return a closure to clean up the resource. The injector will use +this to either return an aggregated cleanup function to the caller or to clean +up the resource if a later provider returns an error. + +```go +//goose:provide + +func provideFile(log Logger, path Path) (*os.File, func(), error) { + f, err := os.Open(string(path)) + if err != nil { + return nil, nil, err + } + cleanup := func() { + if err := f.Close(); err != nil { + log.Log(err) + } + } + return f, cleanup, nil +} +``` + +A cleanup function is guaranteed to be called before the cleanup function of any +of the provider's inputs and must have the signature `func()`. + ## Future Work - Support for map bindings. diff --git a/internal/goose/analyze.go b/internal/goose/analyze.go index b8f9c9f..35af3bc 100644 --- a/internal/goose/analyze.go +++ b/internal/goose/analyze.go @@ -24,6 +24,8 @@ type call struct { ins []types.Type // out is the type produced by this provider call. out types.Type + // hasCleanup is true if the provider call returns a cleanup function. + hasCleanup bool // hasErr is true if the provider call returns an error. hasErr bool } @@ -113,6 +115,7 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symr args: args, ins: ins, out: typ, + hasCleanup: p.hasCleanup, hasErr: p.hasErr, }) return nil diff --git a/internal/goose/goose.go b/internal/goose/goose.go index b7af3d4..4de0733 100644 --- a/internal/goose/goose.go +++ b/internal/goose/goose.go @@ -145,17 +145,29 @@ func (g *gen) frame() []byte { // inject emits the code for an injector. func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, sets []symref) error { results := sig.Results() - returnsErr := false + var returnsCleanup, returnsErr bool switch results.Len() { case 0: return fmt.Errorf("inject %s: no return values", name) case 1: - // nothing special + returnsCleanup, returnsErr = false, false case 2: - if t := results.At(1).Type(); !types.Identical(t, errorType) { - return fmt.Errorf("inject %s: second return type is %s; must be error", name, types.TypeString(t, nil)) + switch t := results.At(1).Type(); { + case types.Identical(t, errorType): + returnsCleanup, returnsErr = false, true + case types.Identical(t, cleanupType): + returnsCleanup, returnsErr = true, false + default: + return fmt.Errorf("inject %s: second return type is %s; must be error or func()", name, types.TypeString(t, nil)) } - returnsErr = true + case 3: + if t := results.At(1).Type(); !types.Identical(t, cleanupType) { + return fmt.Errorf("inject %s: second return type is %s; must be func()", name, types.TypeString(t, nil)) + } + if t := results.At(2).Type(); !types.Identical(t, errorType) { + return fmt.Errorf("inject %s: third return type is %s; must be error", name, types.TypeString(t, nil)) + } + returnsCleanup, returnsErr = true, true default: return fmt.Errorf("inject %s: too many return values", name) } @@ -170,6 +182,9 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se return err } for i := range calls { + if calls[i].hasCleanup && !returnsCleanup { + return fmt.Errorf("inject %s: provider for %s returns cleanup but injection does not return cleanup function", name, types.TypeString(calls[i].out, nil)) + } if calls[i].hasErr && !returnsErr { return fmt.Errorf("inject %s: provider for %s returns error but injection not allowed to fail", name, types.TypeString(calls[i].out, nil)) } @@ -194,6 +209,7 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se // Set up local variables paramNames := make([]string, params.Len()) localNames := make([]string, len(calls)) + cleanupNames := make([]string, len(calls)) errVar := disambiguate("err", g.nameInFileScope) collides := func(v string) bool { if v == errVar { @@ -209,6 +225,11 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se return true } } + for _, l := range cleanupNames { + if l == v { + return true + } + } return g.nameInFileScope(v) } @@ -228,7 +249,11 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se paramNames[i] = disambiguate(a, collides) g.p("%s %s", paramNames[i], paramTypes[i]) } - if returnsErr { + if returnsCleanup && returnsErr { + g.p(") (%s, func(), error) {\n", outTypeString) + } else if returnsCleanup { + g.p(") (%s, func()) {\n", outTypeString) + } else if returnsErr { g.p(") (%s, error) {\n", outTypeString) } else { g.p(") %s {\n", outTypeString) @@ -242,6 +267,10 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se lname = disambiguate(lname, collides) localNames[i] = lname g.p("\t%s", lname) + if c.hasCleanup { + cleanupNames[i] = disambiguate("cleanup", collides) + g.p(", %s", cleanupNames[i]) + } if c.hasErr { g.p(", %s", errVar) } @@ -261,8 +290,17 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se g.p(")\n") if c.hasErr { g.p("\tif %s != nil {\n", errVar) + for j := i - 1; j >= 0; j-- { + if calls[j].hasCleanup { + g.p("\t\t%s()\n", cleanupNames[j]) + } + } + g.p("\t\treturn %s", zv) + if returnsCleanup { + g.p(", nil") + } // TODO(light): give information about failing provider - g.p("\t\treturn %s, err\n", zv) + g.p(", err\n") g.p("\t}\n") } } @@ -276,6 +314,15 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se } else { g.p("\treturn %s", localNames[len(calls)-1]) } + if returnsCleanup { + g.p(", func() {\n") + for i := len(calls) - 1; i >= 0; i-- { + if calls[i].hasCleanup { + g.p("\t\t%s()\n", cleanupNames[i]) + } + } + g.p("\t}") + } if returnsErr { g.p(", nil") } @@ -419,4 +466,7 @@ func disambiguate(name string, collides func(string) bool) string { } } -var errorType = types.Universe.Lookup("error").Type() +var ( + errorType = types.Universe.Lookup("error").Type() + cleanupType = types.NewSignature(nil, nil, nil, false) +) diff --git a/internal/goose/parse.go b/internal/goose/parse.go index c4f6db8..aa2c159 100644 --- a/internal/goose/parse.go +++ b/internal/goose/parse.go @@ -44,6 +44,7 @@ type providerInfo struct { pos token.Pos // provider function definition args []providerInput out types.Type + hasCleanup bool hasErr bool } @@ -224,15 +225,27 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope fpos := fn.Pos() r := sig.Results() - var hasErr bool + var hasCleanup, hasErr bool switch r.Len() { case 1: - hasErr = false + hasCleanup, hasErr = false, false case 2: - if t := r.At(1).Type(); !types.Identical(t, errorType) { - return fmt.Errorf("%v: wrong signature for provider %s: second return type must be error", fctx.fset.Position(fpos), fn.Name.Name) + switch t := r.At(1).Type(); { + case types.Identical(t, errorType): + hasCleanup, hasErr = false, true + case types.Identical(t, cleanupType): + hasCleanup, hasErr = true, false + default: + return fmt.Errorf("%v: wrong signature for provider %s: second return type must be error or func()", fctx.fset.Position(fpos), fn.Name.Name) } - hasErr = true + case 3: + if t := r.At(1).Type(); !types.Identical(t, cleanupType) { + return fmt.Errorf("%v: wrong signature for provider %s: second return type must be func()", fctx.fset.Position(fpos), fn.Name.Name) + } + if t := r.At(2).Type(); !types.Identical(t, errorType) { + return fmt.Errorf("%v: wrong signature for provider %s: third return type must be error", fctx.fset.Position(fpos), fn.Name.Name) + } + hasCleanup, hasErr = true, true default: return fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fctx.fset.Position(fpos), fn.Name.Name) } @@ -244,6 +257,7 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope pos: fn.Pos(), args: make([]providerInput, params.Len()), out: out, + hasCleanup: hasCleanup, hasErr: hasErr, } for i := 0; i < params.Len(); i++ { diff --git a/internal/goose/testdata/Cleanup/foo/foo.go b/internal/goose/testdata/Cleanup/foo/foo.go new file mode 100644 index 0000000..e086644 --- /dev/null +++ b/internal/goose/testdata/Cleanup/foo/foo.go @@ -0,0 +1,32 @@ +package main + +import "fmt" + +func main() { + bar, cleanup := injectBar() + fmt.Println(*bar) + cleanup() + fmt.Println(*bar) +} + +type Foo int +type Bar int + +//goose:provide Foo +func provideFoo() (*Foo, func()) { + foo := new(Foo) + *foo = 42 + return foo, func() { *foo = 0 } +} + +//goose:provide Bar +func provideBar(foo *Foo) (*Bar, func()) { + bar := new(Bar) + *bar = 77 + return bar, func() { + if *foo == 0 { + panic("foo cleaned up before bar") + } + *bar = 0 + } +} diff --git a/internal/goose/testdata/Cleanup/foo/goose.go b/internal/goose/testdata/Cleanup/foo/goose.go new file mode 100644 index 0000000..b56a73e --- /dev/null +++ b/internal/goose/testdata/Cleanup/foo/goose.go @@ -0,0 +1,8 @@ +//+build gooseinject + +package main + +//goose:use Foo +//goose:use Bar + +func injectBar() (*Bar, func()) diff --git a/internal/goose/testdata/Cleanup/out.txt b/internal/goose/testdata/Cleanup/out.txt new file mode 100644 index 0000000..d770642 --- /dev/null +++ b/internal/goose/testdata/Cleanup/out.txt @@ -0,0 +1,2 @@ +77 +0 diff --git a/internal/goose/testdata/Cleanup/pkg b/internal/goose/testdata/Cleanup/pkg new file mode 100644 index 0000000..257cc56 --- /dev/null +++ b/internal/goose/testdata/Cleanup/pkg @@ -0,0 +1 @@ +foo diff --git a/internal/goose/testdata/PartialCleanup/foo/foo.go b/internal/goose/testdata/PartialCleanup/foo/foo.go new file mode 100644 index 0000000..4aec1b2 --- /dev/null +++ b/internal/goose/testdata/PartialCleanup/foo/foo.go @@ -0,0 +1,51 @@ +package main + +import ( + "errors" + "fmt" + "strings" +) + +var ( + cleanedFoo = false + cleanedBar = false +) + +func main() { + _, cleanup, err := injectBaz() + if err == nil { + fmt.Println("") + } else { + fmt.Println(strings.Contains(err.Error(), "bork!")) + } + fmt.Println(cleanedFoo, cleanedBar, cleanup == nil) +} + +type Foo int +type Bar int +type Baz int + +//goose:provide Foo +func provideFoo() (*Foo, func()) { + foo := new(Foo) + *foo = 42 + return foo, func() { *foo = 0; cleanedFoo = true } +} + +//goose:provide Bar +func provideBar(foo *Foo) (*Bar, func(), error) { + bar := new(Bar) + *bar = 77 + return bar, func() { + if *foo == 0 { + panic("foo cleaned up before bar") + } + *bar = 0 + cleanedBar = true + }, nil +} + +//goose:provide Baz +func provideBaz(bar *Bar) (Baz, error) { + return 0, errors.New("bork!") +} diff --git a/internal/goose/testdata/PartialCleanup/foo/goose.go b/internal/goose/testdata/PartialCleanup/foo/goose.go new file mode 100644 index 0000000..c9fe147 --- /dev/null +++ b/internal/goose/testdata/PartialCleanup/foo/goose.go @@ -0,0 +1,9 @@ +//+build gooseinject + +package main + +//goose:use Foo +//goose:use Bar +//goose:use Baz + +func injectBaz() (Baz, func(), error) diff --git a/internal/goose/testdata/PartialCleanup/out.txt b/internal/goose/testdata/PartialCleanup/out.txt new file mode 100644 index 0000000..649cfc2 --- /dev/null +++ b/internal/goose/testdata/PartialCleanup/out.txt @@ -0,0 +1,2 @@ +true +true true true diff --git a/internal/goose/testdata/PartialCleanup/pkg b/internal/goose/testdata/PartialCleanup/pkg new file mode 100644 index 0000000..257cc56 --- /dev/null +++ b/internal/goose/testdata/PartialCleanup/pkg @@ -0,0 +1 @@ +foo