goose: support provider cleanup functions
Documented and updated appropriate providers to reflect. Reviewed-by: Tuo Shan <shantuo@google.com>
This commit is contained in:
27
README.md
27
README.md
@@ -261,6 +261,33 @@ func provideBar(foo Foo) Bar {
|
|||||||
If used as part of an injector that does not bring in the `Foo` dependency, then
|
If used as part of an injector that does not bring in the `Foo` dependency, then
|
||||||
the injector will pass the provider the zero value as the `foo` argument.
|
the injector will pass the provider the zero value as the `foo` argument.
|
||||||
|
|
||||||
|
### Cleanup functions
|
||||||
|
|
||||||
|
If a provider creates a value that needs to be cleaned up (e.g. closing a file),
|
||||||
|
then it can return a closure to clean up the resource. The injector will use
|
||||||
|
this to either return an aggregated cleanup function to the caller or to clean
|
||||||
|
up the resource if a later provider returns an error.
|
||||||
|
|
||||||
|
```go
|
||||||
|
//goose:provide
|
||||||
|
|
||||||
|
func provideFile(log Logger, path Path) (*os.File, func(), error) {
|
||||||
|
f, err := os.Open(string(path))
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
cleanup := func() {
|
||||||
|
if err := f.Close(); err != nil {
|
||||||
|
log.Log(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return f, cleanup, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
A cleanup function is guaranteed to be called before the cleanup function of any
|
||||||
|
of the provider's inputs and must have the signature `func()`.
|
||||||
|
|
||||||
## Future Work
|
## Future Work
|
||||||
|
|
||||||
- Support for map bindings.
|
- Support for map bindings.
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ type call struct {
|
|||||||
ins []types.Type
|
ins []types.Type
|
||||||
// out is the type produced by this provider call.
|
// out is the type produced by this provider call.
|
||||||
out types.Type
|
out types.Type
|
||||||
|
// hasCleanup is true if the provider call returns a cleanup function.
|
||||||
|
hasCleanup bool
|
||||||
// hasErr is true if the provider call returns an error.
|
// hasErr is true if the provider call returns an error.
|
||||||
hasErr bool
|
hasErr bool
|
||||||
}
|
}
|
||||||
@@ -113,6 +115,7 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symr
|
|||||||
args: args,
|
args: args,
|
||||||
ins: ins,
|
ins: ins,
|
||||||
out: typ,
|
out: typ,
|
||||||
|
hasCleanup: p.hasCleanup,
|
||||||
hasErr: p.hasErr,
|
hasErr: p.hasErr,
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -145,17 +145,29 @@ func (g *gen) frame() []byte {
|
|||||||
// inject emits the code for an injector.
|
// inject emits the code for an injector.
|
||||||
func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, sets []symref) error {
|
func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, sets []symref) error {
|
||||||
results := sig.Results()
|
results := sig.Results()
|
||||||
returnsErr := false
|
var returnsCleanup, returnsErr bool
|
||||||
switch results.Len() {
|
switch results.Len() {
|
||||||
case 0:
|
case 0:
|
||||||
return fmt.Errorf("inject %s: no return values", name)
|
return fmt.Errorf("inject %s: no return values", name)
|
||||||
case 1:
|
case 1:
|
||||||
// nothing special
|
returnsCleanup, returnsErr = false, false
|
||||||
case 2:
|
case 2:
|
||||||
if t := results.At(1).Type(); !types.Identical(t, errorType) {
|
switch t := results.At(1).Type(); {
|
||||||
return fmt.Errorf("inject %s: second return type is %s; must be error", name, types.TypeString(t, nil))
|
case types.Identical(t, errorType):
|
||||||
|
returnsCleanup, returnsErr = false, true
|
||||||
|
case types.Identical(t, cleanupType):
|
||||||
|
returnsCleanup, returnsErr = true, false
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("inject %s: second return type is %s; must be error or func()", name, types.TypeString(t, nil))
|
||||||
}
|
}
|
||||||
returnsErr = true
|
case 3:
|
||||||
|
if t := results.At(1).Type(); !types.Identical(t, cleanupType) {
|
||||||
|
return fmt.Errorf("inject %s: second return type is %s; must be func()", name, types.TypeString(t, nil))
|
||||||
|
}
|
||||||
|
if t := results.At(2).Type(); !types.Identical(t, errorType) {
|
||||||
|
return fmt.Errorf("inject %s: third return type is %s; must be error", name, types.TypeString(t, nil))
|
||||||
|
}
|
||||||
|
returnsCleanup, returnsErr = true, true
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("inject %s: too many return values", name)
|
return fmt.Errorf("inject %s: too many return values", name)
|
||||||
}
|
}
|
||||||
@@ -170,6 +182,9 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for i := range calls {
|
for i := range calls {
|
||||||
|
if calls[i].hasCleanup && !returnsCleanup {
|
||||||
|
return fmt.Errorf("inject %s: provider for %s returns cleanup but injection does not return cleanup function", name, types.TypeString(calls[i].out, nil))
|
||||||
|
}
|
||||||
if calls[i].hasErr && !returnsErr {
|
if calls[i].hasErr && !returnsErr {
|
||||||
return fmt.Errorf("inject %s: provider for %s returns error but injection not allowed to fail", name, types.TypeString(calls[i].out, nil))
|
return fmt.Errorf("inject %s: provider for %s returns error but injection not allowed to fail", name, types.TypeString(calls[i].out, nil))
|
||||||
}
|
}
|
||||||
@@ -194,6 +209,7 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se
|
|||||||
// Set up local variables
|
// Set up local variables
|
||||||
paramNames := make([]string, params.Len())
|
paramNames := make([]string, params.Len())
|
||||||
localNames := make([]string, len(calls))
|
localNames := make([]string, len(calls))
|
||||||
|
cleanupNames := make([]string, len(calls))
|
||||||
errVar := disambiguate("err", g.nameInFileScope)
|
errVar := disambiguate("err", g.nameInFileScope)
|
||||||
collides := func(v string) bool {
|
collides := func(v string) bool {
|
||||||
if v == errVar {
|
if v == errVar {
|
||||||
@@ -209,6 +225,11 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for _, l := range cleanupNames {
|
||||||
|
if l == v {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
return g.nameInFileScope(v)
|
return g.nameInFileScope(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -228,7 +249,11 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se
|
|||||||
paramNames[i] = disambiguate(a, collides)
|
paramNames[i] = disambiguate(a, collides)
|
||||||
g.p("%s %s", paramNames[i], paramTypes[i])
|
g.p("%s %s", paramNames[i], paramTypes[i])
|
||||||
}
|
}
|
||||||
if returnsErr {
|
if returnsCleanup && returnsErr {
|
||||||
|
g.p(") (%s, func(), error) {\n", outTypeString)
|
||||||
|
} else if returnsCleanup {
|
||||||
|
g.p(") (%s, func()) {\n", outTypeString)
|
||||||
|
} else if returnsErr {
|
||||||
g.p(") (%s, error) {\n", outTypeString)
|
g.p(") (%s, error) {\n", outTypeString)
|
||||||
} else {
|
} else {
|
||||||
g.p(") %s {\n", outTypeString)
|
g.p(") %s {\n", outTypeString)
|
||||||
@@ -242,6 +267,10 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se
|
|||||||
lname = disambiguate(lname, collides)
|
lname = disambiguate(lname, collides)
|
||||||
localNames[i] = lname
|
localNames[i] = lname
|
||||||
g.p("\t%s", lname)
|
g.p("\t%s", lname)
|
||||||
|
if c.hasCleanup {
|
||||||
|
cleanupNames[i] = disambiguate("cleanup", collides)
|
||||||
|
g.p(", %s", cleanupNames[i])
|
||||||
|
}
|
||||||
if c.hasErr {
|
if c.hasErr {
|
||||||
g.p(", %s", errVar)
|
g.p(", %s", errVar)
|
||||||
}
|
}
|
||||||
@@ -261,8 +290,17 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se
|
|||||||
g.p(")\n")
|
g.p(")\n")
|
||||||
if c.hasErr {
|
if c.hasErr {
|
||||||
g.p("\tif %s != nil {\n", errVar)
|
g.p("\tif %s != nil {\n", errVar)
|
||||||
|
for j := i - 1; j >= 0; j-- {
|
||||||
|
if calls[j].hasCleanup {
|
||||||
|
g.p("\t\t%s()\n", cleanupNames[j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
g.p("\t\treturn %s", zv)
|
||||||
|
if returnsCleanup {
|
||||||
|
g.p(", nil")
|
||||||
|
}
|
||||||
// TODO(light): give information about failing provider
|
// TODO(light): give information about failing provider
|
||||||
g.p("\t\treturn %s, err\n", zv)
|
g.p(", err\n")
|
||||||
g.p("\t}\n")
|
g.p("\t}\n")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -276,6 +314,15 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se
|
|||||||
} else {
|
} else {
|
||||||
g.p("\treturn %s", localNames[len(calls)-1])
|
g.p("\treturn %s", localNames[len(calls)-1])
|
||||||
}
|
}
|
||||||
|
if returnsCleanup {
|
||||||
|
g.p(", func() {\n")
|
||||||
|
for i := len(calls) - 1; i >= 0; i-- {
|
||||||
|
if calls[i].hasCleanup {
|
||||||
|
g.p("\t\t%s()\n", cleanupNames[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
g.p("\t}")
|
||||||
|
}
|
||||||
if returnsErr {
|
if returnsErr {
|
||||||
g.p(", nil")
|
g.p(", nil")
|
||||||
}
|
}
|
||||||
@@ -419,4 +466,7 @@ func disambiguate(name string, collides func(string) bool) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var errorType = types.Universe.Lookup("error").Type()
|
var (
|
||||||
|
errorType = types.Universe.Lookup("error").Type()
|
||||||
|
cleanupType = types.NewSignature(nil, nil, nil, false)
|
||||||
|
)
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ type providerInfo struct {
|
|||||||
pos token.Pos // provider function definition
|
pos token.Pos // provider function definition
|
||||||
args []providerInput
|
args []providerInput
|
||||||
out types.Type
|
out types.Type
|
||||||
|
hasCleanup bool
|
||||||
hasErr bool
|
hasErr bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -224,15 +225,27 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope
|
|||||||
|
|
||||||
fpos := fn.Pos()
|
fpos := fn.Pos()
|
||||||
r := sig.Results()
|
r := sig.Results()
|
||||||
var hasErr bool
|
var hasCleanup, hasErr bool
|
||||||
switch r.Len() {
|
switch r.Len() {
|
||||||
case 1:
|
case 1:
|
||||||
hasErr = false
|
hasCleanup, hasErr = false, false
|
||||||
case 2:
|
case 2:
|
||||||
if t := r.At(1).Type(); !types.Identical(t, errorType) {
|
switch t := r.At(1).Type(); {
|
||||||
return fmt.Errorf("%v: wrong signature for provider %s: second return type must be error", fctx.fset.Position(fpos), fn.Name.Name)
|
case types.Identical(t, errorType):
|
||||||
|
hasCleanup, hasErr = false, true
|
||||||
|
case types.Identical(t, cleanupType):
|
||||||
|
hasCleanup, hasErr = true, false
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%v: wrong signature for provider %s: second return type must be error or func()", fctx.fset.Position(fpos), fn.Name.Name)
|
||||||
}
|
}
|
||||||
hasErr = true
|
case 3:
|
||||||
|
if t := r.At(1).Type(); !types.Identical(t, cleanupType) {
|
||||||
|
return fmt.Errorf("%v: wrong signature for provider %s: second return type must be func()", fctx.fset.Position(fpos), fn.Name.Name)
|
||||||
|
}
|
||||||
|
if t := r.At(2).Type(); !types.Identical(t, errorType) {
|
||||||
|
return fmt.Errorf("%v: wrong signature for provider %s: third return type must be error", fctx.fset.Position(fpos), fn.Name.Name)
|
||||||
|
}
|
||||||
|
hasCleanup, hasErr = true, true
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fctx.fset.Position(fpos), fn.Name.Name)
|
return fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fctx.fset.Position(fpos), fn.Name.Name)
|
||||||
}
|
}
|
||||||
@@ -244,6 +257,7 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope
|
|||||||
pos: fn.Pos(),
|
pos: fn.Pos(),
|
||||||
args: make([]providerInput, params.Len()),
|
args: make([]providerInput, params.Len()),
|
||||||
out: out,
|
out: out,
|
||||||
|
hasCleanup: hasCleanup,
|
||||||
hasErr: hasErr,
|
hasErr: hasErr,
|
||||||
}
|
}
|
||||||
for i := 0; i < params.Len(); i++ {
|
for i := 0; i < params.Len(); i++ {
|
||||||
|
|||||||
32
internal/goose/testdata/Cleanup/foo/foo.go
vendored
Normal file
32
internal/goose/testdata/Cleanup/foo/foo.go
vendored
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
bar, cleanup := injectBar()
|
||||||
|
fmt.Println(*bar)
|
||||||
|
cleanup()
|
||||||
|
fmt.Println(*bar)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Foo int
|
||||||
|
type Bar int
|
||||||
|
|
||||||
|
//goose:provide Foo
|
||||||
|
func provideFoo() (*Foo, func()) {
|
||||||
|
foo := new(Foo)
|
||||||
|
*foo = 42
|
||||||
|
return foo, func() { *foo = 0 }
|
||||||
|
}
|
||||||
|
|
||||||
|
//goose:provide Bar
|
||||||
|
func provideBar(foo *Foo) (*Bar, func()) {
|
||||||
|
bar := new(Bar)
|
||||||
|
*bar = 77
|
||||||
|
return bar, func() {
|
||||||
|
if *foo == 0 {
|
||||||
|
panic("foo cleaned up before bar")
|
||||||
|
}
|
||||||
|
*bar = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
8
internal/goose/testdata/Cleanup/foo/goose.go
vendored
Normal file
8
internal/goose/testdata/Cleanup/foo/goose.go
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
//+build gooseinject
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
//goose:use Foo
|
||||||
|
//goose:use Bar
|
||||||
|
|
||||||
|
func injectBar() (*Bar, func())
|
||||||
2
internal/goose/testdata/Cleanup/out.txt
vendored
Normal file
2
internal/goose/testdata/Cleanup/out.txt
vendored
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
77
|
||||||
|
0
|
||||||
1
internal/goose/testdata/Cleanup/pkg
vendored
Normal file
1
internal/goose/testdata/Cleanup/pkg
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
foo
|
||||||
51
internal/goose/testdata/PartialCleanup/foo/foo.go
vendored
Normal file
51
internal/goose/testdata/PartialCleanup/foo/foo.go
vendored
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
cleanedFoo = false
|
||||||
|
cleanedBar = false
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
_, cleanup, err := injectBaz()
|
||||||
|
if err == nil {
|
||||||
|
fmt.Println("<nil>")
|
||||||
|
} else {
|
||||||
|
fmt.Println(strings.Contains(err.Error(), "bork!"))
|
||||||
|
}
|
||||||
|
fmt.Println(cleanedFoo, cleanedBar, cleanup == nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Foo int
|
||||||
|
type Bar int
|
||||||
|
type Baz int
|
||||||
|
|
||||||
|
//goose:provide Foo
|
||||||
|
func provideFoo() (*Foo, func()) {
|
||||||
|
foo := new(Foo)
|
||||||
|
*foo = 42
|
||||||
|
return foo, func() { *foo = 0; cleanedFoo = true }
|
||||||
|
}
|
||||||
|
|
||||||
|
//goose:provide Bar
|
||||||
|
func provideBar(foo *Foo) (*Bar, func(), error) {
|
||||||
|
bar := new(Bar)
|
||||||
|
*bar = 77
|
||||||
|
return bar, func() {
|
||||||
|
if *foo == 0 {
|
||||||
|
panic("foo cleaned up before bar")
|
||||||
|
}
|
||||||
|
*bar = 0
|
||||||
|
cleanedBar = true
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//goose:provide Baz
|
||||||
|
func provideBaz(bar *Bar) (Baz, error) {
|
||||||
|
return 0, errors.New("bork!")
|
||||||
|
}
|
||||||
9
internal/goose/testdata/PartialCleanup/foo/goose.go
vendored
Normal file
9
internal/goose/testdata/PartialCleanup/foo/goose.go
vendored
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
//+build gooseinject
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
//goose:use Foo
|
||||||
|
//goose:use Bar
|
||||||
|
//goose:use Baz
|
||||||
|
|
||||||
|
func injectBaz() (Baz, func(), error)
|
||||||
2
internal/goose/testdata/PartialCleanup/out.txt
vendored
Normal file
2
internal/goose/testdata/PartialCleanup/out.txt
vendored
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
true
|
||||||
|
true true true
|
||||||
1
internal/goose/testdata/PartialCleanup/pkg
vendored
Normal file
1
internal/goose/testdata/PartialCleanup/pkg
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
foo
|
||||||
Reference in New Issue
Block a user