wire: update Provider.Out to be a slice of provided types, and keep track of the provided concrete type in ProviderSet.providerMap (google/go-cloud#332)

Update Provider.Out to be a slice of provided types, and keep track
of the provided concrete type in ProviderSet.providerMap, to more
clearly model-named struct providers (which provide both the struct
type and a pointer to the struct type).

Fixes google/go-cloud#325.
This commit is contained in:
Robert van Gent
2018-08-16 15:36:53 -07:00
committed by Ross Light
parent 86725a2b3f
commit e93f33129e
6 changed files with 88 additions and 78 deletions

View File

@@ -147,11 +147,11 @@ dfs:
p := pv.Provider()
src := set.srcMap.At(curr.t).(*providerSetSrc)
used = append(used, src)
if !types.Identical(p.Out, curr.t) {
if concrete := pv.ConcreteType(); !types.Identical(concrete, curr.t) {
// Interface binding. Don't create a call ourselves.
i := index.At(p.Out)
i := index.At(concrete)
if i == nil {
stk = append(stk, curr, frame{t: p.Out, from: curr.t})
stk = append(stk, curr, frame{t: concrete, from: curr.t})
continue
}
index.Set(curr.t, i)
@@ -323,20 +323,23 @@ func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *Provider
// Process non-binding providers in new set.
for _, p := range set.Providers {
if providerMap.At(p.Out) != nil {
ec.add(bindingConflictError(fset, p.Pos, p.Out, setMap.At(p.Out).(*ProviderSet)))
continue
src := &providerSetSrc{Provider: p}
for _, typ := range p.Out {
if providerMap.At(typ) != nil {
ec.add(bindingConflictError(fset, p.Pos, typ, setMap.At(typ).(*ProviderSet)))
continue
}
providerMap.Set(typ, &ProvidedType{t: typ, p: p})
srcMap.Set(typ, src)
setMap.Set(typ, set)
}
providerMap.Set(p.Out, p)
srcMap.Set(p.Out, &providerSetSrc{Provider: p})
setMap.Set(p.Out, set)
}
for _, v := range set.Values {
if providerMap.At(v.Out) != nil {
ec.add(bindingConflictError(fset, v.Pos, v.Out, setMap.At(v.Out).(*ProviderSet)))
continue
}
providerMap.Set(v.Out, v)
providerMap.Set(v.Out, &ProvidedType{t: v.Out, v: v})
srcMap.Set(v.Out, &providerSetSrc{Value: v})
setMap.Set(v.Out, set)
}
@@ -388,36 +391,40 @@ func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) []error {
continue
}
visited.Set(head, true)
switch x := providerMap.At(head).(type) {
case nil:
x := providerMap.At(head)
if x == nil {
// Leaf: input.
case *Value:
continue
}
pt := x.(*ProvidedType)
if pt.IsValue() {
// Leaf: values do not have dependencies.
case *Provider:
for _, arg := range x.Args {
a := arg.Type
hasCycle := false
for i, b := range curr {
if types.Identical(a, b) {
sb := new(strings.Builder)
fmt.Fprintf(sb, "cycle for %s:\n", types.TypeString(a, nil))
for j := i; j < len(curr); j++ {
p := providerMap.At(curr[j]).(*Provider)
fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.ImportPath, p.Name)
}
fmt.Fprintf(sb, "%s\n", types.TypeString(a, nil))
ec.add(errors.New(sb.String()))
hasCycle = true
break
continue
}
if !pt.IsProvider() {
panic("invalid provider map value")
}
for _, arg := range pt.Provider().Args {
a := arg.Type
hasCycle := false
for i, b := range curr {
if types.Identical(a, b) {
sb := new(strings.Builder)
fmt.Fprintf(sb, "cycle for %s:\n", types.TypeString(a, nil))
for j := i; j < len(curr); j++ {
p := providerMap.At(curr[j]).(*ProvidedType).Provider()
fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.ImportPath, p.Name)
}
}
if !hasCycle {
next := append(append([]types.Type(nil), curr...), a)
stk = append(stk, next)
fmt.Fprintf(sb, "%s\n", types.TypeString(a, nil))
ec.add(errors.New(sb.String()))
hasCycle = true
break
}
}
default:
panic("invalid provider map value")
if !hasCycle {
next := append(append([]types.Type(nil), curr...), a)
stk = append(stk, next)
}
}
}
}

View File

@@ -55,7 +55,7 @@ type ProviderSet struct {
Values []*Value
Imports []*ProviderSet
// providerMap maps from provided type to a *Provider or *Value.
// providerMap maps from provided type to a *ProvidedType.
// It includes all of the imported types.
providerMap *typeutil.Map
@@ -70,19 +70,13 @@ func (set *ProviderSet) Outputs() []types.Type {
return set.providerMap.Keys()
}
// For returns the provider or value for the given type, or the zero
// ProviderOrValue.
func (set *ProviderSet) For(t types.Type) ProviderOrValue {
switch x := set.providerMap.At(t).(type) {
case nil:
return ProviderOrValue{}
case *Provider:
return ProviderOrValue{p: x}
case *Value:
return ProviderOrValue{v: x}
default:
panic("invalid value in typeMap")
// For returns a ProvidedType for the given type, or the zero ProvidedType.
func (set *ProviderSet) For(t types.Type) ProvidedType {
pt := set.providerMap.At(t)
if pt == nil {
return ProvidedType{}
}
return *pt.(*ProvidedType)
}
// An IfaceBinding declares that a type should be used to satisfy inputs
@@ -122,8 +116,9 @@ type Provider struct {
// elements in Args.
Fields []string
// Out is the type this provider produces.
Out types.Type
// Out is the set of types this provider produces. It will always
// contain at least one type.
Out []types.Type
// HasCleanup reports whether the provider function returns a cleanup
// function. (Always false for structs.)
@@ -365,8 +360,7 @@ func newObjectCache(prog *loader.Program) *objectCache {
}
// get converts a Go object into a Wire structure. It may return a
// *Provider, a structProviderPair, an *IfaceBinding, a *ProviderSet,
// or a *Value.
// *Provider, an *IfaceBinding, a *ProviderSet, or a *Value.
func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) {
ref := objRef{
importPath: obj.Pkg().Path(),
@@ -422,8 +416,7 @@ func (oc *objectCache) varDecl(obj *types.Var) *ast.ValueSpec {
}
// processExpr converts an expression into a Wire structure. It may
// return a *Provider, a structProviderPair, an *IfaceBinding, a
// *ProviderSet, or a *Value.
// return a *Provider, an *IfaceBinding, a *ProviderSet, or a *Value.
func (oc *objectCache) processExpr(pkg *loader.PackageInfo, expr ast.Expr, varName string) (interface{}, []error) {
exprPos := oc.prog.Fset.Position(expr.Pos())
expr = astutil.Unparen(expr)
@@ -469,19 +462,11 @@ func (oc *objectCache) processExpr(pkg *loader.PackageInfo, expr ast.Expr, varNa
if len(errs) > 0 {
return nil, notePositionAll(exprPos, errs)
}
ptrp := new(Provider)
*ptrp = *p
ptrp.Out = types.NewPointer(p.Out)
return structProviderPair{p, ptrp}, nil
return p, nil
}
return nil, []error{notePosition(exprPos, errors.New("unknown pattern"))}
}
type structProviderPair struct {
provider *Provider
ptrProvider *Provider
}
func (oc *objectCache) processNewSet(pkg *loader.PackageInfo, call *ast.CallExpr, varName string) (*ProviderSet, []error) {
// Assumes that call.Fun is wire.NewSet or wire.Build.
@@ -504,8 +489,6 @@ func (oc *objectCache) processNewSet(pkg *loader.PackageInfo, call *ast.CallExpr
pset.Imports = append(pset.Imports, item)
case *IfaceBinding:
pset.Bindings = append(pset.Bindings, item)
case structProviderPair:
pset.Providers = append(pset.Providers, item.provider, item.ptrProvider)
case *Value:
pset.Values = append(pset.Values, item)
default:
@@ -577,7 +560,7 @@ func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, []erro
Name: fn.Name(),
Pos: fn.Pos(),
Args: make([]ProviderInput, params.Len()),
Out: providerSig.out,
Out: []types.Type{providerSig.out},
HasCleanup: providerSig.cleanup,
HasErr: providerSig.err,
}
@@ -649,7 +632,7 @@ func funcOutput(sig *types.Signature) (outputSignature, error) {
}
// processStructProvider creates a provider for a named struct type.
// It only produces the non-pointer variant.
// It produces pointer and non-pointer variants via two values in Out.
func processStructProvider(fset *token.FileSet, typeName *types.TypeName) (*Provider, []error) {
out := typeName.Type()
st, ok := out.Underlying().(*types.Struct)
@@ -665,7 +648,7 @@ func processStructProvider(fset *token.FileSet, typeName *types.TypeName) (*Prov
Args: make([]ProviderInput, st.NumFields()),
Fields: make([]string, st.NumFields()),
IsStruct: true,
Out: out,
Out: []types.Type{out, types.NewPointer(out)},
}
for i := 0; i < st.NumFields(); i++ {
f := st.Field(i)
@@ -854,31 +837,38 @@ func isProviderSetType(t types.Type) bool {
return obj.Pkg() != nil && isWireImport(obj.Pkg().Path()) && obj.Name() == "ProviderSet"
}
// ProviderOrValue is a pointer to a Provider or a Value. The zero value is
// a nil pointer.
type ProviderOrValue struct {
// ProvidedType is a pointer to a Provider or a Value. The zero value is
// a nil pointer. It also holds the concrete type that the Provider or Value
// provided.
type ProvidedType struct {
t types.Type
p *Provider
v *Value
}
// IsNil reports whether pv is the zero value.
func (pv ProviderOrValue) IsNil() bool {
func (pv ProvidedType) IsNil() bool {
return pv.p == nil && pv.v == nil
}
// ConcreteType returns the concrete type that was provided.
func (pv ProvidedType) ConcreteType() types.Type {
return pv.t
}
// IsProvider reports whether pv points to a Provider.
func (pv ProviderOrValue) IsProvider() bool {
func (pv ProvidedType) IsProvider() bool {
return pv.p != nil
}
// IsValue reports whether pv points to a Value.
func (pv ProviderOrValue) IsValue() bool {
func (pv ProvidedType) IsValue() bool {
return pv.v != nil
}
// Provider returns pv as a Provider pointer. It panics if pv points to a
// Value.
func (pv ProviderOrValue) Provider() *Provider {
func (pv ProvidedType) Provider() *Provider {
if pv.v != nil {
panic("Value pointer converted to a Provider")
}
@@ -887,7 +877,7 @@ func (pv ProviderOrValue) Provider() *Provider {
// Value returns pv as a Value pointer. It panics if pv points to a
// Provider.
func (pv ProviderOrValue) Value() *Value {
func (pv ProvidedType) Value() *Value {
if pv.p != nil {
panic("Provider pointer converted to a Value")
}

View File

@@ -22,7 +22,8 @@ import (
func main() {
fb := injectFooBar()
fmt.Println(fb.Foo, fb.Bar)
e := injectEmptyStruct()
fmt.Printf("%d %d %v\n", fb.Foo, fb.Bar, e)
}
type Foo int
@@ -33,6 +34,8 @@ type FooBar struct {
Bar Bar
}
type Empty struct{}
func provideFoo() Foo {
return 41
}

View File

@@ -24,3 +24,8 @@ func injectFooBar() *FooBar {
wire.Build(Set)
return nil
}
func injectEmptyStruct() *Empty {
wire.Build(Empty{})
return nil
}

View File

@@ -1 +1 @@
41 1
41 1 &{}

View File

@@ -16,3 +16,8 @@ func injectFooBar() *FooBar {
}
return fooBar
}
func injectEmptyStruct() *Empty {
empty := &Empty{}
return empty
}