diff --git a/internal/wire/analyze.go b/internal/wire/analyze.go index fd9a4ec..b68afac 100644 --- a/internal/wire/analyze.go +++ b/internal/wire/analyze.go @@ -80,10 +80,11 @@ type call struct { // solve finds the sequence of calls required to produce an output type // with an optional set of provided inputs. func solve(fset *token.FileSet, out types.Type, given []types.Type, set *ProviderSet) ([]call, []error) { + ec := new(errorCollector) for i, g := range given { for _, h := range given[:i] { if types.Identical(g, h) { - return nil, []error{fmt.Errorf("multiple inputs of the same type %s", types.TypeString(g, nil))} + ec.add(fmt.Errorf("multiple inputs of the same type %s", types.TypeString(g, nil))) } } } @@ -95,27 +96,35 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide if pv := set.For(g); !pv.IsNil() { switch { case pv.IsProvider(): - return nil, []error{fmt.Errorf("input of %s conflicts with provider %s at %s", - types.TypeString(g, nil), pv.Provider().Name, fset.Position(pv.Provider().Pos))} + ec.add(fmt.Errorf("input of %s conflicts with provider %s at %s", + types.TypeString(g, nil), pv.Provider().Name, fset.Position(pv.Provider().Pos))) case pv.IsValue(): - return nil, []error{fmt.Errorf("input of %s conflicts with value at %s", - types.TypeString(g, nil), fset.Position(pv.Value().Pos))} + ec.add(fmt.Errorf("input of %s conflicts with value at %s", + types.TypeString(g, nil), fset.Position(pv.Value().Pos))) default: panic("unknown return value from ProviderSet.For") } + } else { + index.Set(g, i) } - index.Set(g, i) + } + + if len(ec.errors) > 0 { + return nil, ec.errors } // Topological sort of the directed graph defined by the providers // using a depth-first search using a stack. Provider set graphs are - // guaranteed to be acyclic. + // guaranteed to be acyclic. An index value of errAbort indicates that + // the type was visited, but failed due to an error added to ec. + errAbort := errors.New("failed to visit") var calls []call type frame struct { t types.Type from types.Type } stk := []frame{{t: out}} +dfs: for len(stk) > 0 { curr := stk[len(stk)-1] stk = stk[:len(stk)-1] @@ -126,10 +135,14 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide switch pv := set.For(curr.t); { case pv.IsNil(): if curr.from == nil { - return nil, []error{fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(curr.t, nil))} + ec.add(fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(curr.t, nil))) + index.Set(curr.t, errAbort) + continue } // TODO(light): Give name of provider. - return nil, []error{fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(curr.t, nil), types.TypeString(curr.from, nil))} + ec.add(fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(curr.t, nil), types.TypeString(curr.from, nil))) + index.Set(curr.t, errAbort) + continue case pv.IsProvider(): p := pv.Provider() if !types.Identical(p.Out, curr.t) { @@ -164,7 +177,12 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide ins := make([]types.Type, len(p.Args)) for i := range p.Args { ins[i] = p.Args[i].Type - args[i] = index.At(p.Args[i].Type).(int) + v := index.At(p.Args[i].Type) + if v == errAbort { + index.Set(curr.t, errAbort) + continue dfs + } + args[i] = v.(int) } index.Set(curr.t, len(given)+len(calls)) kind := funcProviderCall @@ -205,6 +223,9 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide panic("unknown return value from ProviderSet.For") } } + if len(ec.errors) > 0 { + return nil, ec.errors + } return calls, nil } @@ -217,47 +238,62 @@ func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *Provider setMap.SetHasher(hasher) // Process imports first, verifying that there are no conflicts between sets. + ec := new(errorCollector) for _, imp := range set.Imports { for _, k := range imp.providerMap.Keys() { if providerMap.At(k) != nil { - return nil, []error{bindingConflictError(fset, imp.Pos, k, setMap.At(k).(*ProviderSet))} + ec.add(bindingConflictError(fset, imp.Pos, k, setMap.At(k).(*ProviderSet))) + continue } providerMap.Set(k, imp.providerMap.At(k)) setMap.Set(k, imp) } } + if len(ec.errors) > 0 { + return nil, ec.errors + } // Process non-binding providers in new set. for _, p := range set.Providers { if providerMap.At(p.Out) != nil { - return nil, []error{bindingConflictError(fset, p.Pos, p.Out, setMap.At(p.Out).(*ProviderSet))} + ec.add(bindingConflictError(fset, p.Pos, p.Out, setMap.At(p.Out).(*ProviderSet))) + continue } providerMap.Set(p.Out, p) setMap.Set(p.Out, set) } for _, v := range set.Values { if providerMap.At(v.Out) != nil { - return nil, []error{bindingConflictError(fset, v.Pos, v.Out, setMap.At(v.Out).(*ProviderSet))} + ec.add(bindingConflictError(fset, v.Pos, v.Out, setMap.At(v.Out).(*ProviderSet))) + continue } providerMap.Set(v.Out, v) setMap.Set(v.Out, set) } + if len(ec.errors) > 0 { + return nil, ec.errors + } // Process bindings in set. Must happen after the other providers to // ensure the concrete type is being provided. for _, b := range set.Bindings { if providerMap.At(b.Iface) != nil { - return nil, []error{bindingConflictError(fset, b.Pos, b.Iface, setMap.At(b.Iface).(*ProviderSet))} + ec.add(bindingConflictError(fset, b.Pos, b.Iface, setMap.At(b.Iface).(*ProviderSet))) + continue } concrete := providerMap.At(b.Provided) if concrete == nil { pos := fset.Position(b.Pos) typ := types.TypeString(b.Provided, nil) - return nil, []error{fmt.Errorf("%v: no binding for %s", pos, typ)} + ec.add(notePosition(pos, fmt.Errorf("no binding for %s", typ))) + continue } providerMap.Set(b.Iface, concrete) setMap.Set(b.Iface, set) } + if len(ec.errors) > 0 { + return nil, ec.errors + } return providerMap, nil } @@ -269,6 +305,7 @@ func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) []error { // duplicating work. visited := new(typeutil.Map) // to bool visited.SetHasher(hasher) + ec := new(errorCollector) for _, root := range providerMap.Keys() { // Depth-first search using a stack of trails through the provider map. stk := [][]types.Type{{root}} @@ -288,6 +325,7 @@ func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) []error { case *Provider: for _, arg := range x.Args { a := arg.Type + hasCycle := false for i, b := range curr { if types.Identical(a, b) { sb := new(strings.Builder) @@ -297,30 +335,35 @@ func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) []error { fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.ImportPath, p.Name) } fmt.Fprintf(sb, "%s\n", types.TypeString(a, nil)) - return []error{errors.New(sb.String())} + ec.add(errors.New(sb.String())) + hasCycle = true + break } } - next := append(append([]types.Type(nil), curr...), a) - stk = append(stk, next) + if !hasCycle { + next := append(append([]types.Type(nil), curr...), a) + stk = append(stk, next) + } } default: panic("invalid provider map value") } } } - return nil + return ec.errors } // bindingConflictError creates a new error describing multiple bindings // for the same output type. func bindingConflictError(fset *token.FileSet, pos token.Pos, typ types.Type, prevSet *ProviderSet) error { - position := fset.Position(pos) typString := types.TypeString(typ, nil) + var err error if prevSet.Name == "" { - prevPosition := fset.Position(prevSet.Pos) - return fmt.Errorf("%v: multiple bindings for %s (previous binding at %v)", - position, typString, prevPosition) + err = fmt.Errorf("multiple bindings for %s (previous binding at %v)", + typString, fset.Position(prevSet.Pos)) + } else { + err = fmt.Errorf("multiple bindings for %s (previous binding in %q.%s)", + typString, prevSet.PkgPath, prevSet.Name) } - return fmt.Errorf("%v: multiple bindings for %s (previous binding in %q.%s)", - position, typString, prevSet.PkgPath, prevSet.Name) + return notePosition(fset.Position(pos), err) } diff --git a/internal/wire/errors.go b/internal/wire/errors.go new file mode 100644 index 0000000..ce74142 --- /dev/null +++ b/internal/wire/errors.go @@ -0,0 +1,84 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "go/token" +) + +// errorCollector manages a list of errors. The zero value is an empty list. +type errorCollector struct { + errors []error +} + +// add appends any non-nil errors to the collector. +func (ec *errorCollector) add(errs ...error) { + for _, e := range errs { + if e != nil { + ec.errors = append(ec.errors, e) + } + } +} + +// mapErrors returns a new slice that wraps any errors using the given function. +func mapErrors(errs []error, f func(error) error) []error { + if len(errs) == 0 { + return nil + } + newErrs := make([]error, len(errs)) + for i := range errs { + newErrs[i] = f(errs[i]) + } + return newErrs +} + +// A wireErr is an error with an optional position. +type wireErr struct { + error error + position token.Position +} + +// notePosition wraps an error with position information if it doesn't already +// have it. +// +// notePosition is usually called multiple times as an error goes up the call +// stack, so calling notePosition on an existing *wireErr will not modify the +// position, as the assumption is that deeper calls have more precise position +// information about the source of the error. +func notePosition(p token.Position, e error) error { + switch e.(type) { + case nil: + return nil + case *wireErr: + return e + default: + return &wireErr{error: e, position: p} + } +} + +// notePositionAll wraps a list of errors with the given position. +func notePositionAll(p token.Position, errs []error) []error { + return mapErrors(errs, func(e error) error { + return notePosition(p, e) + }) +} + +// Error returns the error message prefixed by the position if valid. +func (w *wireErr) Error() string { + if !w.position.IsValid() { + return w.error.Error() + } + return w.position.String() + ": " + w.error.Error() +} diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 32559a8..fba104f 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -144,21 +144,29 @@ type Value struct { } // Load finds all the provider sets in the given packages, as well as -// the provider sets' transitive dependencies. It may return both an error +// the provider sets' transitive dependencies. It may return both errors // and Info. func Load(bctx *build.Context, wd string, pkgs []string) (*Info, []error) { - // TODO(light): Stop errors from printing to stderr. + ec := new(errorCollector) conf := &loader.Config{ - Build: bctx, - Cwd: wd, + Build: bctx, + Cwd: wd, + TypeChecker: types.Config{ + Error: func(err error) { + ec.add(err) + }, + }, TypeCheckFuncBodies: func(string) bool { return false }, } for _, p := range pkgs { conf.Import(p) } prog, err := conf.Load() + if len(ec.errors) > 0 { + return nil, ec.errors + } if err != nil { - return nil, []error{fmt.Errorf("load: %v", err)} + return nil, []error{err} } info := &Info{ Fset: prog.Fset, @@ -168,21 +176,23 @@ func Load(bctx *build.Context, wd string, pkgs []string) (*Info, []error) { for _, pkgInfo := range prog.InitialPackages() { scope := pkgInfo.Pkg.Scope() for _, name := range scope.Names() { - item, err := oc.get(scope.Lookup(name)) - if err != nil { + obj := scope.Lookup(name) + if !isProviderSetType(obj.Type()) { continue } - pset, ok := item.(*ProviderSet) - if !ok { + item, errs := oc.get(obj) + if len(errs) > 0 { + ec.add(notePositionAll(prog.Fset.Position(obj.Pos()), errs)...) continue } + pset := item.(*ProviderSet) // pset.Name may not equal name, since it could be an alias to // another provider set. id := ProviderSetID{ImportPath: pset.PkgPath, VarName: name} info.Sets[id] = pset } } - return info, nil + return info, ec.errors } // Info holds the result of Load. @@ -293,59 +303,47 @@ func (oc *objectCache) processExpr(pkg *loader.PackageInfo, expr ast.Expr) (inte exprPos := oc.prog.Fset.Position(expr.Pos()) expr = astutil.Unparen(expr) if obj := qualifiedIdentObject(&pkg.Info, expr); obj != nil { - item, err := oc.get(obj) - if err != nil { - return nil, []error{fmt.Errorf("%v: %v", exprPos, err)} - } - return item, nil + item, errs := oc.get(obj) + return item, mapErrors(errs, func(err error) error { + return notePosition(exprPos, err) + }) } if call, ok := expr.(*ast.CallExpr); ok { fnObj := qualifiedIdentObject(&pkg.Info, call.Fun) if fnObj == nil || !isWireImport(fnObj.Pkg().Path()) { - return nil, []error{fmt.Errorf("%v: unknown pattern", exprPos)} + return nil, []error{notePosition(exprPos, errors.New("unknown pattern"))} } switch fnObj.Name() { case "NewSet": pset, errs := oc.processNewSet(pkg, call) - if len(errs) > 0 { - errs = append([]error(nil), errs...) - for i := range errs { - errs[i] = fmt.Errorf("%v: %v", exprPos, errs[i]) - } - return nil, errs - } - return pset, nil + return pset, notePositionAll(exprPos, errs) case "Bind": b, err := processBind(oc.prog.Fset, &pkg.Info, call) if err != nil { - return nil, []error{fmt.Errorf("%v: %v", exprPos, err)} + return nil, []error{notePosition(exprPos, err)} } return b, nil case "Value": v, err := processValue(oc.prog.Fset, &pkg.Info, call) if err != nil { - return nil, []error{fmt.Errorf("%v: %v", exprPos, err)} + return nil, []error{notePosition(exprPos, err)} } return v, nil default: - return nil, []error{fmt.Errorf("%v: unknown pattern", exprPos)} + return nil, []error{notePosition(exprPos, errors.New("unknown pattern"))} } } if tn := structArgType(&pkg.Info, expr); tn != nil { p, errs := processStructProvider(oc.prog.Fset, tn) if len(errs) > 0 { - errs = append([]error(nil), errs...) - for i := range errs { - errs[i] = fmt.Errorf("%v: %v", exprPos, errs[i]) - } - return nil, errs + return nil, notePositionAll(exprPos, errs) } ptrp := new(Provider) *ptrp = *p ptrp.Out = types.NewPointer(p.Out) return structProviderPair{p, ptrp}, nil } - return nil, []error{fmt.Errorf("%v: unknown pattern", exprPos)} + return nil, []error{notePosition(exprPos, errors.New("unknown pattern"))} } type structProviderPair struct { @@ -360,10 +358,12 @@ func (oc *objectCache) processNewSet(pkg *loader.PackageInfo, call *ast.CallExpr Pos: call.Pos(), PkgPath: pkg.Pkg.Path(), } + ec := new(errorCollector) for _, arg := range call.Args { - item, err := oc.processExpr(pkg, arg) - if err != nil { - return nil, err + item, errs := oc.processExpr(pkg, arg) + if len(errs) > 0 { + ec.add(errs...) + continue } switch item := item.(type) { case *Provider: @@ -380,6 +380,9 @@ func (oc *objectCache) processNewSet(pkg *loader.PackageInfo, call *ast.CallExpr panic("unknown item type") } } + if len(ec.errors) > 0 { + return nil, ec.errors + } var errs []error pset.providerMap, errs = buildProviderMap(oc.prog.Fset, oc.hasher, pset) if len(errs) > 0 { @@ -434,7 +437,7 @@ func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, []erro fpos := fn.Pos() providerSig, err := funcOutput(sig) if err != nil { - return nil, []error{fmt.Errorf("%v: wrong signature for provider %s: %v", fset.Position(fpos), fn.Name(), err)} + return nil, []error{notePosition(fset.Position(fpos), fmt.Errorf("wrong signature for provider %s: %v", fn.Name(), err))} } params := sig.Params() provider := &Provider{ @@ -452,7 +455,7 @@ func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, []erro } for j := 0; j < i; j++ { if types.Identical(provider.Args[i].Type, provider.Args[j].Type) { - return nil, []error{fmt.Errorf("%v: provider has multiple parameters of type %s", fset.Position(fpos), types.TypeString(provider.Args[j].Type, nil))} + return nil, []error{notePosition(fset.Position(fpos), fmt.Errorf("provider has multiple parameters of type %s", types.TypeString(provider.Args[j].Type, nil)))} } } } @@ -527,7 +530,7 @@ func processStructProvider(fset *token.FileSet, typeName *types.TypeName) (*Prov provider.Fields[i] = f.Name() for j := 0; j < i; j++ { if types.Identical(provider.Args[i].Type, provider.Args[j].Type) { - return nil, []error{fmt.Errorf("%v: provider struct has multiple fields of type %s", fset.Position(pos), types.TypeString(provider.Args[j].Type, nil))} + return nil, []error{notePosition(fset.Position(pos), fmt.Errorf("provider struct has multiple fields of type %s", types.TypeString(provider.Args[j].Type, nil)))} } } } @@ -539,24 +542,24 @@ func processBind(fset *token.FileSet, info *types.Info, call *ast.CallExpr) (*If // Assumes that call.Fun is wire.Bind. if len(call.Args) != 2 { - return nil, fmt.Errorf("%v: call to Bind takes exactly two arguments", fset.Position(call.Pos())) + return nil, notePosition(fset.Position(call.Pos()), errors.New("call to Bind takes exactly two arguments")) } // TODO(light): Verify that arguments are simple expressions. ifaceArgType := info.TypeOf(call.Args[0]) ifacePtr, ok := ifaceArgType.(*types.Pointer) if !ok { - return nil, fmt.Errorf("%v: first argument to bind must be a pointer to an interface type; found %s", fset.Position(call.Pos()), types.TypeString(ifaceArgType, nil)) + return nil, notePosition(fset.Position(call.Pos()), fmt.Errorf("first argument to bind must be a pointer to an interface type; found %s", types.TypeString(ifaceArgType, nil))) } methodSet, ok := ifacePtr.Elem().Underlying().(*types.Interface) if !ok { - return nil, fmt.Errorf("%v: first argument to bind must be a pointer to an interface type; found %s", fset.Position(call.Pos()), types.TypeString(ifaceArgType, nil)) + return nil, notePosition(fset.Position(call.Pos()), fmt.Errorf("first argument to bind must be a pointer to an interface type; found %s", types.TypeString(ifaceArgType, nil))) } provided := info.TypeOf(call.Args[1]) if types.Identical(ifacePtr.Elem(), provided) { - return nil, fmt.Errorf("%v: cannot bind interface to itself", fset.Position(call.Pos())) + return nil, notePosition(fset.Position(call.Pos()), errors.New("cannot bind interface to itself")) } if !types.Implements(provided, methodSet) { - return nil, fmt.Errorf("%v: %s does not implement %s", fset.Position(call.Pos()), types.TypeString(provided, nil), types.TypeString(ifaceArgType, nil)) + return nil, notePosition(fset.Position(call.Pos()), fmt.Errorf("%s does not implement %s", types.TypeString(provided, nil), types.TypeString(ifaceArgType, nil))) } return &IfaceBinding{ Pos: call.Pos(), @@ -570,7 +573,7 @@ func processValue(fset *token.FileSet, info *types.Info, call *ast.CallExpr) (*V // Assumes that call.Fun is wire.Value. if len(call.Args) != 1 { - return nil, fmt.Errorf("%v: call to Value takes exactly one argument", fset.Position(call.Pos())) + return nil, notePosition(fset.Position(call.Pos()), errors.New("call to Value takes exactly one argument")) } ok := true ast.Inspect(call.Args[0], func(node ast.Node) bool { @@ -597,7 +600,7 @@ func processValue(fset *token.FileSet, info *types.Info, call *ast.CallExpr) (*V return true }) if !ok { - return nil, fmt.Errorf("%v: argument to Value is too complex", fset.Position(call.Pos())) + return nil, notePosition(fset.Position(call.Pos()), errors.New("argument to Value is too complex")) } return &Value{ Pos: call.Args[0].Pos(), @@ -665,6 +668,15 @@ func isWireImport(path string) bool { return path == "github.com/google/go-cloud/wire" } +func isProviderSetType(t types.Type) bool { + n, ok := t.(*types.Named) + if !ok { + return false + } + obj := n.Obj() + return obj.Pkg() != nil && isWireImport(obj.Pkg().Path()) && obj.Name() == "ProviderSet" +} + // ProviderOrValue is a pointer to a Provider or a Value. The zero value is // a nil pointer. type ProviderOrValue struct { diff --git a/internal/wire/testdata/MultipleMissingInputs/foo/foo.go b/internal/wire/testdata/MultipleMissingInputs/foo/foo.go new file mode 100644 index 0000000..4c1a7d5 --- /dev/null +++ b/internal/wire/testdata/MultipleMissingInputs/foo/foo.go @@ -0,0 +1,29 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import "fmt" + +func main() { + fmt.Println(injectBaz()) +} + +type Foo int +type Bar int +type Baz int + +func provideBaz(foo Foo, bar Bar) Baz { + return 0 +} diff --git a/internal/wire/testdata/MultipleMissingInputs/foo/wire.go b/internal/wire/testdata/MultipleMissingInputs/foo/wire.go new file mode 100644 index 0000000..d304057 --- /dev/null +++ b/internal/wire/testdata/MultipleMissingInputs/foo/wire.go @@ -0,0 +1,25 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//+build wireinject + +package main + +import ( + "github.com/google/go-cloud/wire" +) + +func injectBaz() Baz { + panic(wire.Build(provideBaz)) +} diff --git a/internal/wire/testdata/MultipleMissingInputs/out.txt b/internal/wire/testdata/MultipleMissingInputs/out.txt new file mode 100644 index 0000000..e66e81f --- /dev/null +++ b/internal/wire/testdata/MultipleMissingInputs/out.txt @@ -0,0 +1,3 @@ +ERROR +no provider found for foo.Foo +no provider found for foo.Bar diff --git a/internal/wire/testdata/MultipleMissingInputs/pkg b/internal/wire/testdata/MultipleMissingInputs/pkg new file mode 100644 index 0000000..257cc56 --- /dev/null +++ b/internal/wire/testdata/MultipleMissingInputs/pkg @@ -0,0 +1 @@ +foo diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 72e6943..0aefb63 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -43,10 +43,15 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, []error) { if err != nil { return nil, []error{fmt.Errorf("load: %v", err)} } - // TODO(light): Stop errors from printing to stderr. + ec := new(errorCollector) conf := &loader.Config{ Build: bctx, Cwd: wd, + TypeChecker: types.Config{ + Error: func(err error) { + ec.add(err) + }, + }, TypeCheckFuncBodies: func(path string) bool { return path == mainPkg.ImportPath }, @@ -71,8 +76,11 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, []error) { conf.Import(pkg) prog, err := conf.Load() + if len(ec.errors) > 0 { + return nil, ec.errors + } if err != nil { - return nil, []error{fmt.Errorf("load: %v", err)} + return nil, []error{err} } if len(prog.InitialPackages()) != 1 { // This is more of a violated precondition than anything else. @@ -99,6 +107,7 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, []error) { func generateInjectors(g *gen, pkgInfo *loader.PackageInfo) (injectorFiles []*ast.File, _ []error) { oc := newObjectCache(g.prog) injectorFiles = make([]*ast.File, 0, len(pkgInfo.Files)) + ec := new(errorCollector) for _, f := range pkgInfo.Files { for _, decl := range f.Decls { fn, ok := decl.(*ast.FuncDecl) @@ -118,24 +127,19 @@ func generateInjectors(g *gen, pkgInfo *loader.PackageInfo) (injectorFiles []*as } set, errs := oc.processNewSet(pkgInfo, buildCall) if len(errs) > 0 { - position := g.prog.Fset.Position(fn.Pos()) - errs = append([]error(nil), errs...) - for i := range errs { - errs[i] = fmt.Errorf("%v: %v", position, errs[i]) - } - return nil, errs + ec.add(notePositionAll(g.prog.Fset.Position(fn.Pos()), errs)...) + continue } sig := pkgInfo.ObjectOf(fn.Name).Type().(*types.Signature) - if errs := g.inject(fn.Name.Name, sig, set); len(errs) > 0 { - position := g.prog.Fset.Position(fn.Pos()) - errs = append([]error(nil), errs...) - for i := range errs { - errs[i] = fmt.Errorf("%v: %v", position, errs[i]) - } - return nil, errs + if errs := g.inject(fn.Pos(), fn.Name.Name, sig, set); len(errs) > 0 { + ec.add(errs...) + continue } } } + if len(ec.errors) > 0 { + return nil, ec.errors + } return injectorFiles, nil } @@ -218,10 +222,11 @@ func (g *gen) frame() []byte { } // inject emits the code for an injector. -func (g *gen) inject(name string, sig *types.Signature, set *ProviderSet) []error { +func (g *gen) inject(pos token.Pos, name string, sig *types.Signature, set *ProviderSet) []error { injectSig, err := funcOutput(sig) if err != nil { - return []error{fmt.Errorf("inject %s: %v", name, err)} + return []error{notePosition(g.prog.Fset.Position(pos), + fmt.Errorf("inject %s: %v", name, err))} } params := sig.Params() given := make([]types.Type, params.Len()) @@ -230,7 +235,12 @@ func (g *gen) inject(name string, sig *types.Signature, set *ProviderSet) []erro } calls, errs := solve(g.prog.Fset, injectSig.out, given, set) if len(errs) > 0 { - return errs + return mapErrors(errs, func(e error) error { + if w, ok := e.(*wireErr); ok { + return notePosition(w.position, fmt.Errorf("inject %s: %v", name, w.error)) + } + return notePosition(g.prog.Fset.Position(pos), fmt.Errorf("inject %s: %v", name, e)) + }) } type pendingVar struct { name string @@ -238,19 +248,28 @@ func (g *gen) inject(name string, sig *types.Signature, set *ProviderSet) []erro typeInfo *types.Info } var pendingVars []pendingVar + ec := new(errorCollector) for i := range calls { c := &calls[i] if c.hasCleanup && !injectSig.cleanup { - return []error{fmt.Errorf("inject %s: provider for %s returns cleanup but injection does not return cleanup function", name, types.TypeString(c.out, nil))} + ts := types.TypeString(c.out, nil) + ec.add(notePosition( + g.prog.Fset.Position(pos), + fmt.Errorf("inject %s: provider for %s returns cleanup but injection does not return cleanup function", name, ts))) } if c.hasErr && !injectSig.err { - return []error{fmt.Errorf("inject %s: provider for %s returns error but injection not allowed to fail", name, types.TypeString(c.out, nil))} + ts := types.TypeString(c.out, nil) + ec.add(notePosition( + g.prog.Fset.Position(pos), + fmt.Errorf("inject %s: provider for %s returns error but injection not allowed to fail", name, ts))) } if c.kind == valueExpr { if err := accessibleFrom(c.valueTypeInfo, c.valueExpr, g.currPackage); err != nil { // TODO(light): Display line number of value expression. ts := types.TypeString(c.out, nil) - return []error{fmt.Errorf("inject %s: value %s can't be used: %v", name, ts, err)} + ec.add(notePosition( + g.prog.Fset.Position(pos), + fmt.Errorf("inject %s: value %s can't be used: %v", name, ts, err))) } if g.values[c.valueExpr] == "" { t := c.valueTypeInfo.TypeOf(c.valueExpr) @@ -264,6 +283,9 @@ func (g *gen) inject(name string, sig *types.Signature, set *ProviderSet) []erro } } } + if len(ec.errors) > 0 { + return ec.errors + } // Perform one pass to collect all imports, followed by the real pass. injectPass(name, params, injectSig, calls, &injectorGen{