2018-04-02 09:21:52 -07:00
|
|
|
package goose
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"fmt"
|
|
|
|
|
"go/token"
|
|
|
|
|
"go/types"
|
|
|
|
|
|
|
|
|
|
"golang.org/x/tools/go/types/typeutil"
|
|
|
|
|
)
|
|
|
|
|
|
2018-04-03 21:11:53 -07:00
|
|
|
// A call represents a step of an injector function. It may be either a
|
|
|
|
|
// function call or a composite struct literal, depending on the value
|
|
|
|
|
// of isStruct.
|
2018-04-02 09:21:52 -07:00
|
|
|
type call struct {
|
2018-04-03 21:11:53 -07:00
|
|
|
// importPath and name identify the provider to call.
|
2018-04-02 09:21:52 -07:00
|
|
|
importPath string
|
2018-04-03 21:11:53 -07:00
|
|
|
name string
|
2018-04-02 09:21:52 -07:00
|
|
|
|
2018-03-30 21:34:08 -07:00
|
|
|
// args is a list of arguments to call the provider with. Each element is:
|
|
|
|
|
// a) one of the givens (args[i] < len(given)),
|
|
|
|
|
// b) the result of a previous provider call (args[i] >= len(given)), or
|
|
|
|
|
// c) the zero value for the type (args[i] == -1).
|
2018-04-02 09:21:52 -07:00
|
|
|
args []int
|
|
|
|
|
|
2018-04-03 21:11:53 -07:00
|
|
|
// isStruct indicates whether this should generate a struct composite
|
|
|
|
|
// literal instead of a function call.
|
|
|
|
|
isStruct bool
|
|
|
|
|
|
|
|
|
|
// fieldNames maps the arguments to struct field names.
|
|
|
|
|
// This will only be set if isStruct is true.
|
|
|
|
|
fieldNames []string
|
|
|
|
|
|
2018-03-30 21:34:08 -07:00
|
|
|
// ins is the list of types this call receives as arguments.
|
|
|
|
|
ins []types.Type
|
2018-04-02 09:21:52 -07:00
|
|
|
// out is the type produced by this provider call.
|
|
|
|
|
out types.Type
|
2018-04-03 13:13:15 -07:00
|
|
|
// hasCleanup is true if the provider call returns a cleanup function.
|
|
|
|
|
hasCleanup bool
|
2018-04-02 09:21:52 -07:00
|
|
|
// hasErr is true if the provider call returns an error.
|
|
|
|
|
hasErr bool
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// solve finds the sequence of calls required to produce an output type
|
|
|
|
|
// with an optional set of provided inputs.
|
2018-04-27 13:44:54 -04:00
|
|
|
func solve(fset *token.FileSet, out types.Type, given []types.Type, set *ProviderSet) ([]call, error) {
|
2018-04-02 09:21:52 -07:00
|
|
|
for i, g := range given {
|
|
|
|
|
for _, h := range given[:i] {
|
|
|
|
|
if types.Identical(g, h) {
|
|
|
|
|
return nil, fmt.Errorf("multiple inputs of the same type %s", types.TypeString(g, nil))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
2018-04-27 13:44:54 -04:00
|
|
|
providers, err := buildProviderMap(fset, set)
|
2018-04-02 09:21:52 -07:00
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Start building the mapping of type to local variable of the given type.
|
|
|
|
|
// The first len(given) local variables are the given types.
|
|
|
|
|
index := new(typeutil.Map)
|
|
|
|
|
for i, g := range given {
|
|
|
|
|
if p := providers.At(g); p != nil {
|
2018-04-04 14:42:56 -07:00
|
|
|
pp := p.(*Provider)
|
2018-04-27 13:44:54 -04:00
|
|
|
return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.Name, fset.Position(pp.Pos))
|
2018-04-02 09:21:52 -07:00
|
|
|
}
|
|
|
|
|
index.Set(g, i)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Topological sort of the directed graph defined by the providers
|
|
|
|
|
// using a depth-first search. The graph may contain cycles, which
|
|
|
|
|
// should trigger an error.
|
|
|
|
|
var calls []call
|
2018-04-04 14:42:56 -07:00
|
|
|
var visit func(trail []ProviderInput) error
|
|
|
|
|
visit = func(trail []ProviderInput) error {
|
|
|
|
|
typ := trail[len(trail)-1].Type
|
2018-04-02 09:21:52 -07:00
|
|
|
if index.At(typ) != nil {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
2018-03-30 21:34:08 -07:00
|
|
|
for _, in := range trail[:len(trail)-1] {
|
2018-04-04 14:42:56 -07:00
|
|
|
if types.Identical(typ, in.Type) {
|
2018-04-27 17:40:40 -04:00
|
|
|
// TODO(light): Describe cycle.
|
2018-04-02 09:21:52 -07:00
|
|
|
return fmt.Errorf("cycle for %s", types.TypeString(typ, nil))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2018-04-04 14:42:56 -07:00
|
|
|
p, _ := providers.At(typ).(*Provider)
|
2018-04-02 09:21:52 -07:00
|
|
|
if p == nil {
|
|
|
|
|
if len(trail) == 1 {
|
|
|
|
|
return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil))
|
|
|
|
|
}
|
2018-04-27 17:40:40 -04:00
|
|
|
// TODO(light): Give name of provider.
|
2018-04-04 14:42:56 -07:00
|
|
|
return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2].Type, nil))
|
2018-04-02 09:21:52 -07:00
|
|
|
}
|
2018-04-04 14:42:56 -07:00
|
|
|
if !types.Identical(p.Out, typ) {
|
2018-04-02 14:08:17 -07:00
|
|
|
// Interface binding. Don't create a call ourselves.
|
2018-04-04 14:42:56 -07:00
|
|
|
if err := visit(append(trail, ProviderInput{Type: p.Out})); err != nil {
|
2018-04-02 14:08:17 -07:00
|
|
|
return err
|
|
|
|
|
}
|
2018-04-04 14:42:56 -07:00
|
|
|
index.Set(typ, index.At(p.Out))
|
2018-04-02 14:08:17 -07:00
|
|
|
return nil
|
|
|
|
|
}
|
2018-04-04 14:42:56 -07:00
|
|
|
for _, a := range p.Args {
|
2018-04-27 17:40:40 -04:00
|
|
|
// TODO(light): This will discard grown trail arrays.
|
2018-04-02 09:21:52 -07:00
|
|
|
if err := visit(append(trail, a)); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
}
|
2018-04-04 14:42:56 -07:00
|
|
|
args := make([]int, len(p.Args))
|
|
|
|
|
ins := make([]types.Type, len(p.Args))
|
|
|
|
|
for i := range p.Args {
|
|
|
|
|
ins[i] = p.Args[i].Type
|
|
|
|
|
if x := index.At(p.Args[i].Type); x != nil {
|
2018-03-30 21:34:08 -07:00
|
|
|
args[i] = x.(int)
|
|
|
|
|
} else {
|
|
|
|
|
args[i] = -1
|
|
|
|
|
}
|
2018-04-02 09:21:52 -07:00
|
|
|
}
|
|
|
|
|
index.Set(typ, len(given)+len(calls))
|
|
|
|
|
calls = append(calls, call{
|
2018-04-04 14:42:56 -07:00
|
|
|
importPath: p.ImportPath,
|
|
|
|
|
name: p.Name,
|
2018-04-02 09:21:52 -07:00
|
|
|
args: args,
|
2018-04-04 14:42:56 -07:00
|
|
|
isStruct: p.IsStruct,
|
|
|
|
|
fieldNames: p.Fields,
|
2018-03-30 21:34:08 -07:00
|
|
|
ins: ins,
|
2018-04-02 09:21:52 -07:00
|
|
|
out: typ,
|
2018-04-04 14:42:56 -07:00
|
|
|
hasCleanup: p.HasCleanup,
|
|
|
|
|
hasErr: p.HasErr,
|
2018-04-02 09:21:52 -07:00
|
|
|
})
|
|
|
|
|
return nil
|
|
|
|
|
}
|
2018-04-04 14:42:56 -07:00
|
|
|
if err := visit([]ProviderInput{{Type: out}}); err != nil {
|
2018-04-02 09:21:52 -07:00
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
return calls, nil
|
|
|
|
|
}
|
|
|
|
|
|
2018-04-27 13:44:54 -04:00
|
|
|
func buildProviderMap(fset *token.FileSet, set *ProviderSet) (*typeutil.Map, error) {
|
2018-04-02 14:08:17 -07:00
|
|
|
type binding struct {
|
2018-04-27 13:44:54 -04:00
|
|
|
*IfaceBinding
|
|
|
|
|
set *ProviderSet
|
2018-04-02 14:08:17 -07:00
|
|
|
}
|
2018-04-02 09:21:52 -07:00
|
|
|
|
2018-04-27 13:44:54 -04:00
|
|
|
providerMap := new(typeutil.Map) // to *Provider
|
|
|
|
|
setMap := new(typeutil.Map) // to *ProviderSet, for error messages
|
2018-04-02 14:08:17 -07:00
|
|
|
var bindings []binding
|
2018-04-27 13:44:54 -04:00
|
|
|
visited := make(map[*ProviderSet]struct{})
|
|
|
|
|
next := []*ProviderSet{set}
|
2018-04-02 09:21:52 -07:00
|
|
|
for len(next) > 0 {
|
|
|
|
|
curr := next[0]
|
|
|
|
|
copy(next, next[1:])
|
|
|
|
|
next = next[:len(next)-1]
|
2018-04-27 13:44:54 -04:00
|
|
|
if _, skip := visited[curr]; skip {
|
2018-04-02 09:21:52 -07:00
|
|
|
continue
|
|
|
|
|
}
|
2018-04-27 13:44:54 -04:00
|
|
|
visited[curr] = struct{}{}
|
|
|
|
|
for _, p := range curr.Providers {
|
|
|
|
|
if providerMap.At(p.Out) != nil {
|
|
|
|
|
return nil, bindingConflictError(fset, p.Pos, p.Out, setMap.At(p.Out).(*ProviderSet))
|
2018-04-02 09:21:52 -07:00
|
|
|
}
|
2018-04-27 13:44:54 -04:00
|
|
|
providerMap.Set(p.Out, p)
|
|
|
|
|
setMap.Set(p.Out, curr)
|
2018-04-02 09:21:52 -07:00
|
|
|
}
|
2018-04-27 13:44:54 -04:00
|
|
|
for _, b := range curr.Bindings {
|
2018-04-02 14:08:17 -07:00
|
|
|
bindings = append(bindings, binding{
|
2018-04-04 14:42:56 -07:00
|
|
|
IfaceBinding: b,
|
2018-04-27 13:44:54 -04:00
|
|
|
set: curr,
|
2018-04-02 14:08:17 -07:00
|
|
|
})
|
|
|
|
|
}
|
2018-04-27 13:44:54 -04:00
|
|
|
for _, imp := range curr.Imports {
|
|
|
|
|
next = append(next, imp)
|
2018-04-02 14:08:17 -07:00
|
|
|
}
|
|
|
|
|
}
|
2018-04-27 13:44:54 -04:00
|
|
|
// Validate that bindings have their concrete type provided in the set.
|
|
|
|
|
// TODO(light): Move this validation up into provider set creation.
|
2018-04-02 14:08:17 -07:00
|
|
|
for _, b := range bindings {
|
2018-04-27 13:44:54 -04:00
|
|
|
if providerMap.At(b.Iface) != nil {
|
|
|
|
|
return nil, bindingConflictError(fset, b.Pos, b.Iface, setMap.At(b.Iface).(*ProviderSet))
|
2018-04-02 14:08:17 -07:00
|
|
|
}
|
2018-04-27 13:44:54 -04:00
|
|
|
concrete := providerMap.At(b.Provided)
|
2018-04-02 14:08:17 -07:00
|
|
|
if concrete == nil {
|
2018-04-27 13:44:54 -04:00
|
|
|
pos := fset.Position(b.Pos)
|
2018-04-04 14:42:56 -07:00
|
|
|
typ := types.TypeString(b.Provided, nil)
|
2018-04-27 13:44:54 -04:00
|
|
|
return nil, fmt.Errorf("%v: no binding for %s", pos, typ)
|
2018-04-02 09:21:52 -07:00
|
|
|
}
|
2018-04-27 13:44:54 -04:00
|
|
|
providerMap.Set(b.Iface, concrete)
|
|
|
|
|
setMap.Set(b.Iface, b.set)
|
|
|
|
|
}
|
|
|
|
|
return providerMap, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// bindingConflictError creates a new error describing multiple bindings
|
|
|
|
|
// for the same output type.
|
|
|
|
|
func bindingConflictError(fset *token.FileSet, pos token.Pos, typ types.Type, prevSet *ProviderSet) error {
|
|
|
|
|
position := fset.Position(pos)
|
|
|
|
|
typString := types.TypeString(typ, nil)
|
|
|
|
|
if prevSet.Name == "" {
|
|
|
|
|
prevPosition := fset.Position(prevSet.Pos)
|
|
|
|
|
return fmt.Errorf("%v: multiple bindings for %s (previous binding at %v)",
|
|
|
|
|
position, typString, prevPosition)
|
2018-04-02 09:21:52 -07:00
|
|
|
}
|
2018-04-27 13:44:54 -04:00
|
|
|
return fmt.Errorf("%v: multiple bindings for %s (previous binding in %q.%s)",
|
|
|
|
|
position, typString, prevSet.PkgPath, prevSet.Name)
|
2018-04-02 09:21:52 -07:00
|
|
|
}
|