From a2540bae2a9b6976b97cd946f8e1999ca2c5cf1e Mon Sep 17 00:00:00 2001 From: Ross Light Date: Fri, 13 Jul 2018 08:51:43 -0700 Subject: [PATCH] wire: update internal functions to return []error (google/go-cloud#197) This represents no functional change, it purely changes the signature used for functions that can possibly return multiple errors. A follow-up commit will change the control flow to proceed in the face of errors. --- internal/wire/analyze.go | 32 ++++++------- internal/wire/parse.go | 97 ++++++++++++++++++++++------------------ internal/wire/wire.go | 44 +++++++++++------- 3 files changed, 96 insertions(+), 77 deletions(-) diff --git a/internal/wire/analyze.go b/internal/wire/analyze.go index 4e40417..fd9a4ec 100644 --- a/internal/wire/analyze.go +++ b/internal/wire/analyze.go @@ -79,11 +79,11 @@ type call struct { // solve finds the sequence of calls required to produce an output type // with an optional set of provided inputs. -func solve(fset *token.FileSet, out types.Type, given []types.Type, set *ProviderSet) ([]call, error) { +func solve(fset *token.FileSet, out types.Type, given []types.Type, set *ProviderSet) ([]call, []error) { for i, g := range given { for _, h := range given[:i] { if types.Identical(g, h) { - return nil, fmt.Errorf("multiple inputs of the same type %s", types.TypeString(g, nil)) + return nil, []error{fmt.Errorf("multiple inputs of the same type %s", types.TypeString(g, nil))} } } } @@ -95,11 +95,11 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide if pv := set.For(g); !pv.IsNil() { switch { case pv.IsProvider(): - return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", - types.TypeString(g, nil), pv.Provider().Name, fset.Position(pv.Provider().Pos)) + return nil, []error{fmt.Errorf("input of %s conflicts with provider %s at %s", + types.TypeString(g, nil), pv.Provider().Name, fset.Position(pv.Provider().Pos))} case pv.IsValue(): - return nil, fmt.Errorf("input of %s conflicts with value at %s", - types.TypeString(g, nil), fset.Position(pv.Value().Pos)) + return nil, []error{fmt.Errorf("input of %s conflicts with value at %s", + types.TypeString(g, nil), fset.Position(pv.Value().Pos))} default: panic("unknown return value from ProviderSet.For") } @@ -126,10 +126,10 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide switch pv := set.For(curr.t); { case pv.IsNil(): if curr.from == nil { - return nil, fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(curr.t, nil)) + return nil, []error{fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(curr.t, nil))} } // TODO(light): Give name of provider. - return nil, fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(curr.t, nil), types.TypeString(curr.from, nil)) + return nil, []error{fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(curr.t, nil), types.TypeString(curr.from, nil))} case pv.IsProvider(): p := pv.Provider() if !types.Identical(p.Out, curr.t) { @@ -210,7 +210,7 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide // buildProviderMap creates the providerMap field for a given provider set. // The given provider set's providerMap field is ignored. -func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *ProviderSet) (*typeutil.Map, error) { +func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *ProviderSet) (*typeutil.Map, []error) { providerMap := new(typeutil.Map) providerMap.SetHasher(hasher) setMap := new(typeutil.Map) // to *ProviderSet, for error messages @@ -220,7 +220,7 @@ func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *Provider for _, imp := range set.Imports { for _, k := range imp.providerMap.Keys() { if providerMap.At(k) != nil { - return nil, bindingConflictError(fset, imp.Pos, k, setMap.At(k).(*ProviderSet)) + return nil, []error{bindingConflictError(fset, imp.Pos, k, setMap.At(k).(*ProviderSet))} } providerMap.Set(k, imp.providerMap.At(k)) setMap.Set(k, imp) @@ -230,14 +230,14 @@ func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *Provider // Process non-binding providers in new set. for _, p := range set.Providers { if providerMap.At(p.Out) != nil { - return nil, bindingConflictError(fset, p.Pos, p.Out, setMap.At(p.Out).(*ProviderSet)) + return nil, []error{bindingConflictError(fset, p.Pos, p.Out, setMap.At(p.Out).(*ProviderSet))} } providerMap.Set(p.Out, p) setMap.Set(p.Out, set) } for _, v := range set.Values { if providerMap.At(v.Out) != nil { - return nil, bindingConflictError(fset, v.Pos, v.Out, setMap.At(v.Out).(*ProviderSet)) + return nil, []error{bindingConflictError(fset, v.Pos, v.Out, setMap.At(v.Out).(*ProviderSet))} } providerMap.Set(v.Out, v) setMap.Set(v.Out, set) @@ -247,13 +247,13 @@ func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *Provider // ensure the concrete type is being provided. for _, b := range set.Bindings { if providerMap.At(b.Iface) != nil { - return nil, bindingConflictError(fset, b.Pos, b.Iface, setMap.At(b.Iface).(*ProviderSet)) + return nil, []error{bindingConflictError(fset, b.Pos, b.Iface, setMap.At(b.Iface).(*ProviderSet))} } concrete := providerMap.At(b.Provided) if concrete == nil { pos := fset.Position(b.Pos) typ := types.TypeString(b.Provided, nil) - return nil, fmt.Errorf("%v: no binding for %s", pos, typ) + return nil, []error{fmt.Errorf("%v: no binding for %s", pos, typ)} } providerMap.Set(b.Iface, concrete) setMap.Set(b.Iface, set) @@ -261,7 +261,7 @@ func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *Provider return providerMap, nil } -func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) error { +func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) []error { // We must visit every provider type inside provider map, but we don't // have a well-defined starting point and there may be several // distinct graphs. Thus, we start a depth-first search at every @@ -297,7 +297,7 @@ func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) error { fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.ImportPath, p.Name) } fmt.Fprintf(sb, "%s\n", types.TypeString(a, nil)) - return errors.New(sb.String()) + return []error{errors.New(sb.String())} } } next := append(append([]types.Type(nil), curr...), a) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 91144d7..32559a8 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -207,7 +207,7 @@ func (id ProviderSetID) String() string { // objectCache is a lazily evaluated mapping of objects to Wire structures. type objectCache struct { prog *loader.Program - objects map[objRef]interface{} // *Provider, *ProviderSet, *IfaceBinding, or *Value + objects map[objRef]objCacheEntry hasher typeutil.Hasher } @@ -216,10 +216,15 @@ type objRef struct { name string } +type objCacheEntry struct { + val interface{} // *Provider, *ProviderSet, *IfaceBinding, or *Value + errs []error +} + func newObjectCache(prog *loader.Program) *objectCache { return &objectCache{ prog: prog, - objects: make(map[objRef]interface{}), + objects: make(map[objRef]objCacheEntry), hasher: typeutil.MakeHasher(), } } @@ -227,22 +232,25 @@ func newObjectCache(prog *loader.Program) *objectCache { // get converts a Go object into a Wire structure. It may return a // *Provider, a structProviderPair, an *IfaceBinding, a *ProviderSet, // or a *Value. -func (oc *objectCache) get(obj types.Object) (interface{}, error) { +func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { ref := objRef{ importPath: obj.Pkg().Path(), name: obj.Name(), } - if val, cached := oc.objects[ref]; cached { - if val == nil { - return nil, fmt.Errorf("%v is not a provider or a provider set", obj) - } - return val, nil + if ent, cached := oc.objects[ref]; cached { + return ent.val, append([]error(nil), ent.errs...) } + defer func() { + oc.objects[ref] = objCacheEntry{ + val: val, + errs: append([]error(nil), errs...), + } + }() switch obj := obj.(type) { case *types.Var: spec := oc.varDecl(obj) if len(spec.Values) == 0 { - return nil, fmt.Errorf("%v is not a provider or a provider set", obj) + return nil, []error{fmt.Errorf("%v is not a provider or a provider set", obj)} } var i int for i = range spec.Names { @@ -252,16 +260,9 @@ func (oc *objectCache) get(obj types.Object) (interface{}, error) { } return oc.processExpr(oc.prog.Package(obj.Pkg().Path()), spec.Values[i]) case *types.Func: - p, err := processFuncProvider(oc.prog.Fset, obj) - if err != nil { - oc.objects[ref] = nil - return nil, err - } - oc.objects[ref] = p - return p, nil + return processFuncProvider(oc.prog.Fset, obj) default: - oc.objects[ref] = nil - return nil, fmt.Errorf("%v is not a provider or a provider set", obj) + return nil, []error{fmt.Errorf("%v is not a provider or a provider set", obj)} } } @@ -288,55 +289,63 @@ func (oc *objectCache) varDecl(obj *types.Var) *ast.ValueSpec { // processExpr converts an expression into a Wire structure. It may // return a *Provider, a structProviderPair, an *IfaceBinding, a // *ProviderSet, or a *Value. -func (oc *objectCache) processExpr(pkg *loader.PackageInfo, expr ast.Expr) (interface{}, error) { +func (oc *objectCache) processExpr(pkg *loader.PackageInfo, expr ast.Expr) (interface{}, []error) { exprPos := oc.prog.Fset.Position(expr.Pos()) expr = astutil.Unparen(expr) if obj := qualifiedIdentObject(&pkg.Info, expr); obj != nil { item, err := oc.get(obj) if err != nil { - return nil, fmt.Errorf("%v: %v", exprPos, err) + return nil, []error{fmt.Errorf("%v: %v", exprPos, err)} } return item, nil } if call, ok := expr.(*ast.CallExpr); ok { fnObj := qualifiedIdentObject(&pkg.Info, call.Fun) if fnObj == nil || !isWireImport(fnObj.Pkg().Path()) { - return nil, fmt.Errorf("%v: unknown pattern", exprPos) + return nil, []error{fmt.Errorf("%v: unknown pattern", exprPos)} } switch fnObj.Name() { case "NewSet": - pset, err := oc.processNewSet(pkg, call) - if err != nil { - return nil, fmt.Errorf("%v: %v", exprPos, err) + pset, errs := oc.processNewSet(pkg, call) + if len(errs) > 0 { + errs = append([]error(nil), errs...) + for i := range errs { + errs[i] = fmt.Errorf("%v: %v", exprPos, errs[i]) + } + return nil, errs } return pset, nil case "Bind": b, err := processBind(oc.prog.Fset, &pkg.Info, call) if err != nil { - return nil, fmt.Errorf("%v: %v", exprPos, err) + return nil, []error{fmt.Errorf("%v: %v", exprPos, err)} } return b, nil case "Value": v, err := processValue(oc.prog.Fset, &pkg.Info, call) if err != nil { - return nil, fmt.Errorf("%v: %v", exprPos, err) + return nil, []error{fmt.Errorf("%v: %v", exprPos, err)} } return v, nil default: - return nil, fmt.Errorf("%v: unknown pattern", exprPos) + return nil, []error{fmt.Errorf("%v: unknown pattern", exprPos)} } } if tn := structArgType(&pkg.Info, expr); tn != nil { - p, err := processStructProvider(oc.prog.Fset, tn) - if err != nil { - return nil, fmt.Errorf("%v: %v", exprPos, err) + p, errs := processStructProvider(oc.prog.Fset, tn) + if len(errs) > 0 { + errs = append([]error(nil), errs...) + for i := range errs { + errs[i] = fmt.Errorf("%v: %v", exprPos, errs[i]) + } + return nil, errs } ptrp := new(Provider) *ptrp = *p ptrp.Out = types.NewPointer(p.Out) return structProviderPair{p, ptrp}, nil } - return nil, fmt.Errorf("%v: unknown pattern", exprPos) + return nil, []error{fmt.Errorf("%v: unknown pattern", exprPos)} } type structProviderPair struct { @@ -344,7 +353,7 @@ type structProviderPair struct { ptrProvider *Provider } -func (oc *objectCache) processNewSet(pkg *loader.PackageInfo, call *ast.CallExpr) (*ProviderSet, error) { +func (oc *objectCache) processNewSet(pkg *loader.PackageInfo, call *ast.CallExpr) (*ProviderSet, []error) { // Assumes that call.Fun is wire.NewSet or wire.Build. pset := &ProviderSet{ @@ -371,13 +380,13 @@ func (oc *objectCache) processNewSet(pkg *loader.PackageInfo, call *ast.CallExpr panic("unknown item type") } } - var err error - pset.providerMap, err = buildProviderMap(oc.prog.Fset, oc.hasher, pset) - if err != nil { - return nil, err + var errs []error + pset.providerMap, errs = buildProviderMap(oc.prog.Fset, oc.hasher, pset) + if len(errs) > 0 { + return nil, errs } - if err := verifyAcyclic(pset.providerMap, oc.hasher); err != nil { - return nil, err + if errs := verifyAcyclic(pset.providerMap, oc.hasher); len(errs) > 0 { + return nil, errs } return pset, nil } @@ -420,12 +429,12 @@ func qualifiedIdentObject(info *types.Info, expr ast.Expr) types.Object { } // processFuncProvider creates a provider for a function declaration. -func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, error) { +func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, []error) { sig := fn.Type().(*types.Signature) fpos := fn.Pos() providerSig, err := funcOutput(sig) if err != nil { - return nil, fmt.Errorf("%v: wrong signature for provider %s: %v", fset.Position(fpos), fn.Name(), err) + return nil, []error{fmt.Errorf("%v: wrong signature for provider %s: %v", fset.Position(fpos), fn.Name(), err)} } params := sig.Params() provider := &Provider{ @@ -443,7 +452,7 @@ func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, error) } for j := 0; j < i; j++ { if types.Identical(provider.Args[i].Type, provider.Args[j].Type) { - return nil, fmt.Errorf("%v: provider has multiple parameters of type %s", fset.Position(fpos), types.TypeString(provider.Args[j].Type, nil)) + return nil, []error{fmt.Errorf("%v: provider has multiple parameters of type %s", fset.Position(fpos), types.TypeString(provider.Args[j].Type, nil))} } } } @@ -493,11 +502,11 @@ func funcOutput(sig *types.Signature) (outputSignature, error) { // processStructProvider creates a provider for a named struct type. // It only produces the non-pointer variant. -func processStructProvider(fset *token.FileSet, typeName *types.TypeName) (*Provider, error) { +func processStructProvider(fset *token.FileSet, typeName *types.TypeName) (*Provider, []error) { out := typeName.Type() st, ok := out.Underlying().(*types.Struct) if !ok { - return nil, fmt.Errorf("%v does not name a struct", typeName) + return nil, []error{fmt.Errorf("%v does not name a struct", typeName)} } pos := typeName.Pos() @@ -518,7 +527,7 @@ func processStructProvider(fset *token.FileSet, typeName *types.TypeName) (*Prov provider.Fields[i] = f.Name() for j := 0; j < i; j++ { if types.Identical(provider.Args[i].Type, provider.Args[j].Type) { - return nil, fmt.Errorf("%v: provider struct has multiple fields of type %s", fset.Position(pos), types.TypeString(provider.Args[j].Type, nil)) + return nil, []error{fmt.Errorf("%v: provider struct has multiple fields of type %s", fset.Position(pos), types.TypeString(provider.Args[j].Type, nil))} } } } diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 5acb5cb..72e6943 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -80,9 +80,9 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, []error) { } pkgInfo := prog.InitialPackages()[0] g := newGen(prog, pkgInfo.Pkg.Path()) - injectorFiles, err := generateInjectors(g, pkgInfo) - if err != nil { - return nil, []error{err} + injectorFiles, errs := generateInjectors(g, pkgInfo) + if len(errs) > 0 { + return nil, errs } copyNonInjectorDecls(g, injectorFiles, &pkgInfo.Info) goSrc := g.frame() @@ -96,7 +96,7 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, []error) { } // generateInjectors generates the injectors for a given package. -func generateInjectors(g *gen, pkgInfo *loader.PackageInfo) (injectorFiles []*ast.File, _ error) { +func generateInjectors(g *gen, pkgInfo *loader.PackageInfo) (injectorFiles []*ast.File, _ []error) { oc := newObjectCache(g.prog) injectorFiles = make([]*ast.File, 0, len(pkgInfo.Files)) for _, f := range pkgInfo.Files { @@ -116,13 +116,23 @@ func generateInjectors(g *gen, pkgInfo *loader.PackageInfo) (injectorFiles []*as g.p("// Injectors from %s:\n\n", name) injectorFiles = append(injectorFiles, f) } - set, err := oc.processNewSet(pkgInfo, buildCall) - if err != nil { - return nil, fmt.Errorf("%v: %v", g.prog.Fset.Position(fn.Pos()), err) + set, errs := oc.processNewSet(pkgInfo, buildCall) + if len(errs) > 0 { + position := g.prog.Fset.Position(fn.Pos()) + errs = append([]error(nil), errs...) + for i := range errs { + errs[i] = fmt.Errorf("%v: %v", position, errs[i]) + } + return nil, errs } sig := pkgInfo.ObjectOf(fn.Name).Type().(*types.Signature) - if err := g.inject(fn.Name.Name, sig, set); err != nil { - return nil, fmt.Errorf("%v: %v", g.prog.Fset.Position(fn.Pos()), err) + if errs := g.inject(fn.Name.Name, sig, set); len(errs) > 0 { + position := g.prog.Fset.Position(fn.Pos()) + errs = append([]error(nil), errs...) + for i := range errs { + errs[i] = fmt.Errorf("%v: %v", position, errs[i]) + } + return nil, errs } } } @@ -208,19 +218,19 @@ func (g *gen) frame() []byte { } // inject emits the code for an injector. -func (g *gen) inject(name string, sig *types.Signature, set *ProviderSet) error { +func (g *gen) inject(name string, sig *types.Signature, set *ProviderSet) []error { injectSig, err := funcOutput(sig) if err != nil { - return fmt.Errorf("inject %s: %v", name, err) + return []error{fmt.Errorf("inject %s: %v", name, err)} } params := sig.Params() given := make([]types.Type, params.Len()) for i := 0; i < params.Len(); i++ { given[i] = params.At(i).Type() } - calls, err := solve(g.prog.Fset, injectSig.out, given, set) - if err != nil { - return err + calls, errs := solve(g.prog.Fset, injectSig.out, given, set) + if len(errs) > 0 { + return errs } type pendingVar struct { name string @@ -231,16 +241,16 @@ func (g *gen) inject(name string, sig *types.Signature, set *ProviderSet) error for i := range calls { c := &calls[i] if c.hasCleanup && !injectSig.cleanup { - return fmt.Errorf("inject %s: provider for %s returns cleanup but injection does not return cleanup function", name, types.TypeString(c.out, nil)) + return []error{fmt.Errorf("inject %s: provider for %s returns cleanup but injection does not return cleanup function", name, types.TypeString(c.out, nil))} } if c.hasErr && !injectSig.err { - return fmt.Errorf("inject %s: provider for %s returns error but injection not allowed to fail", name, types.TypeString(c.out, nil)) + return []error{fmt.Errorf("inject %s: provider for %s returns error but injection not allowed to fail", name, types.TypeString(c.out, nil))} } if c.kind == valueExpr { if err := accessibleFrom(c.valueTypeInfo, c.valueExpr, g.currPackage); err != nil { // TODO(light): Display line number of value expression. ts := types.TypeString(c.out, nil) - return fmt.Errorf("inject %s: value %s can't be used: %v", name, ts, err) + return []error{fmt.Errorf("inject %s: value %s can't be used: %v", name, ts, err)} } if g.values[c.valueExpr] == "" { t := c.valueTypeInfo.TypeOf(c.valueExpr)