From e93f33129e0f0c61e9e257d52dcddd9d36f4f892 Mon Sep 17 00:00:00 2001 From: Robert van Gent Date: Thu, 16 Aug 2018 15:36:53 -0700 Subject: [PATCH] wire: update Provider.Out to be a slice of provided types, and keep track of the provided concrete type in ProviderSet.providerMap (google/go-cloud#332) Update Provider.Out to be a slice of provided types, and keep track of the provided concrete type in ProviderSet.providerMap, to more clearly model-named struct providers (which provide both the struct type and a pointer to the struct type). Fixes google/go-cloud#325. --- internal/wire/analyze.go | 77 ++++++++++--------- internal/wire/parse.go | 72 ++++++++--------- .../wire/testdata/StructPointer/foo/foo.go | 5 +- .../wire/testdata/StructPointer/foo/wire.go | 5 ++ .../StructPointer/want/program_out.txt | 2 +- .../testdata/StructPointer/want/wire_gen.go | 5 ++ 6 files changed, 88 insertions(+), 78 deletions(-) diff --git a/internal/wire/analyze.go b/internal/wire/analyze.go index 94bef96..7d93c06 100644 --- a/internal/wire/analyze.go +++ b/internal/wire/analyze.go @@ -147,11 +147,11 @@ dfs: p := pv.Provider() src := set.srcMap.At(curr.t).(*providerSetSrc) used = append(used, src) - if !types.Identical(p.Out, curr.t) { + if concrete := pv.ConcreteType(); !types.Identical(concrete, curr.t) { // Interface binding. Don't create a call ourselves. - i := index.At(p.Out) + i := index.At(concrete) if i == nil { - stk = append(stk, curr, frame{t: p.Out, from: curr.t}) + stk = append(stk, curr, frame{t: concrete, from: curr.t}) continue } index.Set(curr.t, i) @@ -323,20 +323,23 @@ func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *Provider // Process non-binding providers in new set. for _, p := range set.Providers { - if providerMap.At(p.Out) != nil { - ec.add(bindingConflictError(fset, p.Pos, p.Out, setMap.At(p.Out).(*ProviderSet))) - continue + src := &providerSetSrc{Provider: p} + for _, typ := range p.Out { + if providerMap.At(typ) != nil { + ec.add(bindingConflictError(fset, p.Pos, typ, setMap.At(typ).(*ProviderSet))) + continue + } + providerMap.Set(typ, &ProvidedType{t: typ, p: p}) + srcMap.Set(typ, src) + setMap.Set(typ, set) } - providerMap.Set(p.Out, p) - srcMap.Set(p.Out, &providerSetSrc{Provider: p}) - setMap.Set(p.Out, set) } for _, v := range set.Values { if providerMap.At(v.Out) != nil { ec.add(bindingConflictError(fset, v.Pos, v.Out, setMap.At(v.Out).(*ProviderSet))) continue } - providerMap.Set(v.Out, v) + providerMap.Set(v.Out, &ProvidedType{t: v.Out, v: v}) srcMap.Set(v.Out, &providerSetSrc{Value: v}) setMap.Set(v.Out, set) } @@ -388,36 +391,40 @@ func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) []error { continue } visited.Set(head, true) - switch x := providerMap.At(head).(type) { - case nil: + x := providerMap.At(head) + if x == nil { // Leaf: input. - case *Value: + continue + } + pt := x.(*ProvidedType) + if pt.IsValue() { // Leaf: values do not have dependencies. - case *Provider: - for _, arg := range x.Args { - a := arg.Type - hasCycle := false - for i, b := range curr { - if types.Identical(a, b) { - sb := new(strings.Builder) - fmt.Fprintf(sb, "cycle for %s:\n", types.TypeString(a, nil)) - for j := i; j < len(curr); j++ { - p := providerMap.At(curr[j]).(*Provider) - fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.ImportPath, p.Name) - } - fmt.Fprintf(sb, "%s\n", types.TypeString(a, nil)) - ec.add(errors.New(sb.String())) - hasCycle = true - break + continue + } + if !pt.IsProvider() { + panic("invalid provider map value") + } + for _, arg := range pt.Provider().Args { + a := arg.Type + hasCycle := false + for i, b := range curr { + if types.Identical(a, b) { + sb := new(strings.Builder) + fmt.Fprintf(sb, "cycle for %s:\n", types.TypeString(a, nil)) + for j := i; j < len(curr); j++ { + p := providerMap.At(curr[j]).(*ProvidedType).Provider() + fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.ImportPath, p.Name) } - } - if !hasCycle { - next := append(append([]types.Type(nil), curr...), a) - stk = append(stk, next) + fmt.Fprintf(sb, "%s\n", types.TypeString(a, nil)) + ec.add(errors.New(sb.String())) + hasCycle = true + break } } - default: - panic("invalid provider map value") + if !hasCycle { + next := append(append([]types.Type(nil), curr...), a) + stk = append(stk, next) + } } } } diff --git a/internal/wire/parse.go b/internal/wire/parse.go index e73f713..a1d939d 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -55,7 +55,7 @@ type ProviderSet struct { Values []*Value Imports []*ProviderSet - // providerMap maps from provided type to a *Provider or *Value. + // providerMap maps from provided type to a *ProvidedType. // It includes all of the imported types. providerMap *typeutil.Map @@ -70,19 +70,13 @@ func (set *ProviderSet) Outputs() []types.Type { return set.providerMap.Keys() } -// For returns the provider or value for the given type, or the zero -// ProviderOrValue. -func (set *ProviderSet) For(t types.Type) ProviderOrValue { - switch x := set.providerMap.At(t).(type) { - case nil: - return ProviderOrValue{} - case *Provider: - return ProviderOrValue{p: x} - case *Value: - return ProviderOrValue{v: x} - default: - panic("invalid value in typeMap") +// For returns a ProvidedType for the given type, or the zero ProvidedType. +func (set *ProviderSet) For(t types.Type) ProvidedType { + pt := set.providerMap.At(t) + if pt == nil { + return ProvidedType{} } + return *pt.(*ProvidedType) } // An IfaceBinding declares that a type should be used to satisfy inputs @@ -122,8 +116,9 @@ type Provider struct { // elements in Args. Fields []string - // Out is the type this provider produces. - Out types.Type + // Out is the set of types this provider produces. It will always + // contain at least one type. + Out []types.Type // HasCleanup reports whether the provider function returns a cleanup // function. (Always false for structs.) @@ -365,8 +360,7 @@ func newObjectCache(prog *loader.Program) *objectCache { } // get converts a Go object into a Wire structure. It may return a -// *Provider, a structProviderPair, an *IfaceBinding, a *ProviderSet, -// or a *Value. +// *Provider, an *IfaceBinding, a *ProviderSet, or a *Value. func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { ref := objRef{ importPath: obj.Pkg().Path(), @@ -422,8 +416,7 @@ func (oc *objectCache) varDecl(obj *types.Var) *ast.ValueSpec { } // processExpr converts an expression into a Wire structure. It may -// return a *Provider, a structProviderPair, an *IfaceBinding, a -// *ProviderSet, or a *Value. +// return a *Provider, an *IfaceBinding, a *ProviderSet, or a *Value. func (oc *objectCache) processExpr(pkg *loader.PackageInfo, expr ast.Expr, varName string) (interface{}, []error) { exprPos := oc.prog.Fset.Position(expr.Pos()) expr = astutil.Unparen(expr) @@ -469,19 +462,11 @@ func (oc *objectCache) processExpr(pkg *loader.PackageInfo, expr ast.Expr, varNa if len(errs) > 0 { return nil, notePositionAll(exprPos, errs) } - ptrp := new(Provider) - *ptrp = *p - ptrp.Out = types.NewPointer(p.Out) - return structProviderPair{p, ptrp}, nil + return p, nil } return nil, []error{notePosition(exprPos, errors.New("unknown pattern"))} } -type structProviderPair struct { - provider *Provider - ptrProvider *Provider -} - func (oc *objectCache) processNewSet(pkg *loader.PackageInfo, call *ast.CallExpr, varName string) (*ProviderSet, []error) { // Assumes that call.Fun is wire.NewSet or wire.Build. @@ -504,8 +489,6 @@ func (oc *objectCache) processNewSet(pkg *loader.PackageInfo, call *ast.CallExpr pset.Imports = append(pset.Imports, item) case *IfaceBinding: pset.Bindings = append(pset.Bindings, item) - case structProviderPair: - pset.Providers = append(pset.Providers, item.provider, item.ptrProvider) case *Value: pset.Values = append(pset.Values, item) default: @@ -577,7 +560,7 @@ func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, []erro Name: fn.Name(), Pos: fn.Pos(), Args: make([]ProviderInput, params.Len()), - Out: providerSig.out, + Out: []types.Type{providerSig.out}, HasCleanup: providerSig.cleanup, HasErr: providerSig.err, } @@ -649,7 +632,7 @@ func funcOutput(sig *types.Signature) (outputSignature, error) { } // processStructProvider creates a provider for a named struct type. -// It only produces the non-pointer variant. +// It produces pointer and non-pointer variants via two values in Out. func processStructProvider(fset *token.FileSet, typeName *types.TypeName) (*Provider, []error) { out := typeName.Type() st, ok := out.Underlying().(*types.Struct) @@ -665,7 +648,7 @@ func processStructProvider(fset *token.FileSet, typeName *types.TypeName) (*Prov Args: make([]ProviderInput, st.NumFields()), Fields: make([]string, st.NumFields()), IsStruct: true, - Out: out, + Out: []types.Type{out, types.NewPointer(out)}, } for i := 0; i < st.NumFields(); i++ { f := st.Field(i) @@ -854,31 +837,38 @@ func isProviderSetType(t types.Type) bool { return obj.Pkg() != nil && isWireImport(obj.Pkg().Path()) && obj.Name() == "ProviderSet" } -// ProviderOrValue is a pointer to a Provider or a Value. The zero value is -// a nil pointer. -type ProviderOrValue struct { +// ProvidedType is a pointer to a Provider or a Value. The zero value is +// a nil pointer. It also holds the concrete type that the Provider or Value +// provided. +type ProvidedType struct { + t types.Type p *Provider v *Value } // IsNil reports whether pv is the zero value. -func (pv ProviderOrValue) IsNil() bool { +func (pv ProvidedType) IsNil() bool { return pv.p == nil && pv.v == nil } +// ConcreteType returns the concrete type that was provided. +func (pv ProvidedType) ConcreteType() types.Type { + return pv.t +} + // IsProvider reports whether pv points to a Provider. -func (pv ProviderOrValue) IsProvider() bool { +func (pv ProvidedType) IsProvider() bool { return pv.p != nil } // IsValue reports whether pv points to a Value. -func (pv ProviderOrValue) IsValue() bool { +func (pv ProvidedType) IsValue() bool { return pv.v != nil } // Provider returns pv as a Provider pointer. It panics if pv points to a // Value. -func (pv ProviderOrValue) Provider() *Provider { +func (pv ProvidedType) Provider() *Provider { if pv.v != nil { panic("Value pointer converted to a Provider") } @@ -887,7 +877,7 @@ func (pv ProviderOrValue) Provider() *Provider { // Value returns pv as a Value pointer. It panics if pv points to a // Provider. -func (pv ProviderOrValue) Value() *Value { +func (pv ProvidedType) Value() *Value { if pv.p != nil { panic("Provider pointer converted to a Value") } diff --git a/internal/wire/testdata/StructPointer/foo/foo.go b/internal/wire/testdata/StructPointer/foo/foo.go index 87025d4..9fadf2d 100644 --- a/internal/wire/testdata/StructPointer/foo/foo.go +++ b/internal/wire/testdata/StructPointer/foo/foo.go @@ -22,7 +22,8 @@ import ( func main() { fb := injectFooBar() - fmt.Println(fb.Foo, fb.Bar) + e := injectEmptyStruct() + fmt.Printf("%d %d %v\n", fb.Foo, fb.Bar, e) } type Foo int @@ -33,6 +34,8 @@ type FooBar struct { Bar Bar } +type Empty struct{} + func provideFoo() Foo { return 41 } diff --git a/internal/wire/testdata/StructPointer/foo/wire.go b/internal/wire/testdata/StructPointer/foo/wire.go index 6e68d7b..a2d9bfe 100644 --- a/internal/wire/testdata/StructPointer/foo/wire.go +++ b/internal/wire/testdata/StructPointer/foo/wire.go @@ -24,3 +24,8 @@ func injectFooBar() *FooBar { wire.Build(Set) return nil } + +func injectEmptyStruct() *Empty { + wire.Build(Empty{}) + return nil +} diff --git a/internal/wire/testdata/StructPointer/want/program_out.txt b/internal/wire/testdata/StructPointer/want/program_out.txt index b1ae43f..a0875ea 100644 --- a/internal/wire/testdata/StructPointer/want/program_out.txt +++ b/internal/wire/testdata/StructPointer/want/program_out.txt @@ -1 +1 @@ -41 1 +41 1 &{} diff --git a/internal/wire/testdata/StructPointer/want/wire_gen.go b/internal/wire/testdata/StructPointer/want/wire_gen.go index 2ebb66d..c078331 100644 --- a/internal/wire/testdata/StructPointer/want/wire_gen.go +++ b/internal/wire/testdata/StructPointer/want/wire_gen.go @@ -16,3 +16,8 @@ func injectFooBar() *FooBar { } return fooBar } + +func injectEmptyStruct() *Empty { + empty := &Empty{} + return empty +}