wire: detect cycles incrementally (google/go-cloud#102)
Idea originally mentioned in google/go-cloud#29. This means that any provider set loaded must not have cycles, which is stricter than before. The cycle error message now gives full detail on what caused the cycle.
This commit is contained in:
@@ -15,10 +15,12 @@
|
||||
package wire
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/token"
|
||||
"go/types"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/tools/go/types/typeutil"
|
||||
)
|
||||
@@ -106,8 +108,8 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide
|
||||
}
|
||||
|
||||
// Topological sort of the directed graph defined by the providers
|
||||
// using a depth-first search. The graph may contain cycles, which
|
||||
// should trigger an error.
|
||||
// using a depth-first search. Provider set graphs are guaranteed to
|
||||
// be acyclic.
|
||||
var calls []call
|
||||
var visit func(trail []ProviderInput) error
|
||||
visit = func(trail []ProviderInput) error {
|
||||
@@ -115,12 +117,6 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide
|
||||
if index.At(typ) != nil {
|
||||
return nil
|
||||
}
|
||||
for _, in := range trail[:len(trail)-1] {
|
||||
if types.Identical(typ, in.Type) {
|
||||
// TODO(light): Describe cycle.
|
||||
return fmt.Errorf("cycle for %s", types.TypeString(typ, nil))
|
||||
}
|
||||
}
|
||||
|
||||
switch pv := set.For(typ); {
|
||||
case pv.IsNil():
|
||||
@@ -248,6 +244,56 @@ func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *Provider
|
||||
return providerMap, nil
|
||||
}
|
||||
|
||||
func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) error {
|
||||
// We must visit every provider type inside provider map, but we don't
|
||||
// have a well-defined starting point and there may be several
|
||||
// distinct graphs. Thus, we start a depth-first search at every
|
||||
// provider, but keep a shared record of visited providers to avoid
|
||||
// duplicating work.
|
||||
visited := new(typeutil.Map) // to bool
|
||||
visited.SetHasher(hasher)
|
||||
for _, root := range providerMap.Keys() {
|
||||
// Depth-first search using a stack of trails through the provider map.
|
||||
stk := [][]types.Type{{root}}
|
||||
for len(stk) > 0 {
|
||||
curr := stk[len(stk)-1]
|
||||
stk = stk[:len(stk)-1]
|
||||
head := curr[len(curr)-1]
|
||||
if v, _ := visited.At(head).(bool); v {
|
||||
continue
|
||||
}
|
||||
visited.Set(head, true)
|
||||
switch x := providerMap.At(head).(type) {
|
||||
case nil:
|
||||
// Leaf: input.
|
||||
case *Value:
|
||||
// Leaf: values do not have dependencies.
|
||||
case *Provider:
|
||||
for _, arg := range x.Args {
|
||||
a := arg.Type
|
||||
for i, b := range curr {
|
||||
if types.Identical(a, b) {
|
||||
sb := new(strings.Builder)
|
||||
fmt.Fprintf(sb, "cycle for %s:\n", types.TypeString(a, nil))
|
||||
for j := i; j < len(curr); j++ {
|
||||
p := providerMap.At(curr[j]).(*Provider)
|
||||
fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.ImportPath, p.Name)
|
||||
}
|
||||
fmt.Fprintf(sb, "%s\n", types.TypeString(a, nil))
|
||||
return errors.New(sb.String())
|
||||
}
|
||||
}
|
||||
next := append(append([]types.Type(nil), curr...), a)
|
||||
stk = append(stk, next)
|
||||
}
|
||||
default:
|
||||
panic("invalid provider map value")
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// bindingConflictError creates a new error describing multiple bindings
|
||||
// for the same output type.
|
||||
func bindingConflictError(fset *token.FileSet, pos token.Pos, typ types.Type, prevSet *ProviderSet) error {
|
||||
|
||||
Reference in New Issue
Block a user