Files
wire/internal/goose/analyze.go
Ross Light 1380f96c06 goose: add interface binding
An interface binding instructs goose that a concrete type should be used
to satisfy a dependency on an interface type. goose could determine this
implicitly, but having an explicit directive makes the provider author's
intent clear and allows different concrete types to satisfy different
smaller interfaces.

Reviewed-by: Tuo Shan <shantuo@google.com>
2018-11-12 14:09:56 -08:00

211 lines
6.3 KiB
Go

package goose
import (
"fmt"
"go/token"
"go/types"
"golang.org/x/tools/go/types/typeutil"
)
// A call represents a step of an injector function.
type call struct {
// importPath and funcName identify the provider function to call.
importPath string
funcName string
// args is a list of arguments to call the provider with. Each element is:
// a) one of the givens (args[i] < len(given)),
// b) the result of a previous provider call (args[i] >= len(given)), or
// c) the zero value for the type (args[i] == -1).
args []int
// ins is the list of types this call receives as arguments.
ins []types.Type
// out is the type produced by this provider call.
out types.Type
// hasErr is true if the provider call returns an error.
hasErr bool
}
// solve finds the sequence of calls required to produce an output type
// with an optional set of provided inputs.
func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symref) ([]call, error) {
for i, g := range given {
for _, h := range given[:i] {
if types.Identical(g, h) {
return nil, fmt.Errorf("multiple inputs of the same type %s", types.TypeString(g, nil))
}
}
}
providers, err := buildProviderMap(mc, sets)
if err != nil {
return nil, err
}
// Start building the mapping of type to local variable of the given type.
// The first len(given) local variables are the given types.
index := new(typeutil.Map)
for i, g := range given {
if p := providers.At(g); p != nil {
pp := p.(*providerInfo)
return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.funcName, mc.fset.Position(pp.pos))
}
index.Set(g, i)
}
// Topological sort of the directed graph defined by the providers
// using a depth-first search. The graph may contain cycles, which
// should trigger an error.
var calls []call
var visit func(trail []providerInput) error
visit = func(trail []providerInput) error {
typ := trail[len(trail)-1].typ
if index.At(typ) != nil {
return nil
}
for _, in := range trail[:len(trail)-1] {
if types.Identical(typ, in.typ) {
// TODO(light): describe cycle
return fmt.Errorf("cycle for %s", types.TypeString(typ, nil))
}
}
p, _ := providers.At(typ).(*providerInfo)
if p == nil {
if trail[len(trail)-1].optional {
return nil
}
if len(trail) == 1 {
return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil))
}
// TODO(light): give name of provider
return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2].typ, nil))
}
if !types.Identical(p.out, typ) {
// Interface binding. Don't create a call ourselves.
if err := visit(append(trail, providerInput{typ: p.out})); err != nil {
return err
}
index.Set(typ, index.At(p.out))
return nil
}
for _, a := range p.args {
// TODO(light): this will discard grown trail arrays.
if err := visit(append(trail, a)); err != nil {
return err
}
}
args := make([]int, len(p.args))
ins := make([]types.Type, len(p.args))
for i := range p.args {
ins[i] = p.args[i].typ
if x := index.At(p.args[i].typ); x != nil {
args[i] = x.(int)
} else {
args[i] = -1
}
}
index.Set(typ, len(given)+len(calls))
calls = append(calls, call{
importPath: p.importPath,
funcName: p.funcName,
args: args,
ins: ins,
out: typ,
hasErr: p.hasErr,
})
return nil
}
if err := visit([]providerInput{{typ: out}}); err != nil {
return nil, err
}
return calls, nil
}
func buildProviderMap(mc *providerSetCache, sets []symref) (*typeutil.Map, error) {
type nextEnt struct {
to symref
from symref
pos token.Pos
}
type binding struct {
ifaceBinding
pset symref
from symref
}
pm := new(typeutil.Map) // to *providerInfo
var bindings []binding
visited := make(map[symref]struct{})
var next []nextEnt
for _, ref := range sets {
next = append(next, nextEnt{to: ref})
}
for len(next) > 0 {
curr := next[0]
copy(next, next[1:])
next = next[:len(next)-1]
if _, skip := visited[curr.to]; skip {
continue
}
visited[curr.to] = struct{}{}
pset, err := mc.get(curr.to)
if err != nil {
if !curr.pos.IsValid() {
return nil, err
}
return nil, fmt.Errorf("%v: %v", mc.fset.Position(curr.pos), err)
}
for _, p := range pset.providers {
if prev := pm.At(p.out); prev != nil {
pos := mc.fset.Position(p.pos)
typ := types.TypeString(p.out, nil)
prevPos := mc.fset.Position(prev.(*providerInfo).pos)
if curr.from.importPath == "" {
// Provider set is imported directly by injector.
return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos)
}
return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, curr.from, prevPos)
}
pm.Set(p.out, p)
}
for _, b := range pset.bindings {
bindings = append(bindings, binding{
ifaceBinding: b,
pset: curr.to,
from: curr.from,
})
}
for _, imp := range pset.imports {
next = append(next, nextEnt{to: imp.symref, from: curr.to, pos: imp.pos})
}
}
for _, b := range bindings {
if prev := pm.At(b.iface); prev != nil {
pos := mc.fset.Position(b.pos)
typ := types.TypeString(b.iface, nil)
// TODO(light): error message for conflicting with another interface binding will point at provider function instead of binding.
prevPos := mc.fset.Position(prev.(*providerInfo).pos)
if b.from.importPath == "" {
// Provider set is imported directly by injector.
return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos)
}
return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, b.from, prevPos)
}
concrete := pm.At(b.provided)
if concrete == nil {
pos := mc.fset.Position(b.pos)
typ := types.TypeString(b.provided, nil)
if b.from.importPath == "" {
// Concrete provider is imported directly by injector.
return nil, fmt.Errorf("%v: no binding for %s", pos, typ)
}
return nil, fmt.Errorf("%v: no binding for %s (imported by %v)", pos, typ, b.from)
}
pm.Set(b.iface, concrete)
}
return pm, nil
}