From 366207371e542d688e8336e4c68e0a4117bc1f2c Mon Sep 17 00:00:00 2001 From: Ross Light Date: Wed, 20 Jun 2018 11:21:59 -0700 Subject: [PATCH] wire: detect cycles incrementally (google/go-cloud#102) Idea originally mentioned in google/go-cloud#29. This means that any provider set loaded must not have cycles, which is stricter than before. The cycle error message now gives full detail on what caused the cycle. --- internal/wire/analyze.go | 62 +++++++++++++++++++++--- internal/wire/parse.go | 3 ++ internal/wire/testdata/Cycle/foo/foo.go | 37 ++++++++++++++ internal/wire/testdata/Cycle/foo/wire.go | 25 ++++++++++ internal/wire/testdata/Cycle/out.txt | 1 + internal/wire/testdata/Cycle/pkg | 1 + 6 files changed, 121 insertions(+), 8 deletions(-) create mode 100644 internal/wire/testdata/Cycle/foo/foo.go create mode 100644 internal/wire/testdata/Cycle/foo/wire.go create mode 100644 internal/wire/testdata/Cycle/out.txt create mode 100644 internal/wire/testdata/Cycle/pkg diff --git a/internal/wire/analyze.go b/internal/wire/analyze.go index 39c3de8..a8e5690 100644 --- a/internal/wire/analyze.go +++ b/internal/wire/analyze.go @@ -15,10 +15,12 @@ package wire import ( + "errors" "fmt" "go/ast" "go/token" "go/types" + "strings" "golang.org/x/tools/go/types/typeutil" ) @@ -106,8 +108,8 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide } // Topological sort of the directed graph defined by the providers - // using a depth-first search. The graph may contain cycles, which - // should trigger an error. + // using a depth-first search. Provider set graphs are guaranteed to + // be acyclic. var calls []call var visit func(trail []ProviderInput) error visit = func(trail []ProviderInput) error { @@ -115,12 +117,6 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide if index.At(typ) != nil { return nil } - for _, in := range trail[:len(trail)-1] { - if types.Identical(typ, in.Type) { - // TODO(light): Describe cycle. - return fmt.Errorf("cycle for %s", types.TypeString(typ, nil)) - } - } switch pv := set.For(typ); { case pv.IsNil(): @@ -248,6 +244,56 @@ func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *Provider return providerMap, nil } +func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) error { + // We must visit every provider type inside provider map, but we don't + // have a well-defined starting point and there may be several + // distinct graphs. Thus, we start a depth-first search at every + // provider, but keep a shared record of visited providers to avoid + // duplicating work. + visited := new(typeutil.Map) // to bool + visited.SetHasher(hasher) + for _, root := range providerMap.Keys() { + // Depth-first search using a stack of trails through the provider map. + stk := [][]types.Type{{root}} + for len(stk) > 0 { + curr := stk[len(stk)-1] + stk = stk[:len(stk)-1] + head := curr[len(curr)-1] + if v, _ := visited.At(head).(bool); v { + continue + } + visited.Set(head, true) + switch x := providerMap.At(head).(type) { + case nil: + // Leaf: input. + case *Value: + // Leaf: values do not have dependencies. + case *Provider: + for _, arg := range x.Args { + a := arg.Type + for i, b := range curr { + if types.Identical(a, b) { + sb := new(strings.Builder) + fmt.Fprintf(sb, "cycle for %s:\n", types.TypeString(a, nil)) + for j := i; j < len(curr); j++ { + p := providerMap.At(curr[j]).(*Provider) + fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.ImportPath, p.Name) + } + fmt.Fprintf(sb, "%s\n", types.TypeString(a, nil)) + return errors.New(sb.String()) + } + } + next := append(append([]types.Type(nil), curr...), a) + stk = append(stk, next) + } + default: + panic("invalid provider map value") + } + } + } + return nil +} + // bindingConflictError creates a new error describing multiple bindings // for the same output type. func bindingConflictError(fset *token.FileSet, pos token.Pos, typ types.Type, prevSet *ProviderSet) error { diff --git a/internal/wire/parse.go b/internal/wire/parse.go index d05fe83..4428c59 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -375,6 +375,9 @@ func (oc *objectCache) processNewSet(pkg *loader.PackageInfo, call *ast.CallExpr if err != nil { return nil, err } + if err := verifyAcyclic(pset.providerMap, oc.hasher); err != nil { + return nil, err + } return pset, nil } diff --git a/internal/wire/testdata/Cycle/foo/foo.go b/internal/wire/testdata/Cycle/foo/foo.go new file mode 100644 index 0000000..39da2fa --- /dev/null +++ b/internal/wire/testdata/Cycle/foo/foo.go @@ -0,0 +1,37 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import "fmt" + +func main() { + fmt.Println(injectedBaz()) +} + +type Foo int +type Bar int +type Baz int + +func provideFoo(_ Baz) Foo { + return 0 +} + +func provideBar(_ Foo) Bar { + return 0 +} + +func provideBaz(_ Bar) Baz { + return 0 +} diff --git a/internal/wire/testdata/Cycle/foo/wire.go b/internal/wire/testdata/Cycle/foo/wire.go new file mode 100644 index 0000000..9aff17e --- /dev/null +++ b/internal/wire/testdata/Cycle/foo/wire.go @@ -0,0 +1,25 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//+build wireinject + +package main + +import ( + "github.com/google/go-cloud/wire" +) + +func injectedBaz() Baz { + panic(wire.Build(provideFoo, provideBar, provideBaz)) +} diff --git a/internal/wire/testdata/Cycle/out.txt b/internal/wire/testdata/Cycle/out.txt new file mode 100644 index 0000000..5df7507 --- /dev/null +++ b/internal/wire/testdata/Cycle/out.txt @@ -0,0 +1 @@ +ERROR diff --git a/internal/wire/testdata/Cycle/pkg b/internal/wire/testdata/Cycle/pkg new file mode 100644 index 0000000..257cc56 --- /dev/null +++ b/internal/wire/testdata/Cycle/pkg @@ -0,0 +1 @@ +foo