goose: add interface binding

An interface binding instructs goose that a concrete type should be used
to satisfy a dependency on an interface type. goose could determine this
implicitly, but having an explicit directive makes the provider author's
intent clear and allows different concrete types to satisfy different
smaller interfaces.

Reviewed-by: Tuo Shan <shantuo@google.com>
This commit is contained in:
Ross Light
2018-04-02 14:08:17 -07:00
parent 73d4c0f0fc
commit 1380f96c06
21 changed files with 383 additions and 46 deletions

View File

@@ -208,6 +208,43 @@ type MySQLConnectionString string
## Advanced Features ## Advanced Features
### Binding Interfaces
Frequently, dependency injection is used to bind concrete implementations for an
interface. goose matches inputs to outputs via [type identity][], so the
inclination might be to create a provider that returns an interface type.
However, this would not be idiomatic, since the Go best practice is to [return
concrete types][]. Instead, you can declare an interface binding in a
provider set:
```go
type Fooer interface {
Foo() string
}
type Bar string
func (b *Bar) Foo() string {
return string(*b)
}
//goose:provide BarFooer
func provideBar() *Bar {
b := new(Bar)
*b = "Hello, World!"
return b
}
//goose:bind BarFooer Fooer *Bar
```
The syntax is provider set name, interface type, and finally the concrete type.
An interface binding does not necessarily need to have a provider in the same
set that provides the concrete type.
[type identity]: https://golang.org/ref/spec#Type_identity
[return concrete types]: https://github.com/golang/go/wiki/CodeReviewComments#interfaces
### Optional Inputs ### Optional Inputs
A provider input can be marked optional using `goose:optional`: A provider input can be marked optional using `goose:optional`:
@@ -230,6 +267,6 @@ the injector will pass the provider the zero value as the `foo` argument.
- Support for multiple provider outputs. - Support for multiple provider outputs.
- Support for field binding: declare a struct as a provider and have it be - Support for field binding: declare a struct as a provider and have it be
filled in by the corresponding bindings from the graph. filled in by the corresponding bindings from the graph.
- Currently, all dependency satisfaction is done using identity. I'd like to - Tighter validation for a provider set (cycles in unused providers goes
use a limited form of assignability for interface types, but I'm unsure unreported currently)
how well this implicit satisfaction will work in practice. - Visualization for provider sets

View File

@@ -30,7 +30,7 @@ type call struct {
// solve finds the sequence of calls required to produce an output type // solve finds the sequence of calls required to produce an output type
// with an optional set of provided inputs. // with an optional set of provided inputs.
func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []providerSetRef) ([]call, error) { func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symref) ([]call, error) {
for i, g := range given { for i, g := range given {
for _, h := range given[:i] { for _, h := range given[:i] {
if types.Identical(g, h) { if types.Identical(g, h) {
@@ -82,6 +82,14 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []prov
// TODO(light): give name of provider // TODO(light): give name of provider
return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2].typ, nil)) return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2].typ, nil))
} }
if !types.Identical(p.out, typ) {
// Interface binding. Don't create a call ourselves.
if err := visit(append(trail, providerInput{typ: p.out})); err != nil {
return err
}
index.Set(typ, index.At(p.out))
return nil
}
for _, a := range p.args { for _, a := range p.args {
// TODO(light): this will discard grown trail arrays. // TODO(light): this will discard grown trail arrays.
if err := visit(append(trail, a)); err != nil { if err := visit(append(trail, a)); err != nil {
@@ -115,16 +123,22 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []prov
return calls, nil return calls, nil
} }
func buildProviderMap(mc *providerSetCache, sets []providerSetRef) (*typeutil.Map, error) { func buildProviderMap(mc *providerSetCache, sets []symref) (*typeutil.Map, error) {
type nextEnt struct { type nextEnt struct {
to providerSetRef to symref
from providerSetRef from symref
pos token.Pos pos token.Pos
} }
type binding struct {
ifaceBinding
pset symref
from symref
}
pm := new(typeutil.Map) // to *providerInfo pm := new(typeutil.Map) // to *providerInfo
visited := make(map[providerSetRef]struct{}) var bindings []binding
visited := make(map[symref]struct{})
var next []nextEnt var next []nextEnt
for _, ref := range sets { for _, ref := range sets {
next = append(next, nextEnt{to: ref}) next = append(next, nextEnt{to: ref})
@@ -137,28 +151,60 @@ func buildProviderMap(mc *providerSetCache, sets []providerSetRef) (*typeutil.Ma
continue continue
} }
visited[curr.to] = struct{}{} visited[curr.to] = struct{}{}
mod, err := mc.get(curr.to) pset, err := mc.get(curr.to)
if err != nil { if err != nil {
if !curr.pos.IsValid() { if !curr.pos.IsValid() {
return nil, err return nil, err
} }
return nil, fmt.Errorf("%v: %v", mc.fset.Position(curr.pos), err) return nil, fmt.Errorf("%v: %v", mc.fset.Position(curr.pos), err)
} }
for _, p := range mod.providers { for _, p := range pset.providers {
if prev := pm.At(p.out); prev != nil { if prev := pm.At(p.out); prev != nil {
pos := mc.fset.Position(p.pos) pos := mc.fset.Position(p.pos)
typ := types.TypeString(p.out, nil) typ := types.TypeString(p.out, nil)
prevPos := mc.fset.Position(prev.(*providerInfo).pos) prevPos := mc.fset.Position(prev.(*providerInfo).pos)
if curr.from.importPath != "" { if curr.from.importPath == "" {
// Provider set is imported directly by injector.
return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos) return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos)
} }
return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, curr.from, prevPos) return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, curr.from, prevPos)
} }
pm.Set(p.out, p) pm.Set(p.out, p)
} }
for _, imp := range mod.imports { for _, b := range pset.bindings {
next = append(next, nextEnt{to: imp.providerSetRef, from: curr.to, pos: imp.pos}) bindings = append(bindings, binding{
ifaceBinding: b,
pset: curr.to,
from: curr.from,
})
} }
for _, imp := range pset.imports {
next = append(next, nextEnt{to: imp.symref, from: curr.to, pos: imp.pos})
}
}
for _, b := range bindings {
if prev := pm.At(b.iface); prev != nil {
pos := mc.fset.Position(b.pos)
typ := types.TypeString(b.iface, nil)
// TODO(light): error message for conflicting with another interface binding will point at provider function instead of binding.
prevPos := mc.fset.Position(prev.(*providerInfo).pos)
if b.from.importPath == "" {
// Provider set is imported directly by injector.
return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos)
}
return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, b.from, prevPos)
}
concrete := pm.At(b.provided)
if concrete == nil {
pos := mc.fset.Position(b.pos)
typ := types.TypeString(b.provided, nil)
if b.from.importPath == "" {
// Concrete provider is imported directly by injector.
return nil, fmt.Errorf("%v: no binding for %s", pos, typ)
}
return nil, fmt.Errorf("%v: no binding for %s (imported by %v)", pos, typ, b.from)
}
pm.Set(b.iface, concrete)
} }
return pm, nil return pm, nil
} }

View File

@@ -66,7 +66,7 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
if dg.decl != decl { if dg.decl != decl {
dg = directiveGroup{} dg = directiveGroup{}
} }
var sets []providerSetRef var sets []symref
for _, d := range dg.dirs { for _, d := range dg.dirs {
if d.kind != "use" { if d.kind != "use" {
return nil, fmt.Errorf("%v: cannot use %s directive on inject function", prog.Fset.Position(d.pos), d.kind) return nil, fmt.Errorf("%v: cannot use %s directive on inject function", prog.Fset.Position(d.pos), d.kind)
@@ -76,7 +76,7 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
return nil, fmt.Errorf("%v: goose:use must have at least one provider set reference", prog.Fset.Position(d.pos)) return nil, fmt.Errorf("%v: goose:use must have at least one provider set reference", prog.Fset.Position(d.pos))
} }
for _, arg := range args { for _, arg := range args {
ref, err := parseProviderSetRef(r, arg, fileScope, g.currPackage, d.pos) ref, err := parseSymbolRef(r, arg, fileScope, g.currPackage, d.pos)
if err != nil { if err != nil {
return nil, fmt.Errorf("%v: %v", prog.Fset.Position(d.pos), err) return nil, fmt.Errorf("%v: %v", prog.Fset.Position(d.pos), err)
} }
@@ -143,7 +143,7 @@ func (g *gen) frame() []byte {
} }
// inject emits the code for an injector. // inject emits the code for an injector.
func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, sets []providerSetRef) error { func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, sets []symref) error {
results := sig.Results() results := sig.Results()
returnsErr := false returnsErr := false
switch results.Len() { switch results.Len() {

View File

@@ -18,11 +18,22 @@ import (
// providerSet. // providerSet.
type providerSet struct { type providerSet struct {
providers []*providerInfo providers []*providerInfo
bindings []ifaceBinding
imports []providerSetImport imports []providerSetImport
} }
// An ifaceBinding declares that a type should be used to satisfy inputs
// of the given interface type.
//
// provided is always a type that is assignable to iface.
type ifaceBinding struct {
iface types.Type
provided types.Type
pos token.Pos
}
type providerSetImport struct { type providerSetImport struct {
providerSetRef symref
pos token.Pos pos token.Pos
} }
@@ -30,7 +41,7 @@ type providerSetImport struct {
type providerInfo struct { type providerInfo struct {
importPath string importPath string
funcName string funcName string
pos token.Pos pos token.Pos // provider function definition
args []providerInput args []providerInput
out types.Type out types.Type
hasErr bool hasErr bool
@@ -80,43 +91,94 @@ func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet
return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(d.pos)) return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(d.pos))
case "use": case "use":
// Ignore, picked up by injector flow. // Ignore, picked up by injector flow.
case "bind":
args := d.args()
if len(args) != 3 {
return fmt.Errorf("%v: invalid binding: expected TARGET IFACE TYPE", fctx.fset.Position(d.pos))
}
ifaceRef, err := parseSymbolRef(fctx.r, args[1], scope, fctx.pkg.Path(), d.pos)
if err != nil {
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
}
ifaceObj, err := ifaceRef.resolveObject(fctx.pkg)
if err != nil {
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
}
ifaceDecl, ok := ifaceObj.(*types.TypeName)
if !ok {
return fmt.Errorf("%v: %v does not name a type", fctx.fset.Position(d.pos), ifaceRef)
}
iface := ifaceDecl.Type()
methodSet, ok := iface.Underlying().(*types.Interface)
if !ok {
return fmt.Errorf("%v: %v does not name an interface type", fctx.fset.Position(d.pos), ifaceRef)
}
providedRef, err := parseSymbolRef(fctx.r, strings.TrimPrefix(args[2], "*"), scope, fctx.pkg.Path(), d.pos)
if err != nil {
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
}
providedObj, err := providedRef.resolveObject(fctx.pkg)
if err != nil {
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
}
providedDecl, ok := providedObj.(*types.TypeName)
if !ok {
return fmt.Errorf("%v: %v does not name a type", fctx.fset.Position(d.pos), providedRef)
}
provided := providedDecl.Type()
if types.Identical(provided, iface) {
return fmt.Errorf("%v: cannot bind interface to itself", fctx.fset.Position(d.pos))
}
if strings.HasPrefix(args[2], "*") {
provided = types.NewPointer(provided)
}
if !types.Implements(provided, methodSet) {
return fmt.Errorf("%v: %s does not implement %s", fctx.fset.Position(d.pos), types.TypeString(provided, nil), types.TypeString(iface, nil))
}
name := args[0]
if pset := sets[name]; pset != nil {
pset.bindings = append(pset.bindings, ifaceBinding{
iface: iface,
provided: provided,
})
} else {
sets[name] = &providerSet{
bindings: []ifaceBinding{{
iface: iface,
provided: provided,
}},
}
}
case "import": case "import":
args := d.args() args := d.args()
if len(args) < 2 { if len(args) < 2 {
return fmt.Errorf("%s: invalid import: expected TARGET SETREF", fctx.fset.Position(d.pos)) return fmt.Errorf("%v: invalid import: expected TARGET SETREF", fctx.fset.Position(d.pos))
} }
name := args[0] name := args[0]
for _, spec := range args[1:] { for _, spec := range args[1:] {
ref, err := parseProviderSetRef(fctx.r, spec, scope, fctx.pkg.Path(), d.pos) ref, err := parseSymbolRef(fctx.r, spec, scope, fctx.pkg.Path(), d.pos)
if err != nil { if err != nil {
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err) return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
} }
if ref.importPath != fctx.pkg.Path() { if findImport(fctx.pkg, ref.importPath) == nil {
imported := false return fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fctx.fset.Position(d.pos), name, ref.importPath)
for _, imp := range fctx.pkg.Imports() {
if ref.importPath == imp.Path() {
imported = true
break
}
}
if !imported {
return fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fctx.fset.Position(d.pos), name, ref.importPath)
}
} }
if mod := sets[name]; mod != nil { if mod := sets[name]; mod != nil {
found := false found := false
for _, other := range mod.imports { for _, other := range mod.imports {
if ref == other.providerSetRef { if ref == other.symref {
found = true found = true
break break
} }
} }
if !found { if !found {
mod.imports = append(mod.imports, providerSetImport{providerSetRef: ref, pos: d.pos}) mod.imports = append(mod.imports, providerSetImport{symref: ref, pos: d.pos})
} }
} else { } else {
sets[name] = &providerSet{ sets[name] = &providerSet{
imports: []providerSetImport{{providerSetRef: ref, pos: d.pos}}, imports: []providerSetImport{{symref: ref, pos: d.pos}},
} }
} }
} }
@@ -233,7 +295,7 @@ func newProviderSetCache(prog *loader.Program, r *importResolver) *providerSetCa
} }
} }
func (mc *providerSetCache) get(ref providerSetRef) (*providerSet, error) { func (mc *providerSetCache) get(ref symref) (*providerSet, error) {
if mods, cached := mc.sets[ref.importPath]; cached { if mods, cached := mc.sets[ref.importPath]; cached {
mod := mods[ref.name] mod := mods[ref.name]
if mod == nil { if mod == nil {
@@ -263,46 +325,58 @@ func (mc *providerSetCache) get(ref providerSetRef) (*providerSet, error) {
return mod, nil return mod, nil
} }
// A providerSetRef is a parsed reference to a collection of providers. // A symref is a parsed reference to a symbol (either a provider set or a Go object).
type providerSetRef struct { type symref struct {
importPath string importPath string
name string name string
} }
func parseProviderSetRef(r *importResolver, ref string, s *types.Scope, pkg string, pos token.Pos) (providerSetRef, error) { func parseSymbolRef(r *importResolver, ref string, s *types.Scope, pkg string, pos token.Pos) (symref, error) {
// TODO(light): verify that provider set name is an identifier before returning // TODO(light): verify that provider set name is an identifier before returning
i := strings.LastIndexByte(ref, '.') i := strings.LastIndexByte(ref, '.')
if i == -1 { if i == -1 {
return providerSetRef{importPath: pkg, name: ref}, nil return symref{importPath: pkg, name: ref}, nil
} }
imp, name := ref[:i], ref[i+1:] imp, name := ref[:i], ref[i+1:]
if strings.HasPrefix(imp, `"`) { if strings.HasPrefix(imp, `"`) {
path, err := strconv.Unquote(imp) path, err := strconv.Unquote(imp)
if err != nil { if err != nil {
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: bad import path", ref) return symref{}, fmt.Errorf("parse symbol reference %q: bad import path", ref)
} }
path, err = r.resolve(pos, path) path, err = r.resolve(pos, path)
if err != nil { if err != nil {
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: %v", ref, err) return symref{}, fmt.Errorf("parse symbol reference %q: %v", ref, err)
} }
return providerSetRef{importPath: path, name: name}, nil return symref{importPath: path, name: name}, nil
} }
_, obj := s.LookupParent(imp, pos) _, obj := s.LookupParent(imp, pos)
if obj == nil { if obj == nil {
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: unknown identifier %s", ref, imp) return symref{}, fmt.Errorf("parse symbol reference %q: unknown identifier %s", ref, imp)
} }
pn, ok := obj.(*types.PkgName) pn, ok := obj.(*types.PkgName)
if !ok { if !ok {
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: %s does not name a package", ref, imp) return symref{}, fmt.Errorf("parse symbol reference %q: %s does not name a package", ref, imp)
} }
return providerSetRef{importPath: pn.Imported().Path(), name: name}, nil return symref{importPath: pn.Imported().Path(), name: name}, nil
} }
func (ref providerSetRef) String() string { func (ref symref) String() string {
return strconv.Quote(ref.importPath) + "." + ref.name return strconv.Quote(ref.importPath) + "." + ref.name
} }
func (ref symref) resolveObject(pkg *types.Package) (types.Object, error) {
imp := findImport(pkg, ref.importPath)
if imp == nil {
return nil, fmt.Errorf("resolve Go reference %v: package not directly imported", ref)
}
obj := imp.Scope().Lookup(ref.name)
if obj == nil {
return nil, fmt.Errorf("resolve Go reference %v: %s not found in package", ref, ref.name)
}
return obj, nil
}
type importResolver struct { type importResolver struct {
fset *token.FileSet fset *token.FileSet
bctx *build.Context bctx *build.Context
@@ -333,6 +407,18 @@ func (r *importResolver) resolve(pos token.Pos, path string) (string, error) {
return pkg.ImportPath, nil return pkg.ImportPath, nil
} }
func findImport(pkg *types.Package, path string) *types.Package {
if pkg.Path() == path {
return pkg
}
for _, imp := range pkg.Imports() {
if imp.Path() == path {
return imp
}
}
return nil
}
// A directive is a parsed goose comment. // A directive is a parsed goose comment.
type directive struct { type directive struct {
pos token.Pos pos token.Pos

View File

@@ -0,0 +1,26 @@
package main
import (
"fmt"
_ "foo"
)
func main() {
fmt.Println(injectFooer().Foo())
}
type Bar string
func (b *Bar) Foo() string {
return string(*b)
}
//goose:provide
func provideBar() *Bar {
b := new(Bar)
*b = "Hello, World!"
return b
}
//goose:bind provideBar "foo".Fooer *Bar

View File

@@ -0,0 +1,9 @@
//+build gooseinject
package main
import "foo"
//goose:use provideBar
func injectFooer() foo.Fooer

View File

@@ -0,0 +1,5 @@
package foo
type Fooer interface {
Foo() string
}

View File

@@ -0,0 +1 @@
Hello, World!

View File

@@ -0,0 +1 @@
bar

View File

@@ -0,0 +1,26 @@
package main
import "fmt"
func main() {
fmt.Println(injectFooer().Foo())
}
type Fooer interface {
Foo() string
}
type Bar string
func (b *Bar) Foo() string {
return string(*b)
}
//goose:provide
func provideBar() *Bar {
b := new(Bar)
*b = "Hello, World!"
return b
}
//goose:bind provideBar Fooer *Bar

View File

@@ -0,0 +1,7 @@
//+build gooseinject
package main
//goose:use provideBar
func injectFooer() Fooer

View File

@@ -0,0 +1 @@
Hello, World!

View File

@@ -0,0 +1 @@
foo

View File

@@ -0,0 +1,50 @@
// This test verifies that the concrete type is provided only once, even if an
// interface additionally depends on it.
package main
import (
"fmt"
"sync"
)
func main() {
injectFooBar()
fmt.Println(provideBarCalls)
}
type Fooer interface {
Foo() string
}
type Bar string
type FooBar struct {
Fooer Fooer
Bar *Bar
}
func (b *Bar) Foo() string {
return string(*b)
}
//goose:provide
//goose:bind provideBar Fooer *Bar
func provideBar() *Bar {
mu.Lock()
provideBarCalls++
mu.Unlock()
b := new(Bar)
*b = "Hello, World!"
return b
}
var (
mu sync.Mutex
provideBarCalls int
)
//goose:provide
func provideFooBar(fooer Fooer, bar *Bar) FooBar {
return FooBar{fooer, bar}
}

View File

@@ -0,0 +1,8 @@
//+build gooseinject
package main
//goose:use provideBar
//goose:use provideFooBar
func injectFooBar() FooBar

View File

@@ -0,0 +1 @@
1

View File

@@ -0,0 +1 @@
foo

View File

@@ -0,0 +1,22 @@
package main
import "fmt"
func main() {
fmt.Println(injectFooer().Foo())
}
type Fooer interface {
Foo() string
}
type Bar string
func (b Bar) Foo() string {
return string(b)
}
//goose:provide
func provideBar() Bar {
return "Hello, World!"
}

View File

@@ -0,0 +1,7 @@
//+build gooseinject
package main
//goose:use provideBar
func injectFooer() Fooer

View File

@@ -0,0 +1 @@
ERROR

View File

@@ -0,0 +1 @@
foo