diff --git a/README.md b/README.md index 5f22604..5695275 100644 --- a/README.md +++ b/README.md @@ -208,6 +208,43 @@ type MySQLConnectionString string ## Advanced Features +### Binding Interfaces + +Frequently, dependency injection is used to bind concrete implementations for an +interface. goose matches inputs to outputs via [type identity][], so the +inclination might be to create a provider that returns an interface type. +However, this would not be idiomatic, since the Go best practice is to [return +concrete types][]. Instead, you can declare an interface binding in a +provider set: + +```go +type Fooer interface { + Foo() string +} + +type Bar string + +func (b *Bar) Foo() string { + return string(*b) +} + +//goose:provide BarFooer +func provideBar() *Bar { + b := new(Bar) + *b = "Hello, World!" + return b +} + +//goose:bind BarFooer Fooer *Bar +``` + +The syntax is provider set name, interface type, and finally the concrete type. +An interface binding does not necessarily need to have a provider in the same +set that provides the concrete type. + +[type identity]: https://golang.org/ref/spec#Type_identity +[return concrete types]: https://github.com/golang/go/wiki/CodeReviewComments#interfaces + ### Optional Inputs A provider input can be marked optional using `goose:optional`: @@ -230,6 +267,6 @@ the injector will pass the provider the zero value as the `foo` argument. - Support for multiple provider outputs. - Support for field binding: declare a struct as a provider and have it be filled in by the corresponding bindings from the graph. -- Currently, all dependency satisfaction is done using identity. I'd like to - use a limited form of assignability for interface types, but I'm unsure - how well this implicit satisfaction will work in practice. +- Tighter validation for a provider set (cycles in unused providers goes + unreported currently) +- Visualization for provider sets diff --git a/internal/goose/analyze.go b/internal/goose/analyze.go index 7e3a5a4..b8f9c9f 100644 --- a/internal/goose/analyze.go +++ b/internal/goose/analyze.go @@ -30,7 +30,7 @@ type call struct { // solve finds the sequence of calls required to produce an output type // with an optional set of provided inputs. -func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []providerSetRef) ([]call, error) { +func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symref) ([]call, error) { for i, g := range given { for _, h := range given[:i] { if types.Identical(g, h) { @@ -82,6 +82,14 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []prov // TODO(light): give name of provider return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2].typ, nil)) } + if !types.Identical(p.out, typ) { + // Interface binding. Don't create a call ourselves. + if err := visit(append(trail, providerInput{typ: p.out})); err != nil { + return err + } + index.Set(typ, index.At(p.out)) + return nil + } for _, a := range p.args { // TODO(light): this will discard grown trail arrays. if err := visit(append(trail, a)); err != nil { @@ -115,16 +123,22 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []prov return calls, nil } -func buildProviderMap(mc *providerSetCache, sets []providerSetRef) (*typeutil.Map, error) { +func buildProviderMap(mc *providerSetCache, sets []symref) (*typeutil.Map, error) { type nextEnt struct { - to providerSetRef + to symref - from providerSetRef + from symref pos token.Pos } + type binding struct { + ifaceBinding + pset symref + from symref + } pm := new(typeutil.Map) // to *providerInfo - visited := make(map[providerSetRef]struct{}) + var bindings []binding + visited := make(map[symref]struct{}) var next []nextEnt for _, ref := range sets { next = append(next, nextEnt{to: ref}) @@ -137,28 +151,60 @@ func buildProviderMap(mc *providerSetCache, sets []providerSetRef) (*typeutil.Ma continue } visited[curr.to] = struct{}{} - mod, err := mc.get(curr.to) + pset, err := mc.get(curr.to) if err != nil { if !curr.pos.IsValid() { return nil, err } return nil, fmt.Errorf("%v: %v", mc.fset.Position(curr.pos), err) } - for _, p := range mod.providers { + for _, p := range pset.providers { if prev := pm.At(p.out); prev != nil { pos := mc.fset.Position(p.pos) typ := types.TypeString(p.out, nil) prevPos := mc.fset.Position(prev.(*providerInfo).pos) - if curr.from.importPath != "" { + if curr.from.importPath == "" { + // Provider set is imported directly by injector. return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos) } return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, curr.from, prevPos) } pm.Set(p.out, p) } - for _, imp := range mod.imports { - next = append(next, nextEnt{to: imp.providerSetRef, from: curr.to, pos: imp.pos}) + for _, b := range pset.bindings { + bindings = append(bindings, binding{ + ifaceBinding: b, + pset: curr.to, + from: curr.from, + }) } + for _, imp := range pset.imports { + next = append(next, nextEnt{to: imp.symref, from: curr.to, pos: imp.pos}) + } + } + for _, b := range bindings { + if prev := pm.At(b.iface); prev != nil { + pos := mc.fset.Position(b.pos) + typ := types.TypeString(b.iface, nil) + // TODO(light): error message for conflicting with another interface binding will point at provider function instead of binding. + prevPos := mc.fset.Position(prev.(*providerInfo).pos) + if b.from.importPath == "" { + // Provider set is imported directly by injector. + return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos) + } + return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, b.from, prevPos) + } + concrete := pm.At(b.provided) + if concrete == nil { + pos := mc.fset.Position(b.pos) + typ := types.TypeString(b.provided, nil) + if b.from.importPath == "" { + // Concrete provider is imported directly by injector. + return nil, fmt.Errorf("%v: no binding for %s", pos, typ) + } + return nil, fmt.Errorf("%v: no binding for %s (imported by %v)", pos, typ, b.from) + } + pm.Set(b.iface, concrete) } return pm, nil } diff --git a/internal/goose/goose.go b/internal/goose/goose.go index 7ae2daf..b7af3d4 100644 --- a/internal/goose/goose.go +++ b/internal/goose/goose.go @@ -66,7 +66,7 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) { if dg.decl != decl { dg = directiveGroup{} } - var sets []providerSetRef + var sets []symref for _, d := range dg.dirs { if d.kind != "use" { return nil, fmt.Errorf("%v: cannot use %s directive on inject function", prog.Fset.Position(d.pos), d.kind) @@ -76,7 +76,7 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) { return nil, fmt.Errorf("%v: goose:use must have at least one provider set reference", prog.Fset.Position(d.pos)) } for _, arg := range args { - ref, err := parseProviderSetRef(r, arg, fileScope, g.currPackage, d.pos) + ref, err := parseSymbolRef(r, arg, fileScope, g.currPackage, d.pos) if err != nil { return nil, fmt.Errorf("%v: %v", prog.Fset.Position(d.pos), err) } @@ -143,7 +143,7 @@ func (g *gen) frame() []byte { } // inject emits the code for an injector. -func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, sets []providerSetRef) error { +func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, sets []symref) error { results := sig.Results() returnsErr := false switch results.Len() { diff --git a/internal/goose/parse.go b/internal/goose/parse.go index 6e38a5b..c4f6db8 100644 --- a/internal/goose/parse.go +++ b/internal/goose/parse.go @@ -18,11 +18,22 @@ import ( // providerSet. type providerSet struct { providers []*providerInfo + bindings []ifaceBinding imports []providerSetImport } +// An ifaceBinding declares that a type should be used to satisfy inputs +// of the given interface type. +// +// provided is always a type that is assignable to iface. +type ifaceBinding struct { + iface types.Type + provided types.Type + pos token.Pos +} + type providerSetImport struct { - providerSetRef + symref pos token.Pos } @@ -30,7 +41,7 @@ type providerSetImport struct { type providerInfo struct { importPath string funcName string - pos token.Pos + pos token.Pos // provider function definition args []providerInput out types.Type hasErr bool @@ -80,43 +91,94 @@ func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(d.pos)) case "use": // Ignore, picked up by injector flow. + case "bind": + args := d.args() + if len(args) != 3 { + return fmt.Errorf("%v: invalid binding: expected TARGET IFACE TYPE", fctx.fset.Position(d.pos)) + } + ifaceRef, err := parseSymbolRef(fctx.r, args[1], scope, fctx.pkg.Path(), d.pos) + if err != nil { + return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err) + } + ifaceObj, err := ifaceRef.resolveObject(fctx.pkg) + if err != nil { + return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err) + } + ifaceDecl, ok := ifaceObj.(*types.TypeName) + if !ok { + return fmt.Errorf("%v: %v does not name a type", fctx.fset.Position(d.pos), ifaceRef) + } + iface := ifaceDecl.Type() + methodSet, ok := iface.Underlying().(*types.Interface) + if !ok { + return fmt.Errorf("%v: %v does not name an interface type", fctx.fset.Position(d.pos), ifaceRef) + } + + providedRef, err := parseSymbolRef(fctx.r, strings.TrimPrefix(args[2], "*"), scope, fctx.pkg.Path(), d.pos) + if err != nil { + return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err) + } + providedObj, err := providedRef.resolveObject(fctx.pkg) + if err != nil { + return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err) + } + providedDecl, ok := providedObj.(*types.TypeName) + if !ok { + return fmt.Errorf("%v: %v does not name a type", fctx.fset.Position(d.pos), providedRef) + } + provided := providedDecl.Type() + if types.Identical(provided, iface) { + return fmt.Errorf("%v: cannot bind interface to itself", fctx.fset.Position(d.pos)) + } + if strings.HasPrefix(args[2], "*") { + provided = types.NewPointer(provided) + } + if !types.Implements(provided, methodSet) { + return fmt.Errorf("%v: %s does not implement %s", fctx.fset.Position(d.pos), types.TypeString(provided, nil), types.TypeString(iface, nil)) + } + + name := args[0] + if pset := sets[name]; pset != nil { + pset.bindings = append(pset.bindings, ifaceBinding{ + iface: iface, + provided: provided, + }) + } else { + sets[name] = &providerSet{ + bindings: []ifaceBinding{{ + iface: iface, + provided: provided, + }}, + } + } case "import": args := d.args() if len(args) < 2 { - return fmt.Errorf("%s: invalid import: expected TARGET SETREF", fctx.fset.Position(d.pos)) + return fmt.Errorf("%v: invalid import: expected TARGET SETREF", fctx.fset.Position(d.pos)) } name := args[0] for _, spec := range args[1:] { - ref, err := parseProviderSetRef(fctx.r, spec, scope, fctx.pkg.Path(), d.pos) + ref, err := parseSymbolRef(fctx.r, spec, scope, fctx.pkg.Path(), d.pos) if err != nil { return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err) } - if ref.importPath != fctx.pkg.Path() { - imported := false - for _, imp := range fctx.pkg.Imports() { - if ref.importPath == imp.Path() { - imported = true - break - } - } - if !imported { - return fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fctx.fset.Position(d.pos), name, ref.importPath) - } + if findImport(fctx.pkg, ref.importPath) == nil { + return fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fctx.fset.Position(d.pos), name, ref.importPath) } if mod := sets[name]; mod != nil { found := false for _, other := range mod.imports { - if ref == other.providerSetRef { + if ref == other.symref { found = true break } } if !found { - mod.imports = append(mod.imports, providerSetImport{providerSetRef: ref, pos: d.pos}) + mod.imports = append(mod.imports, providerSetImport{symref: ref, pos: d.pos}) } } else { sets[name] = &providerSet{ - imports: []providerSetImport{{providerSetRef: ref, pos: d.pos}}, + imports: []providerSetImport{{symref: ref, pos: d.pos}}, } } } @@ -233,7 +295,7 @@ func newProviderSetCache(prog *loader.Program, r *importResolver) *providerSetCa } } -func (mc *providerSetCache) get(ref providerSetRef) (*providerSet, error) { +func (mc *providerSetCache) get(ref symref) (*providerSet, error) { if mods, cached := mc.sets[ref.importPath]; cached { mod := mods[ref.name] if mod == nil { @@ -263,46 +325,58 @@ func (mc *providerSetCache) get(ref providerSetRef) (*providerSet, error) { return mod, nil } -// A providerSetRef is a parsed reference to a collection of providers. -type providerSetRef struct { +// A symref is a parsed reference to a symbol (either a provider set or a Go object). +type symref struct { importPath string name string } -func parseProviderSetRef(r *importResolver, ref string, s *types.Scope, pkg string, pos token.Pos) (providerSetRef, error) { +func parseSymbolRef(r *importResolver, ref string, s *types.Scope, pkg string, pos token.Pos) (symref, error) { // TODO(light): verify that provider set name is an identifier before returning i := strings.LastIndexByte(ref, '.') if i == -1 { - return providerSetRef{importPath: pkg, name: ref}, nil + return symref{importPath: pkg, name: ref}, nil } imp, name := ref[:i], ref[i+1:] if strings.HasPrefix(imp, `"`) { path, err := strconv.Unquote(imp) if err != nil { - return providerSetRef{}, fmt.Errorf("parse provider set reference %q: bad import path", ref) + return symref{}, fmt.Errorf("parse symbol reference %q: bad import path", ref) } path, err = r.resolve(pos, path) if err != nil { - return providerSetRef{}, fmt.Errorf("parse provider set reference %q: %v", ref, err) + return symref{}, fmt.Errorf("parse symbol reference %q: %v", ref, err) } - return providerSetRef{importPath: path, name: name}, nil + return symref{importPath: path, name: name}, nil } _, obj := s.LookupParent(imp, pos) if obj == nil { - return providerSetRef{}, fmt.Errorf("parse provider set reference %q: unknown identifier %s", ref, imp) + return symref{}, fmt.Errorf("parse symbol reference %q: unknown identifier %s", ref, imp) } pn, ok := obj.(*types.PkgName) if !ok { - return providerSetRef{}, fmt.Errorf("parse provider set reference %q: %s does not name a package", ref, imp) + return symref{}, fmt.Errorf("parse symbol reference %q: %s does not name a package", ref, imp) } - return providerSetRef{importPath: pn.Imported().Path(), name: name}, nil + return symref{importPath: pn.Imported().Path(), name: name}, nil } -func (ref providerSetRef) String() string { +func (ref symref) String() string { return strconv.Quote(ref.importPath) + "." + ref.name } +func (ref symref) resolveObject(pkg *types.Package) (types.Object, error) { + imp := findImport(pkg, ref.importPath) + if imp == nil { + return nil, fmt.Errorf("resolve Go reference %v: package not directly imported", ref) + } + obj := imp.Scope().Lookup(ref.name) + if obj == nil { + return nil, fmt.Errorf("resolve Go reference %v: %s not found in package", ref, ref.name) + } + return obj, nil +} + type importResolver struct { fset *token.FileSet bctx *build.Context @@ -333,6 +407,18 @@ func (r *importResolver) resolve(pos token.Pos, path string) (string, error) { return pkg.ImportPath, nil } +func findImport(pkg *types.Package, path string) *types.Package { + if pkg.Path() == path { + return pkg + } + for _, imp := range pkg.Imports() { + if imp.Path() == path { + return imp + } + } + return nil +} + // A directive is a parsed goose comment. type directive struct { pos token.Pos diff --git a/internal/goose/testdata/ImportedInterfaceBinding/bar/bar.go b/internal/goose/testdata/ImportedInterfaceBinding/bar/bar.go new file mode 100644 index 0000000..9970572 --- /dev/null +++ b/internal/goose/testdata/ImportedInterfaceBinding/bar/bar.go @@ -0,0 +1,26 @@ +package main + +import ( + "fmt" + + _ "foo" +) + +func main() { + fmt.Println(injectFooer().Foo()) +} + +type Bar string + +func (b *Bar) Foo() string { + return string(*b) +} + +//goose:provide +func provideBar() *Bar { + b := new(Bar) + *b = "Hello, World!" + return b +} + +//goose:bind provideBar "foo".Fooer *Bar diff --git a/internal/goose/testdata/ImportedInterfaceBinding/bar/goose.go b/internal/goose/testdata/ImportedInterfaceBinding/bar/goose.go new file mode 100644 index 0000000..46812cc --- /dev/null +++ b/internal/goose/testdata/ImportedInterfaceBinding/bar/goose.go @@ -0,0 +1,9 @@ +//+build gooseinject + +package main + +import "foo" + +//goose:use provideBar + +func injectFooer() foo.Fooer diff --git a/internal/goose/testdata/ImportedInterfaceBinding/foo/foo.go b/internal/goose/testdata/ImportedInterfaceBinding/foo/foo.go new file mode 100644 index 0000000..4262566 --- /dev/null +++ b/internal/goose/testdata/ImportedInterfaceBinding/foo/foo.go @@ -0,0 +1,5 @@ +package foo + +type Fooer interface { + Foo() string +} diff --git a/internal/goose/testdata/ImportedInterfaceBinding/out.txt b/internal/goose/testdata/ImportedInterfaceBinding/out.txt new file mode 100644 index 0000000..8ab686e --- /dev/null +++ b/internal/goose/testdata/ImportedInterfaceBinding/out.txt @@ -0,0 +1 @@ +Hello, World! diff --git a/internal/goose/testdata/ImportedInterfaceBinding/pkg b/internal/goose/testdata/ImportedInterfaceBinding/pkg new file mode 100644 index 0000000..5716ca5 --- /dev/null +++ b/internal/goose/testdata/ImportedInterfaceBinding/pkg @@ -0,0 +1 @@ +bar diff --git a/internal/goose/testdata/InterfaceBinding/foo/foo.go b/internal/goose/testdata/InterfaceBinding/foo/foo.go new file mode 100644 index 0000000..50c523c --- /dev/null +++ b/internal/goose/testdata/InterfaceBinding/foo/foo.go @@ -0,0 +1,26 @@ +package main + +import "fmt" + +func main() { + fmt.Println(injectFooer().Foo()) +} + +type Fooer interface { + Foo() string +} + +type Bar string + +func (b *Bar) Foo() string { + return string(*b) +} + +//goose:provide +func provideBar() *Bar { + b := new(Bar) + *b = "Hello, World!" + return b +} + +//goose:bind provideBar Fooer *Bar diff --git a/internal/goose/testdata/InterfaceBinding/foo/foo_goose.go b/internal/goose/testdata/InterfaceBinding/foo/foo_goose.go new file mode 100644 index 0000000..38876cb --- /dev/null +++ b/internal/goose/testdata/InterfaceBinding/foo/foo_goose.go @@ -0,0 +1,7 @@ +//+build gooseinject + +package main + +//goose:use provideBar + +func injectFooer() Fooer diff --git a/internal/goose/testdata/InterfaceBinding/out.txt b/internal/goose/testdata/InterfaceBinding/out.txt new file mode 100644 index 0000000..8ab686e --- /dev/null +++ b/internal/goose/testdata/InterfaceBinding/out.txt @@ -0,0 +1 @@ +Hello, World! diff --git a/internal/goose/testdata/InterfaceBinding/pkg b/internal/goose/testdata/InterfaceBinding/pkg new file mode 100644 index 0000000..257cc56 --- /dev/null +++ b/internal/goose/testdata/InterfaceBinding/pkg @@ -0,0 +1 @@ +foo diff --git a/internal/goose/testdata/InterfaceBindingReuse/foo/foo.go b/internal/goose/testdata/InterfaceBindingReuse/foo/foo.go new file mode 100644 index 0000000..bb097eb --- /dev/null +++ b/internal/goose/testdata/InterfaceBindingReuse/foo/foo.go @@ -0,0 +1,50 @@ +// This test verifies that the concrete type is provided only once, even if an +// interface additionally depends on it. + +package main + +import ( + "fmt" + "sync" +) + +func main() { + injectFooBar() + fmt.Println(provideBarCalls) +} + +type Fooer interface { + Foo() string +} + +type Bar string + +type FooBar struct { + Fooer Fooer + Bar *Bar +} + +func (b *Bar) Foo() string { + return string(*b) +} + +//goose:provide +//goose:bind provideBar Fooer *Bar +func provideBar() *Bar { + mu.Lock() + provideBarCalls++ + mu.Unlock() + b := new(Bar) + *b = "Hello, World!" + return b +} + +var ( + mu sync.Mutex + provideBarCalls int +) + +//goose:provide +func provideFooBar(fooer Fooer, bar *Bar) FooBar { + return FooBar{fooer, bar} +} diff --git a/internal/goose/testdata/InterfaceBindingReuse/foo/foo_goose.go b/internal/goose/testdata/InterfaceBindingReuse/foo/foo_goose.go new file mode 100644 index 0000000..48a1c01 --- /dev/null +++ b/internal/goose/testdata/InterfaceBindingReuse/foo/foo_goose.go @@ -0,0 +1,8 @@ +//+build gooseinject + +package main + +//goose:use provideBar +//goose:use provideFooBar + +func injectFooBar() FooBar diff --git a/internal/goose/testdata/InterfaceBindingReuse/out.txt b/internal/goose/testdata/InterfaceBindingReuse/out.txt new file mode 100644 index 0000000..d00491f --- /dev/null +++ b/internal/goose/testdata/InterfaceBindingReuse/out.txt @@ -0,0 +1 @@ +1 diff --git a/internal/goose/testdata/InterfaceBindingReuse/pkg b/internal/goose/testdata/InterfaceBindingReuse/pkg new file mode 100644 index 0000000..257cc56 --- /dev/null +++ b/internal/goose/testdata/InterfaceBindingReuse/pkg @@ -0,0 +1 @@ +foo diff --git a/internal/goose/testdata/NoImplicitInterface/foo/foo.go b/internal/goose/testdata/NoImplicitInterface/foo/foo.go new file mode 100644 index 0000000..01585ff --- /dev/null +++ b/internal/goose/testdata/NoImplicitInterface/foo/foo.go @@ -0,0 +1,22 @@ +package main + +import "fmt" + +func main() { + fmt.Println(injectFooer().Foo()) +} + +type Fooer interface { + Foo() string +} + +type Bar string + +func (b Bar) Foo() string { + return string(b) +} + +//goose:provide +func provideBar() Bar { + return "Hello, World!" +} diff --git a/internal/goose/testdata/NoImplicitInterface/foo/foo_goose.go b/internal/goose/testdata/NoImplicitInterface/foo/foo_goose.go new file mode 100644 index 0000000..38876cb --- /dev/null +++ b/internal/goose/testdata/NoImplicitInterface/foo/foo_goose.go @@ -0,0 +1,7 @@ +//+build gooseinject + +package main + +//goose:use provideBar + +func injectFooer() Fooer diff --git a/internal/goose/testdata/NoImplicitInterface/out.txt b/internal/goose/testdata/NoImplicitInterface/out.txt new file mode 100644 index 0000000..5df7507 --- /dev/null +++ b/internal/goose/testdata/NoImplicitInterface/out.txt @@ -0,0 +1 @@ +ERROR diff --git a/internal/goose/testdata/NoImplicitInterface/pkg b/internal/goose/testdata/NoImplicitInterface/pkg new file mode 100644 index 0000000..257cc56 --- /dev/null +++ b/internal/goose/testdata/NoImplicitInterface/pkg @@ -0,0 +1 @@ +foo