2018-04-02 09:21:52 -07:00
|
|
|
package goose
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"fmt"
|
|
|
|
|
"go/token"
|
|
|
|
|
"go/types"
|
|
|
|
|
|
|
|
|
|
"golang.org/x/tools/go/types/typeutil"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
// A call represents a step of an injector function.
|
|
|
|
|
type call struct {
|
|
|
|
|
// importPath and funcName identify the provider function to call.
|
|
|
|
|
importPath string
|
|
|
|
|
funcName string
|
|
|
|
|
|
2018-03-30 21:34:08 -07:00
|
|
|
// args is a list of arguments to call the provider with. Each element is:
|
|
|
|
|
// a) one of the givens (args[i] < len(given)),
|
|
|
|
|
// b) the result of a previous provider call (args[i] >= len(given)), or
|
|
|
|
|
// c) the zero value for the type (args[i] == -1).
|
2018-04-02 09:21:52 -07:00
|
|
|
args []int
|
|
|
|
|
|
2018-03-30 21:34:08 -07:00
|
|
|
// ins is the list of types this call receives as arguments.
|
|
|
|
|
ins []types.Type
|
2018-04-02 09:21:52 -07:00
|
|
|
// out is the type produced by this provider call.
|
|
|
|
|
out types.Type
|
2018-04-03 13:13:15 -07:00
|
|
|
// hasCleanup is true if the provider call returns a cleanup function.
|
|
|
|
|
hasCleanup bool
|
2018-04-02 09:21:52 -07:00
|
|
|
// hasErr is true if the provider call returns an error.
|
|
|
|
|
hasErr bool
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// solve finds the sequence of calls required to produce an output type
|
|
|
|
|
// with an optional set of provided inputs.
|
2018-04-02 14:08:17 -07:00
|
|
|
func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symref) ([]call, error) {
|
2018-04-02 09:21:52 -07:00
|
|
|
for i, g := range given {
|
|
|
|
|
for _, h := range given[:i] {
|
|
|
|
|
if types.Identical(g, h) {
|
|
|
|
|
return nil, fmt.Errorf("multiple inputs of the same type %s", types.TypeString(g, nil))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
providers, err := buildProviderMap(mc, sets)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Start building the mapping of type to local variable of the given type.
|
|
|
|
|
// The first len(given) local variables are the given types.
|
|
|
|
|
index := new(typeutil.Map)
|
|
|
|
|
for i, g := range given {
|
|
|
|
|
if p := providers.At(g); p != nil {
|
|
|
|
|
pp := p.(*providerInfo)
|
|
|
|
|
return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.funcName, mc.fset.Position(pp.pos))
|
|
|
|
|
}
|
|
|
|
|
index.Set(g, i)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Topological sort of the directed graph defined by the providers
|
|
|
|
|
// using a depth-first search. The graph may contain cycles, which
|
|
|
|
|
// should trigger an error.
|
|
|
|
|
var calls []call
|
2018-03-30 21:34:08 -07:00
|
|
|
var visit func(trail []providerInput) error
|
|
|
|
|
visit = func(trail []providerInput) error {
|
|
|
|
|
typ := trail[len(trail)-1].typ
|
2018-04-02 09:21:52 -07:00
|
|
|
if index.At(typ) != nil {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
2018-03-30 21:34:08 -07:00
|
|
|
for _, in := range trail[:len(trail)-1] {
|
|
|
|
|
if types.Identical(typ, in.typ) {
|
2018-04-02 09:21:52 -07:00
|
|
|
// TODO(light): describe cycle
|
|
|
|
|
return fmt.Errorf("cycle for %s", types.TypeString(typ, nil))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
p, _ := providers.At(typ).(*providerInfo)
|
|
|
|
|
if p == nil {
|
2018-03-30 21:34:08 -07:00
|
|
|
if trail[len(trail)-1].optional {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
2018-04-02 09:21:52 -07:00
|
|
|
if len(trail) == 1 {
|
|
|
|
|
return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil))
|
|
|
|
|
}
|
|
|
|
|
// TODO(light): give name of provider
|
2018-03-30 21:34:08 -07:00
|
|
|
return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2].typ, nil))
|
2018-04-02 09:21:52 -07:00
|
|
|
}
|
2018-04-02 14:08:17 -07:00
|
|
|
if !types.Identical(p.out, typ) {
|
|
|
|
|
// Interface binding. Don't create a call ourselves.
|
|
|
|
|
if err := visit(append(trail, providerInput{typ: p.out})); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
index.Set(typ, index.At(p.out))
|
|
|
|
|
return nil
|
|
|
|
|
}
|
2018-04-02 09:21:52 -07:00
|
|
|
for _, a := range p.args {
|
|
|
|
|
// TODO(light): this will discard grown trail arrays.
|
|
|
|
|
if err := visit(append(trail, a)); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
args := make([]int, len(p.args))
|
2018-03-30 21:34:08 -07:00
|
|
|
ins := make([]types.Type, len(p.args))
|
2018-04-02 09:21:52 -07:00
|
|
|
for i := range p.args {
|
2018-03-30 21:34:08 -07:00
|
|
|
ins[i] = p.args[i].typ
|
|
|
|
|
if x := index.At(p.args[i].typ); x != nil {
|
|
|
|
|
args[i] = x.(int)
|
|
|
|
|
} else {
|
|
|
|
|
args[i] = -1
|
|
|
|
|
}
|
2018-04-02 09:21:52 -07:00
|
|
|
}
|
|
|
|
|
index.Set(typ, len(given)+len(calls))
|
|
|
|
|
calls = append(calls, call{
|
|
|
|
|
importPath: p.importPath,
|
|
|
|
|
funcName: p.funcName,
|
|
|
|
|
args: args,
|
2018-03-30 21:34:08 -07:00
|
|
|
ins: ins,
|
2018-04-02 09:21:52 -07:00
|
|
|
out: typ,
|
2018-04-03 13:13:15 -07:00
|
|
|
hasCleanup: p.hasCleanup,
|
2018-04-02 09:21:52 -07:00
|
|
|
hasErr: p.hasErr,
|
|
|
|
|
})
|
|
|
|
|
return nil
|
|
|
|
|
}
|
2018-03-30 21:34:08 -07:00
|
|
|
if err := visit([]providerInput{{typ: out}}); err != nil {
|
2018-04-02 09:21:52 -07:00
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
return calls, nil
|
|
|
|
|
}
|
|
|
|
|
|
2018-04-02 14:08:17 -07:00
|
|
|
func buildProviderMap(mc *providerSetCache, sets []symref) (*typeutil.Map, error) {
|
2018-04-02 09:21:52 -07:00
|
|
|
type nextEnt struct {
|
2018-04-02 14:08:17 -07:00
|
|
|
to symref
|
2018-04-02 09:21:52 -07:00
|
|
|
|
2018-04-02 14:08:17 -07:00
|
|
|
from symref
|
2018-04-02 09:21:52 -07:00
|
|
|
pos token.Pos
|
|
|
|
|
}
|
2018-04-02 14:08:17 -07:00
|
|
|
type binding struct {
|
|
|
|
|
ifaceBinding
|
|
|
|
|
pset symref
|
|
|
|
|
from symref
|
|
|
|
|
}
|
2018-04-02 09:21:52 -07:00
|
|
|
|
|
|
|
|
pm := new(typeutil.Map) // to *providerInfo
|
2018-04-02 14:08:17 -07:00
|
|
|
var bindings []binding
|
|
|
|
|
visited := make(map[symref]struct{})
|
2018-04-02 09:21:52 -07:00
|
|
|
var next []nextEnt
|
|
|
|
|
for _, ref := range sets {
|
|
|
|
|
next = append(next, nextEnt{to: ref})
|
|
|
|
|
}
|
|
|
|
|
for len(next) > 0 {
|
|
|
|
|
curr := next[0]
|
|
|
|
|
copy(next, next[1:])
|
|
|
|
|
next = next[:len(next)-1]
|
|
|
|
|
if _, skip := visited[curr.to]; skip {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
visited[curr.to] = struct{}{}
|
2018-04-02 14:08:17 -07:00
|
|
|
pset, err := mc.get(curr.to)
|
2018-04-02 09:21:52 -07:00
|
|
|
if err != nil {
|
|
|
|
|
if !curr.pos.IsValid() {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
return nil, fmt.Errorf("%v: %v", mc.fset.Position(curr.pos), err)
|
|
|
|
|
}
|
2018-04-02 14:08:17 -07:00
|
|
|
for _, p := range pset.providers {
|
2018-04-02 09:21:52 -07:00
|
|
|
if prev := pm.At(p.out); prev != nil {
|
|
|
|
|
pos := mc.fset.Position(p.pos)
|
|
|
|
|
typ := types.TypeString(p.out, nil)
|
|
|
|
|
prevPos := mc.fset.Position(prev.(*providerInfo).pos)
|
2018-04-02 14:08:17 -07:00
|
|
|
if curr.from.importPath == "" {
|
|
|
|
|
// Provider set is imported directly by injector.
|
2018-04-02 09:21:52 -07:00
|
|
|
return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos)
|
|
|
|
|
}
|
|
|
|
|
return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, curr.from, prevPos)
|
|
|
|
|
}
|
|
|
|
|
pm.Set(p.out, p)
|
|
|
|
|
}
|
2018-04-02 14:08:17 -07:00
|
|
|
for _, b := range pset.bindings {
|
|
|
|
|
bindings = append(bindings, binding{
|
|
|
|
|
ifaceBinding: b,
|
|
|
|
|
pset: curr.to,
|
|
|
|
|
from: curr.from,
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
for _, imp := range pset.imports {
|
|
|
|
|
next = append(next, nextEnt{to: imp.symref, from: curr.to, pos: imp.pos})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for _, b := range bindings {
|
|
|
|
|
if prev := pm.At(b.iface); prev != nil {
|
|
|
|
|
pos := mc.fset.Position(b.pos)
|
|
|
|
|
typ := types.TypeString(b.iface, nil)
|
|
|
|
|
// TODO(light): error message for conflicting with another interface binding will point at provider function instead of binding.
|
|
|
|
|
prevPos := mc.fset.Position(prev.(*providerInfo).pos)
|
|
|
|
|
if b.from.importPath == "" {
|
|
|
|
|
// Provider set is imported directly by injector.
|
|
|
|
|
return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos)
|
|
|
|
|
}
|
|
|
|
|
return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, b.from, prevPos)
|
|
|
|
|
}
|
|
|
|
|
concrete := pm.At(b.provided)
|
|
|
|
|
if concrete == nil {
|
|
|
|
|
pos := mc.fset.Position(b.pos)
|
|
|
|
|
typ := types.TypeString(b.provided, nil)
|
|
|
|
|
if b.from.importPath == "" {
|
|
|
|
|
// Concrete provider is imported directly by injector.
|
|
|
|
|
return nil, fmt.Errorf("%v: no binding for %s", pos, typ)
|
|
|
|
|
}
|
|
|
|
|
return nil, fmt.Errorf("%v: no binding for %s (imported by %v)", pos, typ, b.from)
|
2018-04-02 09:21:52 -07:00
|
|
|
}
|
2018-04-02 14:08:17 -07:00
|
|
|
pm.Set(b.iface, concrete)
|
2018-04-02 09:21:52 -07:00
|
|
|
}
|
|
|
|
|
return pm, nil
|
|
|
|
|
}
|