An interface binding instructs goose that a concrete type should be used to satisfy a dependency on an interface type. goose could determine this implicitly, but having an explicit directive makes the provider author's intent clear and allows different concrete types to satisfy different smaller interfaces. Reviewed-by: Tuo Shan <shantuo@google.com>
211 lines
6.3 KiB
Go
211 lines
6.3 KiB
Go
package goose
|
|
|
|
import (
|
|
"fmt"
|
|
"go/token"
|
|
"go/types"
|
|
|
|
"golang.org/x/tools/go/types/typeutil"
|
|
)
|
|
|
|
// A call represents a step of an injector function.
|
|
type call struct {
|
|
// importPath and funcName identify the provider function to call.
|
|
importPath string
|
|
funcName string
|
|
|
|
// args is a list of arguments to call the provider with. Each element is:
|
|
// a) one of the givens (args[i] < len(given)),
|
|
// b) the result of a previous provider call (args[i] >= len(given)), or
|
|
// c) the zero value for the type (args[i] == -1).
|
|
args []int
|
|
|
|
// ins is the list of types this call receives as arguments.
|
|
ins []types.Type
|
|
// out is the type produced by this provider call.
|
|
out types.Type
|
|
// hasErr is true if the provider call returns an error.
|
|
hasErr bool
|
|
}
|
|
|
|
// solve finds the sequence of calls required to produce an output type
|
|
// with an optional set of provided inputs.
|
|
func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symref) ([]call, error) {
|
|
for i, g := range given {
|
|
for _, h := range given[:i] {
|
|
if types.Identical(g, h) {
|
|
return nil, fmt.Errorf("multiple inputs of the same type %s", types.TypeString(g, nil))
|
|
}
|
|
}
|
|
}
|
|
providers, err := buildProviderMap(mc, sets)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Start building the mapping of type to local variable of the given type.
|
|
// The first len(given) local variables are the given types.
|
|
index := new(typeutil.Map)
|
|
for i, g := range given {
|
|
if p := providers.At(g); p != nil {
|
|
pp := p.(*providerInfo)
|
|
return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.funcName, mc.fset.Position(pp.pos))
|
|
}
|
|
index.Set(g, i)
|
|
}
|
|
|
|
// Topological sort of the directed graph defined by the providers
|
|
// using a depth-first search. The graph may contain cycles, which
|
|
// should trigger an error.
|
|
var calls []call
|
|
var visit func(trail []providerInput) error
|
|
visit = func(trail []providerInput) error {
|
|
typ := trail[len(trail)-1].typ
|
|
if index.At(typ) != nil {
|
|
return nil
|
|
}
|
|
for _, in := range trail[:len(trail)-1] {
|
|
if types.Identical(typ, in.typ) {
|
|
// TODO(light): describe cycle
|
|
return fmt.Errorf("cycle for %s", types.TypeString(typ, nil))
|
|
}
|
|
}
|
|
|
|
p, _ := providers.At(typ).(*providerInfo)
|
|
if p == nil {
|
|
if trail[len(trail)-1].optional {
|
|
return nil
|
|
}
|
|
if len(trail) == 1 {
|
|
return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil))
|
|
}
|
|
// TODO(light): give name of provider
|
|
return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2].typ, nil))
|
|
}
|
|
if !types.Identical(p.out, typ) {
|
|
// Interface binding. Don't create a call ourselves.
|
|
if err := visit(append(trail, providerInput{typ: p.out})); err != nil {
|
|
return err
|
|
}
|
|
index.Set(typ, index.At(p.out))
|
|
return nil
|
|
}
|
|
for _, a := range p.args {
|
|
// TODO(light): this will discard grown trail arrays.
|
|
if err := visit(append(trail, a)); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
args := make([]int, len(p.args))
|
|
ins := make([]types.Type, len(p.args))
|
|
for i := range p.args {
|
|
ins[i] = p.args[i].typ
|
|
if x := index.At(p.args[i].typ); x != nil {
|
|
args[i] = x.(int)
|
|
} else {
|
|
args[i] = -1
|
|
}
|
|
}
|
|
index.Set(typ, len(given)+len(calls))
|
|
calls = append(calls, call{
|
|
importPath: p.importPath,
|
|
funcName: p.funcName,
|
|
args: args,
|
|
ins: ins,
|
|
out: typ,
|
|
hasErr: p.hasErr,
|
|
})
|
|
return nil
|
|
}
|
|
if err := visit([]providerInput{{typ: out}}); err != nil {
|
|
return nil, err
|
|
}
|
|
return calls, nil
|
|
}
|
|
|
|
func buildProviderMap(mc *providerSetCache, sets []symref) (*typeutil.Map, error) {
|
|
type nextEnt struct {
|
|
to symref
|
|
|
|
from symref
|
|
pos token.Pos
|
|
}
|
|
type binding struct {
|
|
ifaceBinding
|
|
pset symref
|
|
from symref
|
|
}
|
|
|
|
pm := new(typeutil.Map) // to *providerInfo
|
|
var bindings []binding
|
|
visited := make(map[symref]struct{})
|
|
var next []nextEnt
|
|
for _, ref := range sets {
|
|
next = append(next, nextEnt{to: ref})
|
|
}
|
|
for len(next) > 0 {
|
|
curr := next[0]
|
|
copy(next, next[1:])
|
|
next = next[:len(next)-1]
|
|
if _, skip := visited[curr.to]; skip {
|
|
continue
|
|
}
|
|
visited[curr.to] = struct{}{}
|
|
pset, err := mc.get(curr.to)
|
|
if err != nil {
|
|
if !curr.pos.IsValid() {
|
|
return nil, err
|
|
}
|
|
return nil, fmt.Errorf("%v: %v", mc.fset.Position(curr.pos), err)
|
|
}
|
|
for _, p := range pset.providers {
|
|
if prev := pm.At(p.out); prev != nil {
|
|
pos := mc.fset.Position(p.pos)
|
|
typ := types.TypeString(p.out, nil)
|
|
prevPos := mc.fset.Position(prev.(*providerInfo).pos)
|
|
if curr.from.importPath == "" {
|
|
// Provider set is imported directly by injector.
|
|
return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos)
|
|
}
|
|
return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, curr.from, prevPos)
|
|
}
|
|
pm.Set(p.out, p)
|
|
}
|
|
for _, b := range pset.bindings {
|
|
bindings = append(bindings, binding{
|
|
ifaceBinding: b,
|
|
pset: curr.to,
|
|
from: curr.from,
|
|
})
|
|
}
|
|
for _, imp := range pset.imports {
|
|
next = append(next, nextEnt{to: imp.symref, from: curr.to, pos: imp.pos})
|
|
}
|
|
}
|
|
for _, b := range bindings {
|
|
if prev := pm.At(b.iface); prev != nil {
|
|
pos := mc.fset.Position(b.pos)
|
|
typ := types.TypeString(b.iface, nil)
|
|
// TODO(light): error message for conflicting with another interface binding will point at provider function instead of binding.
|
|
prevPos := mc.fset.Position(prev.(*providerInfo).pos)
|
|
if b.from.importPath == "" {
|
|
// Provider set is imported directly by injector.
|
|
return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos)
|
|
}
|
|
return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, b.from, prevPos)
|
|
}
|
|
concrete := pm.At(b.provided)
|
|
if concrete == nil {
|
|
pos := mc.fset.Position(b.pos)
|
|
typ := types.TypeString(b.provided, nil)
|
|
if b.from.importPath == "" {
|
|
// Concrete provider is imported directly by injector.
|
|
return nil, fmt.Errorf("%v: no binding for %s", pos, typ)
|
|
}
|
|
return nil, fmt.Errorf("%v: no binding for %s (imported by %v)", pos, typ, b.from)
|
|
}
|
|
pm.Set(b.iface, concrete)
|
|
}
|
|
return pm, nil
|
|
}
|