wire: detect cycles incrementally (google/go-cloud#102)

Idea originally mentioned in google/go-cloud#29. This means that any provider set
loaded must not have cycles, which is stricter than before. The cycle
error message now gives full detail on what caused the cycle.
This commit is contained in:
Ross Light
2018-06-20 11:21:59 -07:00
parent cd52d44251
commit 366207371e
6 changed files with 121 additions and 8 deletions

View File

@@ -15,10 +15,12 @@
package wire package wire
import ( import (
"errors"
"fmt" "fmt"
"go/ast" "go/ast"
"go/token" "go/token"
"go/types" "go/types"
"strings"
"golang.org/x/tools/go/types/typeutil" "golang.org/x/tools/go/types/typeutil"
) )
@@ -106,8 +108,8 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide
} }
// Topological sort of the directed graph defined by the providers // Topological sort of the directed graph defined by the providers
// using a depth-first search. The graph may contain cycles, which // using a depth-first search. Provider set graphs are guaranteed to
// should trigger an error. // be acyclic.
var calls []call var calls []call
var visit func(trail []ProviderInput) error var visit func(trail []ProviderInput) error
visit = func(trail []ProviderInput) error { visit = func(trail []ProviderInput) error {
@@ -115,12 +117,6 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide
if index.At(typ) != nil { if index.At(typ) != nil {
return nil return nil
} }
for _, in := range trail[:len(trail)-1] {
if types.Identical(typ, in.Type) {
// TODO(light): Describe cycle.
return fmt.Errorf("cycle for %s", types.TypeString(typ, nil))
}
}
switch pv := set.For(typ); { switch pv := set.For(typ); {
case pv.IsNil(): case pv.IsNil():
@@ -248,6 +244,56 @@ func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *Provider
return providerMap, nil return providerMap, nil
} }
func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) error {
// We must visit every provider type inside provider map, but we don't
// have a well-defined starting point and there may be several
// distinct graphs. Thus, we start a depth-first search at every
// provider, but keep a shared record of visited providers to avoid
// duplicating work.
visited := new(typeutil.Map) // to bool
visited.SetHasher(hasher)
for _, root := range providerMap.Keys() {
// Depth-first search using a stack of trails through the provider map.
stk := [][]types.Type{{root}}
for len(stk) > 0 {
curr := stk[len(stk)-1]
stk = stk[:len(stk)-1]
head := curr[len(curr)-1]
if v, _ := visited.At(head).(bool); v {
continue
}
visited.Set(head, true)
switch x := providerMap.At(head).(type) {
case nil:
// Leaf: input.
case *Value:
// Leaf: values do not have dependencies.
case *Provider:
for _, arg := range x.Args {
a := arg.Type
for i, b := range curr {
if types.Identical(a, b) {
sb := new(strings.Builder)
fmt.Fprintf(sb, "cycle for %s:\n", types.TypeString(a, nil))
for j := i; j < len(curr); j++ {
p := providerMap.At(curr[j]).(*Provider)
fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.ImportPath, p.Name)
}
fmt.Fprintf(sb, "%s\n", types.TypeString(a, nil))
return errors.New(sb.String())
}
}
next := append(append([]types.Type(nil), curr...), a)
stk = append(stk, next)
}
default:
panic("invalid provider map value")
}
}
}
return nil
}
// bindingConflictError creates a new error describing multiple bindings // bindingConflictError creates a new error describing multiple bindings
// for the same output type. // for the same output type.
func bindingConflictError(fset *token.FileSet, pos token.Pos, typ types.Type, prevSet *ProviderSet) error { func bindingConflictError(fset *token.FileSet, pos token.Pos, typ types.Type, prevSet *ProviderSet) error {

View File

@@ -375,6 +375,9 @@ func (oc *objectCache) processNewSet(pkg *loader.PackageInfo, call *ast.CallExpr
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := verifyAcyclic(pset.providerMap, oc.hasher); err != nil {
return nil, err
}
return pset, nil return pset, nil
} }

37
internal/wire/testdata/Cycle/foo/foo.go vendored Normal file
View File

@@ -0,0 +1,37 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import "fmt"
func main() {
fmt.Println(injectedBaz())
}
type Foo int
type Bar int
type Baz int
func provideFoo(_ Baz) Foo {
return 0
}
func provideBar(_ Foo) Bar {
return 0
}
func provideBaz(_ Bar) Baz {
return 0
}

View File

@@ -0,0 +1,25 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//+build wireinject
package main
import (
"github.com/google/go-cloud/wire"
)
func injectedBaz() Baz {
panic(wire.Build(provideFoo, provideBar, provideBaz))
}

1
internal/wire/testdata/Cycle/out.txt vendored Normal file
View File

@@ -0,0 +1 @@
ERROR

1
internal/wire/testdata/Cycle/pkg vendored Normal file
View File

@@ -0,0 +1 @@
foo