goose: add interface binding

An interface binding instructs goose that a concrete type should be used
to satisfy a dependency on an interface type. goose could determine this
implicitly, but having an explicit directive makes the provider author's
intent clear and allows different concrete types to satisfy different
smaller interfaces.

Reviewed-by: Tuo Shan <shantuo@google.com>
This commit is contained in:
Ross Light
2018-04-02 14:08:17 -07:00
parent 73d4c0f0fc
commit 1380f96c06
21 changed files with 383 additions and 46 deletions

View File

@@ -30,7 +30,7 @@ type call struct {
// solve finds the sequence of calls required to produce an output type
// with an optional set of provided inputs.
func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []providerSetRef) ([]call, error) {
func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symref) ([]call, error) {
for i, g := range given {
for _, h := range given[:i] {
if types.Identical(g, h) {
@@ -82,6 +82,14 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []prov
// TODO(light): give name of provider
return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2].typ, nil))
}
if !types.Identical(p.out, typ) {
// Interface binding. Don't create a call ourselves.
if err := visit(append(trail, providerInput{typ: p.out})); err != nil {
return err
}
index.Set(typ, index.At(p.out))
return nil
}
for _, a := range p.args {
// TODO(light): this will discard grown trail arrays.
if err := visit(append(trail, a)); err != nil {
@@ -115,16 +123,22 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []prov
return calls, nil
}
func buildProviderMap(mc *providerSetCache, sets []providerSetRef) (*typeutil.Map, error) {
func buildProviderMap(mc *providerSetCache, sets []symref) (*typeutil.Map, error) {
type nextEnt struct {
to providerSetRef
to symref
from providerSetRef
from symref
pos token.Pos
}
type binding struct {
ifaceBinding
pset symref
from symref
}
pm := new(typeutil.Map) // to *providerInfo
visited := make(map[providerSetRef]struct{})
var bindings []binding
visited := make(map[symref]struct{})
var next []nextEnt
for _, ref := range sets {
next = append(next, nextEnt{to: ref})
@@ -137,28 +151,60 @@ func buildProviderMap(mc *providerSetCache, sets []providerSetRef) (*typeutil.Ma
continue
}
visited[curr.to] = struct{}{}
mod, err := mc.get(curr.to)
pset, err := mc.get(curr.to)
if err != nil {
if !curr.pos.IsValid() {
return nil, err
}
return nil, fmt.Errorf("%v: %v", mc.fset.Position(curr.pos), err)
}
for _, p := range mod.providers {
for _, p := range pset.providers {
if prev := pm.At(p.out); prev != nil {
pos := mc.fset.Position(p.pos)
typ := types.TypeString(p.out, nil)
prevPos := mc.fset.Position(prev.(*providerInfo).pos)
if curr.from.importPath != "" {
if curr.from.importPath == "" {
// Provider set is imported directly by injector.
return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos)
}
return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, curr.from, prevPos)
}
pm.Set(p.out, p)
}
for _, imp := range mod.imports {
next = append(next, nextEnt{to: imp.providerSetRef, from: curr.to, pos: imp.pos})
for _, b := range pset.bindings {
bindings = append(bindings, binding{
ifaceBinding: b,
pset: curr.to,
from: curr.from,
})
}
for _, imp := range pset.imports {
next = append(next, nextEnt{to: imp.symref, from: curr.to, pos: imp.pos})
}
}
for _, b := range bindings {
if prev := pm.At(b.iface); prev != nil {
pos := mc.fset.Position(b.pos)
typ := types.TypeString(b.iface, nil)
// TODO(light): error message for conflicting with another interface binding will point at provider function instead of binding.
prevPos := mc.fset.Position(prev.(*providerInfo).pos)
if b.from.importPath == "" {
// Provider set is imported directly by injector.
return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos)
}
return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, b.from, prevPos)
}
concrete := pm.At(b.provided)
if concrete == nil {
pos := mc.fset.Position(b.pos)
typ := types.TypeString(b.provided, nil)
if b.from.importPath == "" {
// Concrete provider is imported directly by injector.
return nil, fmt.Errorf("%v: no binding for %s", pos, typ)
}
return nil, fmt.Errorf("%v: no binding for %s (imported by %v)", pos, typ, b.from)
}
pm.Set(b.iface, concrete)
}
return pm, nil
}

View File

@@ -66,7 +66,7 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
if dg.decl != decl {
dg = directiveGroup{}
}
var sets []providerSetRef
var sets []symref
for _, d := range dg.dirs {
if d.kind != "use" {
return nil, fmt.Errorf("%v: cannot use %s directive on inject function", prog.Fset.Position(d.pos), d.kind)
@@ -76,7 +76,7 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
return nil, fmt.Errorf("%v: goose:use must have at least one provider set reference", prog.Fset.Position(d.pos))
}
for _, arg := range args {
ref, err := parseProviderSetRef(r, arg, fileScope, g.currPackage, d.pos)
ref, err := parseSymbolRef(r, arg, fileScope, g.currPackage, d.pos)
if err != nil {
return nil, fmt.Errorf("%v: %v", prog.Fset.Position(d.pos), err)
}
@@ -143,7 +143,7 @@ func (g *gen) frame() []byte {
}
// inject emits the code for an injector.
func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, sets []providerSetRef) error {
func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, sets []symref) error {
results := sig.Results()
returnsErr := false
switch results.Len() {

View File

@@ -18,11 +18,22 @@ import (
// providerSet.
type providerSet struct {
providers []*providerInfo
bindings []ifaceBinding
imports []providerSetImport
}
// An ifaceBinding declares that a type should be used to satisfy inputs
// of the given interface type.
//
// provided is always a type that is assignable to iface.
type ifaceBinding struct {
iface types.Type
provided types.Type
pos token.Pos
}
type providerSetImport struct {
providerSetRef
symref
pos token.Pos
}
@@ -30,7 +41,7 @@ type providerSetImport struct {
type providerInfo struct {
importPath string
funcName string
pos token.Pos
pos token.Pos // provider function definition
args []providerInput
out types.Type
hasErr bool
@@ -80,43 +91,94 @@ func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet
return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(d.pos))
case "use":
// Ignore, picked up by injector flow.
case "bind":
args := d.args()
if len(args) != 3 {
return fmt.Errorf("%v: invalid binding: expected TARGET IFACE TYPE", fctx.fset.Position(d.pos))
}
ifaceRef, err := parseSymbolRef(fctx.r, args[1], scope, fctx.pkg.Path(), d.pos)
if err != nil {
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
}
ifaceObj, err := ifaceRef.resolveObject(fctx.pkg)
if err != nil {
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
}
ifaceDecl, ok := ifaceObj.(*types.TypeName)
if !ok {
return fmt.Errorf("%v: %v does not name a type", fctx.fset.Position(d.pos), ifaceRef)
}
iface := ifaceDecl.Type()
methodSet, ok := iface.Underlying().(*types.Interface)
if !ok {
return fmt.Errorf("%v: %v does not name an interface type", fctx.fset.Position(d.pos), ifaceRef)
}
providedRef, err := parseSymbolRef(fctx.r, strings.TrimPrefix(args[2], "*"), scope, fctx.pkg.Path(), d.pos)
if err != nil {
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
}
providedObj, err := providedRef.resolveObject(fctx.pkg)
if err != nil {
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
}
providedDecl, ok := providedObj.(*types.TypeName)
if !ok {
return fmt.Errorf("%v: %v does not name a type", fctx.fset.Position(d.pos), providedRef)
}
provided := providedDecl.Type()
if types.Identical(provided, iface) {
return fmt.Errorf("%v: cannot bind interface to itself", fctx.fset.Position(d.pos))
}
if strings.HasPrefix(args[2], "*") {
provided = types.NewPointer(provided)
}
if !types.Implements(provided, methodSet) {
return fmt.Errorf("%v: %s does not implement %s", fctx.fset.Position(d.pos), types.TypeString(provided, nil), types.TypeString(iface, nil))
}
name := args[0]
if pset := sets[name]; pset != nil {
pset.bindings = append(pset.bindings, ifaceBinding{
iface: iface,
provided: provided,
})
} else {
sets[name] = &providerSet{
bindings: []ifaceBinding{{
iface: iface,
provided: provided,
}},
}
}
case "import":
args := d.args()
if len(args) < 2 {
return fmt.Errorf("%s: invalid import: expected TARGET SETREF", fctx.fset.Position(d.pos))
return fmt.Errorf("%v: invalid import: expected TARGET SETREF", fctx.fset.Position(d.pos))
}
name := args[0]
for _, spec := range args[1:] {
ref, err := parseProviderSetRef(fctx.r, spec, scope, fctx.pkg.Path(), d.pos)
ref, err := parseSymbolRef(fctx.r, spec, scope, fctx.pkg.Path(), d.pos)
if err != nil {
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
}
if ref.importPath != fctx.pkg.Path() {
imported := false
for _, imp := range fctx.pkg.Imports() {
if ref.importPath == imp.Path() {
imported = true
break
}
}
if !imported {
return fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fctx.fset.Position(d.pos), name, ref.importPath)
}
if findImport(fctx.pkg, ref.importPath) == nil {
return fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fctx.fset.Position(d.pos), name, ref.importPath)
}
if mod := sets[name]; mod != nil {
found := false
for _, other := range mod.imports {
if ref == other.providerSetRef {
if ref == other.symref {
found = true
break
}
}
if !found {
mod.imports = append(mod.imports, providerSetImport{providerSetRef: ref, pos: d.pos})
mod.imports = append(mod.imports, providerSetImport{symref: ref, pos: d.pos})
}
} else {
sets[name] = &providerSet{
imports: []providerSetImport{{providerSetRef: ref, pos: d.pos}},
imports: []providerSetImport{{symref: ref, pos: d.pos}},
}
}
}
@@ -233,7 +295,7 @@ func newProviderSetCache(prog *loader.Program, r *importResolver) *providerSetCa
}
}
func (mc *providerSetCache) get(ref providerSetRef) (*providerSet, error) {
func (mc *providerSetCache) get(ref symref) (*providerSet, error) {
if mods, cached := mc.sets[ref.importPath]; cached {
mod := mods[ref.name]
if mod == nil {
@@ -263,46 +325,58 @@ func (mc *providerSetCache) get(ref providerSetRef) (*providerSet, error) {
return mod, nil
}
// A providerSetRef is a parsed reference to a collection of providers.
type providerSetRef struct {
// A symref is a parsed reference to a symbol (either a provider set or a Go object).
type symref struct {
importPath string
name string
}
func parseProviderSetRef(r *importResolver, ref string, s *types.Scope, pkg string, pos token.Pos) (providerSetRef, error) {
func parseSymbolRef(r *importResolver, ref string, s *types.Scope, pkg string, pos token.Pos) (symref, error) {
// TODO(light): verify that provider set name is an identifier before returning
i := strings.LastIndexByte(ref, '.')
if i == -1 {
return providerSetRef{importPath: pkg, name: ref}, nil
return symref{importPath: pkg, name: ref}, nil
}
imp, name := ref[:i], ref[i+1:]
if strings.HasPrefix(imp, `"`) {
path, err := strconv.Unquote(imp)
if err != nil {
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: bad import path", ref)
return symref{}, fmt.Errorf("parse symbol reference %q: bad import path", ref)
}
path, err = r.resolve(pos, path)
if err != nil {
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: %v", ref, err)
return symref{}, fmt.Errorf("parse symbol reference %q: %v", ref, err)
}
return providerSetRef{importPath: path, name: name}, nil
return symref{importPath: path, name: name}, nil
}
_, obj := s.LookupParent(imp, pos)
if obj == nil {
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: unknown identifier %s", ref, imp)
return symref{}, fmt.Errorf("parse symbol reference %q: unknown identifier %s", ref, imp)
}
pn, ok := obj.(*types.PkgName)
if !ok {
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: %s does not name a package", ref, imp)
return symref{}, fmt.Errorf("parse symbol reference %q: %s does not name a package", ref, imp)
}
return providerSetRef{importPath: pn.Imported().Path(), name: name}, nil
return symref{importPath: pn.Imported().Path(), name: name}, nil
}
func (ref providerSetRef) String() string {
func (ref symref) String() string {
return strconv.Quote(ref.importPath) + "." + ref.name
}
func (ref symref) resolveObject(pkg *types.Package) (types.Object, error) {
imp := findImport(pkg, ref.importPath)
if imp == nil {
return nil, fmt.Errorf("resolve Go reference %v: package not directly imported", ref)
}
obj := imp.Scope().Lookup(ref.name)
if obj == nil {
return nil, fmt.Errorf("resolve Go reference %v: %s not found in package", ref, ref.name)
}
return obj, nil
}
type importResolver struct {
fset *token.FileSet
bctx *build.Context
@@ -333,6 +407,18 @@ func (r *importResolver) resolve(pos token.Pos, path string) (string, error) {
return pkg.ImportPath, nil
}
func findImport(pkg *types.Package, path string) *types.Package {
if pkg.Path() == path {
return pkg
}
for _, imp := range pkg.Imports() {
if imp.Path() == path {
return imp
}
}
return nil
}
// A directive is a parsed goose comment.
type directive struct {
pos token.Pos

View File

@@ -0,0 +1,26 @@
package main
import (
"fmt"
_ "foo"
)
func main() {
fmt.Println(injectFooer().Foo())
}
type Bar string
func (b *Bar) Foo() string {
return string(*b)
}
//goose:provide
func provideBar() *Bar {
b := new(Bar)
*b = "Hello, World!"
return b
}
//goose:bind provideBar "foo".Fooer *Bar

View File

@@ -0,0 +1,9 @@
//+build gooseinject
package main
import "foo"
//goose:use provideBar
func injectFooer() foo.Fooer

View File

@@ -0,0 +1,5 @@
package foo
type Fooer interface {
Foo() string
}

View File

@@ -0,0 +1 @@
Hello, World!

View File

@@ -0,0 +1 @@
bar

View File

@@ -0,0 +1,26 @@
package main
import "fmt"
func main() {
fmt.Println(injectFooer().Foo())
}
type Fooer interface {
Foo() string
}
type Bar string
func (b *Bar) Foo() string {
return string(*b)
}
//goose:provide
func provideBar() *Bar {
b := new(Bar)
*b = "Hello, World!"
return b
}
//goose:bind provideBar Fooer *Bar

View File

@@ -0,0 +1,7 @@
//+build gooseinject
package main
//goose:use provideBar
func injectFooer() Fooer

View File

@@ -0,0 +1 @@
Hello, World!

View File

@@ -0,0 +1 @@
foo

View File

@@ -0,0 +1,50 @@
// This test verifies that the concrete type is provided only once, even if an
// interface additionally depends on it.
package main
import (
"fmt"
"sync"
)
func main() {
injectFooBar()
fmt.Println(provideBarCalls)
}
type Fooer interface {
Foo() string
}
type Bar string
type FooBar struct {
Fooer Fooer
Bar *Bar
}
func (b *Bar) Foo() string {
return string(*b)
}
//goose:provide
//goose:bind provideBar Fooer *Bar
func provideBar() *Bar {
mu.Lock()
provideBarCalls++
mu.Unlock()
b := new(Bar)
*b = "Hello, World!"
return b
}
var (
mu sync.Mutex
provideBarCalls int
)
//goose:provide
func provideFooBar(fooer Fooer, bar *Bar) FooBar {
return FooBar{fooer, bar}
}

View File

@@ -0,0 +1,8 @@
//+build gooseinject
package main
//goose:use provideBar
//goose:use provideFooBar
func injectFooBar() FooBar

View File

@@ -0,0 +1 @@
1

View File

@@ -0,0 +1 @@
foo

View File

@@ -0,0 +1,22 @@
package main
import "fmt"
func main() {
fmt.Println(injectFooer().Foo())
}
type Fooer interface {
Foo() string
}
type Bar string
func (b Bar) Foo() string {
return string(b)
}
//goose:provide
func provideBar() Bar {
return "Hello, World!"
}

View File

@@ -0,0 +1,7 @@
//+build gooseinject
package main
//goose:use provideBar
func injectFooer() Fooer

View File

@@ -0,0 +1 @@
ERROR

View File

@@ -0,0 +1 @@
foo