wire: update Provider.Out to be a slice of provided types, and keep track of the provided concrete type in ProviderSet.providerMap (google/go-cloud#332)

Update Provider.Out to be a slice of provided types, and keep track
of the provided concrete type in ProviderSet.providerMap, to more
clearly model-named struct providers (which provide both the struct
type and a pointer to the struct type).

Fixes google/go-cloud#325.
This commit is contained in:
Robert van Gent
2018-08-16 15:36:53 -07:00
committed by Ross Light
parent 86725a2b3f
commit e93f33129e
6 changed files with 88 additions and 78 deletions

View File

@@ -147,11 +147,11 @@ dfs:
p := pv.Provider() p := pv.Provider()
src := set.srcMap.At(curr.t).(*providerSetSrc) src := set.srcMap.At(curr.t).(*providerSetSrc)
used = append(used, src) used = append(used, src)
if !types.Identical(p.Out, curr.t) { if concrete := pv.ConcreteType(); !types.Identical(concrete, curr.t) {
// Interface binding. Don't create a call ourselves. // Interface binding. Don't create a call ourselves.
i := index.At(p.Out) i := index.At(concrete)
if i == nil { if i == nil {
stk = append(stk, curr, frame{t: p.Out, from: curr.t}) stk = append(stk, curr, frame{t: concrete, from: curr.t})
continue continue
} }
index.Set(curr.t, i) index.Set(curr.t, i)
@@ -323,20 +323,23 @@ func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *Provider
// Process non-binding providers in new set. // Process non-binding providers in new set.
for _, p := range set.Providers { for _, p := range set.Providers {
if providerMap.At(p.Out) != nil { src := &providerSetSrc{Provider: p}
ec.add(bindingConflictError(fset, p.Pos, p.Out, setMap.At(p.Out).(*ProviderSet))) for _, typ := range p.Out {
continue if providerMap.At(typ) != nil {
ec.add(bindingConflictError(fset, p.Pos, typ, setMap.At(typ).(*ProviderSet)))
continue
}
providerMap.Set(typ, &ProvidedType{t: typ, p: p})
srcMap.Set(typ, src)
setMap.Set(typ, set)
} }
providerMap.Set(p.Out, p)
srcMap.Set(p.Out, &providerSetSrc{Provider: p})
setMap.Set(p.Out, set)
} }
for _, v := range set.Values { for _, v := range set.Values {
if providerMap.At(v.Out) != nil { if providerMap.At(v.Out) != nil {
ec.add(bindingConflictError(fset, v.Pos, v.Out, setMap.At(v.Out).(*ProviderSet))) ec.add(bindingConflictError(fset, v.Pos, v.Out, setMap.At(v.Out).(*ProviderSet)))
continue continue
} }
providerMap.Set(v.Out, v) providerMap.Set(v.Out, &ProvidedType{t: v.Out, v: v})
srcMap.Set(v.Out, &providerSetSrc{Value: v}) srcMap.Set(v.Out, &providerSetSrc{Value: v})
setMap.Set(v.Out, set) setMap.Set(v.Out, set)
} }
@@ -388,36 +391,40 @@ func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) []error {
continue continue
} }
visited.Set(head, true) visited.Set(head, true)
switch x := providerMap.At(head).(type) { x := providerMap.At(head)
case nil: if x == nil {
// Leaf: input. // Leaf: input.
case *Value: continue
}
pt := x.(*ProvidedType)
if pt.IsValue() {
// Leaf: values do not have dependencies. // Leaf: values do not have dependencies.
case *Provider: continue
for _, arg := range x.Args { }
a := arg.Type if !pt.IsProvider() {
hasCycle := false panic("invalid provider map value")
for i, b := range curr { }
if types.Identical(a, b) { for _, arg := range pt.Provider().Args {
sb := new(strings.Builder) a := arg.Type
fmt.Fprintf(sb, "cycle for %s:\n", types.TypeString(a, nil)) hasCycle := false
for j := i; j < len(curr); j++ { for i, b := range curr {
p := providerMap.At(curr[j]).(*Provider) if types.Identical(a, b) {
fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.ImportPath, p.Name) sb := new(strings.Builder)
} fmt.Fprintf(sb, "cycle for %s:\n", types.TypeString(a, nil))
fmt.Fprintf(sb, "%s\n", types.TypeString(a, nil)) for j := i; j < len(curr); j++ {
ec.add(errors.New(sb.String())) p := providerMap.At(curr[j]).(*ProvidedType).Provider()
hasCycle = true fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.ImportPath, p.Name)
break
} }
} fmt.Fprintf(sb, "%s\n", types.TypeString(a, nil))
if !hasCycle { ec.add(errors.New(sb.String()))
next := append(append([]types.Type(nil), curr...), a) hasCycle = true
stk = append(stk, next) break
} }
} }
default: if !hasCycle {
panic("invalid provider map value") next := append(append([]types.Type(nil), curr...), a)
stk = append(stk, next)
}
} }
} }
} }

View File

@@ -55,7 +55,7 @@ type ProviderSet struct {
Values []*Value Values []*Value
Imports []*ProviderSet Imports []*ProviderSet
// providerMap maps from provided type to a *Provider or *Value. // providerMap maps from provided type to a *ProvidedType.
// It includes all of the imported types. // It includes all of the imported types.
providerMap *typeutil.Map providerMap *typeutil.Map
@@ -70,19 +70,13 @@ func (set *ProviderSet) Outputs() []types.Type {
return set.providerMap.Keys() return set.providerMap.Keys()
} }
// For returns the provider or value for the given type, or the zero // For returns a ProvidedType for the given type, or the zero ProvidedType.
// ProviderOrValue. func (set *ProviderSet) For(t types.Type) ProvidedType {
func (set *ProviderSet) For(t types.Type) ProviderOrValue { pt := set.providerMap.At(t)
switch x := set.providerMap.At(t).(type) { if pt == nil {
case nil: return ProvidedType{}
return ProviderOrValue{}
case *Provider:
return ProviderOrValue{p: x}
case *Value:
return ProviderOrValue{v: x}
default:
panic("invalid value in typeMap")
} }
return *pt.(*ProvidedType)
} }
// An IfaceBinding declares that a type should be used to satisfy inputs // An IfaceBinding declares that a type should be used to satisfy inputs
@@ -122,8 +116,9 @@ type Provider struct {
// elements in Args. // elements in Args.
Fields []string Fields []string
// Out is the type this provider produces. // Out is the set of types this provider produces. It will always
Out types.Type // contain at least one type.
Out []types.Type
// HasCleanup reports whether the provider function returns a cleanup // HasCleanup reports whether the provider function returns a cleanup
// function. (Always false for structs.) // function. (Always false for structs.)
@@ -365,8 +360,7 @@ func newObjectCache(prog *loader.Program) *objectCache {
} }
// get converts a Go object into a Wire structure. It may return a // get converts a Go object into a Wire structure. It may return a
// *Provider, a structProviderPair, an *IfaceBinding, a *ProviderSet, // *Provider, an *IfaceBinding, a *ProviderSet, or a *Value.
// or a *Value.
func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) {
ref := objRef{ ref := objRef{
importPath: obj.Pkg().Path(), importPath: obj.Pkg().Path(),
@@ -422,8 +416,7 @@ func (oc *objectCache) varDecl(obj *types.Var) *ast.ValueSpec {
} }
// processExpr converts an expression into a Wire structure. It may // processExpr converts an expression into a Wire structure. It may
// return a *Provider, a structProviderPair, an *IfaceBinding, a // return a *Provider, an *IfaceBinding, a *ProviderSet, or a *Value.
// *ProviderSet, or a *Value.
func (oc *objectCache) processExpr(pkg *loader.PackageInfo, expr ast.Expr, varName string) (interface{}, []error) { func (oc *objectCache) processExpr(pkg *loader.PackageInfo, expr ast.Expr, varName string) (interface{}, []error) {
exprPos := oc.prog.Fset.Position(expr.Pos()) exprPos := oc.prog.Fset.Position(expr.Pos())
expr = astutil.Unparen(expr) expr = astutil.Unparen(expr)
@@ -469,19 +462,11 @@ func (oc *objectCache) processExpr(pkg *loader.PackageInfo, expr ast.Expr, varNa
if len(errs) > 0 { if len(errs) > 0 {
return nil, notePositionAll(exprPos, errs) return nil, notePositionAll(exprPos, errs)
} }
ptrp := new(Provider) return p, nil
*ptrp = *p
ptrp.Out = types.NewPointer(p.Out)
return structProviderPair{p, ptrp}, nil
} }
return nil, []error{notePosition(exprPos, errors.New("unknown pattern"))} return nil, []error{notePosition(exprPos, errors.New("unknown pattern"))}
} }
type structProviderPair struct {
provider *Provider
ptrProvider *Provider
}
func (oc *objectCache) processNewSet(pkg *loader.PackageInfo, call *ast.CallExpr, varName string) (*ProviderSet, []error) { func (oc *objectCache) processNewSet(pkg *loader.PackageInfo, call *ast.CallExpr, varName string) (*ProviderSet, []error) {
// Assumes that call.Fun is wire.NewSet or wire.Build. // Assumes that call.Fun is wire.NewSet or wire.Build.
@@ -504,8 +489,6 @@ func (oc *objectCache) processNewSet(pkg *loader.PackageInfo, call *ast.CallExpr
pset.Imports = append(pset.Imports, item) pset.Imports = append(pset.Imports, item)
case *IfaceBinding: case *IfaceBinding:
pset.Bindings = append(pset.Bindings, item) pset.Bindings = append(pset.Bindings, item)
case structProviderPair:
pset.Providers = append(pset.Providers, item.provider, item.ptrProvider)
case *Value: case *Value:
pset.Values = append(pset.Values, item) pset.Values = append(pset.Values, item)
default: default:
@@ -577,7 +560,7 @@ func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, []erro
Name: fn.Name(), Name: fn.Name(),
Pos: fn.Pos(), Pos: fn.Pos(),
Args: make([]ProviderInput, params.Len()), Args: make([]ProviderInput, params.Len()),
Out: providerSig.out, Out: []types.Type{providerSig.out},
HasCleanup: providerSig.cleanup, HasCleanup: providerSig.cleanup,
HasErr: providerSig.err, HasErr: providerSig.err,
} }
@@ -649,7 +632,7 @@ func funcOutput(sig *types.Signature) (outputSignature, error) {
} }
// processStructProvider creates a provider for a named struct type. // processStructProvider creates a provider for a named struct type.
// It only produces the non-pointer variant. // It produces pointer and non-pointer variants via two values in Out.
func processStructProvider(fset *token.FileSet, typeName *types.TypeName) (*Provider, []error) { func processStructProvider(fset *token.FileSet, typeName *types.TypeName) (*Provider, []error) {
out := typeName.Type() out := typeName.Type()
st, ok := out.Underlying().(*types.Struct) st, ok := out.Underlying().(*types.Struct)
@@ -665,7 +648,7 @@ func processStructProvider(fset *token.FileSet, typeName *types.TypeName) (*Prov
Args: make([]ProviderInput, st.NumFields()), Args: make([]ProviderInput, st.NumFields()),
Fields: make([]string, st.NumFields()), Fields: make([]string, st.NumFields()),
IsStruct: true, IsStruct: true,
Out: out, Out: []types.Type{out, types.NewPointer(out)},
} }
for i := 0; i < st.NumFields(); i++ { for i := 0; i < st.NumFields(); i++ {
f := st.Field(i) f := st.Field(i)
@@ -854,31 +837,38 @@ func isProviderSetType(t types.Type) bool {
return obj.Pkg() != nil && isWireImport(obj.Pkg().Path()) && obj.Name() == "ProviderSet" return obj.Pkg() != nil && isWireImport(obj.Pkg().Path()) && obj.Name() == "ProviderSet"
} }
// ProviderOrValue is a pointer to a Provider or a Value. The zero value is // ProvidedType is a pointer to a Provider or a Value. The zero value is
// a nil pointer. // a nil pointer. It also holds the concrete type that the Provider or Value
type ProviderOrValue struct { // provided.
type ProvidedType struct {
t types.Type
p *Provider p *Provider
v *Value v *Value
} }
// IsNil reports whether pv is the zero value. // IsNil reports whether pv is the zero value.
func (pv ProviderOrValue) IsNil() bool { func (pv ProvidedType) IsNil() bool {
return pv.p == nil && pv.v == nil return pv.p == nil && pv.v == nil
} }
// ConcreteType returns the concrete type that was provided.
func (pv ProvidedType) ConcreteType() types.Type {
return pv.t
}
// IsProvider reports whether pv points to a Provider. // IsProvider reports whether pv points to a Provider.
func (pv ProviderOrValue) IsProvider() bool { func (pv ProvidedType) IsProvider() bool {
return pv.p != nil return pv.p != nil
} }
// IsValue reports whether pv points to a Value. // IsValue reports whether pv points to a Value.
func (pv ProviderOrValue) IsValue() bool { func (pv ProvidedType) IsValue() bool {
return pv.v != nil return pv.v != nil
} }
// Provider returns pv as a Provider pointer. It panics if pv points to a // Provider returns pv as a Provider pointer. It panics if pv points to a
// Value. // Value.
func (pv ProviderOrValue) Provider() *Provider { func (pv ProvidedType) Provider() *Provider {
if pv.v != nil { if pv.v != nil {
panic("Value pointer converted to a Provider") panic("Value pointer converted to a Provider")
} }
@@ -887,7 +877,7 @@ func (pv ProviderOrValue) Provider() *Provider {
// Value returns pv as a Value pointer. It panics if pv points to a // Value returns pv as a Value pointer. It panics if pv points to a
// Provider. // Provider.
func (pv ProviderOrValue) Value() *Value { func (pv ProvidedType) Value() *Value {
if pv.p != nil { if pv.p != nil {
panic("Provider pointer converted to a Value") panic("Provider pointer converted to a Value")
} }

View File

@@ -22,7 +22,8 @@ import (
func main() { func main() {
fb := injectFooBar() fb := injectFooBar()
fmt.Println(fb.Foo, fb.Bar) e := injectEmptyStruct()
fmt.Printf("%d %d %v\n", fb.Foo, fb.Bar, e)
} }
type Foo int type Foo int
@@ -33,6 +34,8 @@ type FooBar struct {
Bar Bar Bar Bar
} }
type Empty struct{}
func provideFoo() Foo { func provideFoo() Foo {
return 41 return 41
} }

View File

@@ -24,3 +24,8 @@ func injectFooBar() *FooBar {
wire.Build(Set) wire.Build(Set)
return nil return nil
} }
func injectEmptyStruct() *Empty {
wire.Build(Empty{})
return nil
}

View File

@@ -1 +1 @@
41 1 41 1 &{}

View File

@@ -16,3 +16,8 @@ func injectFooBar() *FooBar {
} }
return fooBar return fooBar
} }
func injectEmptyStruct() *Empty {
empty := &Empty{}
return empty
}