wire: detect cycles incrementally (google/go-cloud#102)
Idea originally mentioned in google/go-cloud#29. This means that any provider set loaded must not have cycles, which is stricter than before. The cycle error message now gives full detail on what caused the cycle.
This commit is contained in:
@@ -15,10 +15,12 @@
|
|||||||
package wire
|
package wire
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"go/ast"
|
"go/ast"
|
||||||
"go/token"
|
"go/token"
|
||||||
"go/types"
|
"go/types"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/tools/go/types/typeutil"
|
"golang.org/x/tools/go/types/typeutil"
|
||||||
)
|
)
|
||||||
@@ -106,8 +108,8 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Topological sort of the directed graph defined by the providers
|
// Topological sort of the directed graph defined by the providers
|
||||||
// using a depth-first search. The graph may contain cycles, which
|
// using a depth-first search. Provider set graphs are guaranteed to
|
||||||
// should trigger an error.
|
// be acyclic.
|
||||||
var calls []call
|
var calls []call
|
||||||
var visit func(trail []ProviderInput) error
|
var visit func(trail []ProviderInput) error
|
||||||
visit = func(trail []ProviderInput) error {
|
visit = func(trail []ProviderInput) error {
|
||||||
@@ -115,12 +117,6 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide
|
|||||||
if index.At(typ) != nil {
|
if index.At(typ) != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
for _, in := range trail[:len(trail)-1] {
|
|
||||||
if types.Identical(typ, in.Type) {
|
|
||||||
// TODO(light): Describe cycle.
|
|
||||||
return fmt.Errorf("cycle for %s", types.TypeString(typ, nil))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch pv := set.For(typ); {
|
switch pv := set.For(typ); {
|
||||||
case pv.IsNil():
|
case pv.IsNil():
|
||||||
@@ -248,6 +244,56 @@ func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *Provider
|
|||||||
return providerMap, nil
|
return providerMap, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) error {
|
||||||
|
// We must visit every provider type inside provider map, but we don't
|
||||||
|
// have a well-defined starting point and there may be several
|
||||||
|
// distinct graphs. Thus, we start a depth-first search at every
|
||||||
|
// provider, but keep a shared record of visited providers to avoid
|
||||||
|
// duplicating work.
|
||||||
|
visited := new(typeutil.Map) // to bool
|
||||||
|
visited.SetHasher(hasher)
|
||||||
|
for _, root := range providerMap.Keys() {
|
||||||
|
// Depth-first search using a stack of trails through the provider map.
|
||||||
|
stk := [][]types.Type{{root}}
|
||||||
|
for len(stk) > 0 {
|
||||||
|
curr := stk[len(stk)-1]
|
||||||
|
stk = stk[:len(stk)-1]
|
||||||
|
head := curr[len(curr)-1]
|
||||||
|
if v, _ := visited.At(head).(bool); v {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
visited.Set(head, true)
|
||||||
|
switch x := providerMap.At(head).(type) {
|
||||||
|
case nil:
|
||||||
|
// Leaf: input.
|
||||||
|
case *Value:
|
||||||
|
// Leaf: values do not have dependencies.
|
||||||
|
case *Provider:
|
||||||
|
for _, arg := range x.Args {
|
||||||
|
a := arg.Type
|
||||||
|
for i, b := range curr {
|
||||||
|
if types.Identical(a, b) {
|
||||||
|
sb := new(strings.Builder)
|
||||||
|
fmt.Fprintf(sb, "cycle for %s:\n", types.TypeString(a, nil))
|
||||||
|
for j := i; j < len(curr); j++ {
|
||||||
|
p := providerMap.At(curr[j]).(*Provider)
|
||||||
|
fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.ImportPath, p.Name)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(sb, "%s\n", types.TypeString(a, nil))
|
||||||
|
return errors.New(sb.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
next := append(append([]types.Type(nil), curr...), a)
|
||||||
|
stk = append(stk, next)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
panic("invalid provider map value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// bindingConflictError creates a new error describing multiple bindings
|
// bindingConflictError creates a new error describing multiple bindings
|
||||||
// for the same output type.
|
// for the same output type.
|
||||||
func bindingConflictError(fset *token.FileSet, pos token.Pos, typ types.Type, prevSet *ProviderSet) error {
|
func bindingConflictError(fset *token.FileSet, pos token.Pos, typ types.Type, prevSet *ProviderSet) error {
|
||||||
|
|||||||
@@ -375,6 +375,9 @@ func (oc *objectCache) processNewSet(pkg *loader.PackageInfo, call *ast.CallExpr
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if err := verifyAcyclic(pset.providerMap, oc.hasher); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return pset, nil
|
return pset, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
37
internal/wire/testdata/Cycle/foo/foo.go
vendored
Normal file
37
internal/wire/testdata/Cycle/foo/foo.go
vendored
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
// Copyright 2018 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
fmt.Println(injectedBaz())
|
||||||
|
}
|
||||||
|
|
||||||
|
type Foo int
|
||||||
|
type Bar int
|
||||||
|
type Baz int
|
||||||
|
|
||||||
|
func provideFoo(_ Baz) Foo {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func provideBar(_ Foo) Bar {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func provideBaz(_ Bar) Baz {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
25
internal/wire/testdata/Cycle/foo/wire.go
vendored
Normal file
25
internal/wire/testdata/Cycle/foo/wire.go
vendored
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
// Copyright 2018 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
//+build wireinject
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/google/go-cloud/wire"
|
||||||
|
)
|
||||||
|
|
||||||
|
func injectedBaz() Baz {
|
||||||
|
panic(wire.Build(provideFoo, provideBar, provideBaz))
|
||||||
|
}
|
||||||
1
internal/wire/testdata/Cycle/out.txt
vendored
Normal file
1
internal/wire/testdata/Cycle/out.txt
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
ERROR
|
||||||
1
internal/wire/testdata/Cycle/pkg
vendored
Normal file
1
internal/wire/testdata/Cycle/pkg
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
foo
|
||||||
Reference in New Issue
Block a user