wire: update Provider.Out to be a slice of provided types, and keep track of the provided concrete type in ProviderSet.providerMap (google/go-cloud#332)
Update Provider.Out to be a slice of provided types, and keep track of the provided concrete type in ProviderSet.providerMap, to more clearly model-named struct providers (which provide both the struct type and a pointer to the struct type). Fixes google/go-cloud#325.
This commit is contained in:
committed by
Ross Light
parent
86725a2b3f
commit
e93f33129e
@@ -147,11 +147,11 @@ dfs:
|
|||||||
p := pv.Provider()
|
p := pv.Provider()
|
||||||
src := set.srcMap.At(curr.t).(*providerSetSrc)
|
src := set.srcMap.At(curr.t).(*providerSetSrc)
|
||||||
used = append(used, src)
|
used = append(used, src)
|
||||||
if !types.Identical(p.Out, curr.t) {
|
if concrete := pv.ConcreteType(); !types.Identical(concrete, curr.t) {
|
||||||
// Interface binding. Don't create a call ourselves.
|
// Interface binding. Don't create a call ourselves.
|
||||||
i := index.At(p.Out)
|
i := index.At(concrete)
|
||||||
if i == nil {
|
if i == nil {
|
||||||
stk = append(stk, curr, frame{t: p.Out, from: curr.t})
|
stk = append(stk, curr, frame{t: concrete, from: curr.t})
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
index.Set(curr.t, i)
|
index.Set(curr.t, i)
|
||||||
@@ -323,20 +323,23 @@ func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *Provider
|
|||||||
|
|
||||||
// Process non-binding providers in new set.
|
// Process non-binding providers in new set.
|
||||||
for _, p := range set.Providers {
|
for _, p := range set.Providers {
|
||||||
if providerMap.At(p.Out) != nil {
|
src := &providerSetSrc{Provider: p}
|
||||||
ec.add(bindingConflictError(fset, p.Pos, p.Out, setMap.At(p.Out).(*ProviderSet)))
|
for _, typ := range p.Out {
|
||||||
continue
|
if providerMap.At(typ) != nil {
|
||||||
|
ec.add(bindingConflictError(fset, p.Pos, typ, setMap.At(typ).(*ProviderSet)))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
providerMap.Set(typ, &ProvidedType{t: typ, p: p})
|
||||||
|
srcMap.Set(typ, src)
|
||||||
|
setMap.Set(typ, set)
|
||||||
}
|
}
|
||||||
providerMap.Set(p.Out, p)
|
|
||||||
srcMap.Set(p.Out, &providerSetSrc{Provider: p})
|
|
||||||
setMap.Set(p.Out, set)
|
|
||||||
}
|
}
|
||||||
for _, v := range set.Values {
|
for _, v := range set.Values {
|
||||||
if providerMap.At(v.Out) != nil {
|
if providerMap.At(v.Out) != nil {
|
||||||
ec.add(bindingConflictError(fset, v.Pos, v.Out, setMap.At(v.Out).(*ProviderSet)))
|
ec.add(bindingConflictError(fset, v.Pos, v.Out, setMap.At(v.Out).(*ProviderSet)))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
providerMap.Set(v.Out, v)
|
providerMap.Set(v.Out, &ProvidedType{t: v.Out, v: v})
|
||||||
srcMap.Set(v.Out, &providerSetSrc{Value: v})
|
srcMap.Set(v.Out, &providerSetSrc{Value: v})
|
||||||
setMap.Set(v.Out, set)
|
setMap.Set(v.Out, set)
|
||||||
}
|
}
|
||||||
@@ -388,36 +391,40 @@ func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) []error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
visited.Set(head, true)
|
visited.Set(head, true)
|
||||||
switch x := providerMap.At(head).(type) {
|
x := providerMap.At(head)
|
||||||
case nil:
|
if x == nil {
|
||||||
// Leaf: input.
|
// Leaf: input.
|
||||||
case *Value:
|
continue
|
||||||
|
}
|
||||||
|
pt := x.(*ProvidedType)
|
||||||
|
if pt.IsValue() {
|
||||||
// Leaf: values do not have dependencies.
|
// Leaf: values do not have dependencies.
|
||||||
case *Provider:
|
continue
|
||||||
for _, arg := range x.Args {
|
}
|
||||||
a := arg.Type
|
if !pt.IsProvider() {
|
||||||
hasCycle := false
|
panic("invalid provider map value")
|
||||||
for i, b := range curr {
|
}
|
||||||
if types.Identical(a, b) {
|
for _, arg := range pt.Provider().Args {
|
||||||
sb := new(strings.Builder)
|
a := arg.Type
|
||||||
fmt.Fprintf(sb, "cycle for %s:\n", types.TypeString(a, nil))
|
hasCycle := false
|
||||||
for j := i; j < len(curr); j++ {
|
for i, b := range curr {
|
||||||
p := providerMap.At(curr[j]).(*Provider)
|
if types.Identical(a, b) {
|
||||||
fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.ImportPath, p.Name)
|
sb := new(strings.Builder)
|
||||||
}
|
fmt.Fprintf(sb, "cycle for %s:\n", types.TypeString(a, nil))
|
||||||
fmt.Fprintf(sb, "%s\n", types.TypeString(a, nil))
|
for j := i; j < len(curr); j++ {
|
||||||
ec.add(errors.New(sb.String()))
|
p := providerMap.At(curr[j]).(*ProvidedType).Provider()
|
||||||
hasCycle = true
|
fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.ImportPath, p.Name)
|
||||||
break
|
|
||||||
}
|
}
|
||||||
}
|
fmt.Fprintf(sb, "%s\n", types.TypeString(a, nil))
|
||||||
if !hasCycle {
|
ec.add(errors.New(sb.String()))
|
||||||
next := append(append([]types.Type(nil), curr...), a)
|
hasCycle = true
|
||||||
stk = append(stk, next)
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
default:
|
if !hasCycle {
|
||||||
panic("invalid provider map value")
|
next := append(append([]types.Type(nil), curr...), a)
|
||||||
|
stk = append(stk, next)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ type ProviderSet struct {
|
|||||||
Values []*Value
|
Values []*Value
|
||||||
Imports []*ProviderSet
|
Imports []*ProviderSet
|
||||||
|
|
||||||
// providerMap maps from provided type to a *Provider or *Value.
|
// providerMap maps from provided type to a *ProvidedType.
|
||||||
// It includes all of the imported types.
|
// It includes all of the imported types.
|
||||||
providerMap *typeutil.Map
|
providerMap *typeutil.Map
|
||||||
|
|
||||||
@@ -70,19 +70,13 @@ func (set *ProviderSet) Outputs() []types.Type {
|
|||||||
return set.providerMap.Keys()
|
return set.providerMap.Keys()
|
||||||
}
|
}
|
||||||
|
|
||||||
// For returns the provider or value for the given type, or the zero
|
// For returns a ProvidedType for the given type, or the zero ProvidedType.
|
||||||
// ProviderOrValue.
|
func (set *ProviderSet) For(t types.Type) ProvidedType {
|
||||||
func (set *ProviderSet) For(t types.Type) ProviderOrValue {
|
pt := set.providerMap.At(t)
|
||||||
switch x := set.providerMap.At(t).(type) {
|
if pt == nil {
|
||||||
case nil:
|
return ProvidedType{}
|
||||||
return ProviderOrValue{}
|
|
||||||
case *Provider:
|
|
||||||
return ProviderOrValue{p: x}
|
|
||||||
case *Value:
|
|
||||||
return ProviderOrValue{v: x}
|
|
||||||
default:
|
|
||||||
panic("invalid value in typeMap")
|
|
||||||
}
|
}
|
||||||
|
return *pt.(*ProvidedType)
|
||||||
}
|
}
|
||||||
|
|
||||||
// An IfaceBinding declares that a type should be used to satisfy inputs
|
// An IfaceBinding declares that a type should be used to satisfy inputs
|
||||||
@@ -122,8 +116,9 @@ type Provider struct {
|
|||||||
// elements in Args.
|
// elements in Args.
|
||||||
Fields []string
|
Fields []string
|
||||||
|
|
||||||
// Out is the type this provider produces.
|
// Out is the set of types this provider produces. It will always
|
||||||
Out types.Type
|
// contain at least one type.
|
||||||
|
Out []types.Type
|
||||||
|
|
||||||
// HasCleanup reports whether the provider function returns a cleanup
|
// HasCleanup reports whether the provider function returns a cleanup
|
||||||
// function. (Always false for structs.)
|
// function. (Always false for structs.)
|
||||||
@@ -365,8 +360,7 @@ func newObjectCache(prog *loader.Program) *objectCache {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// get converts a Go object into a Wire structure. It may return a
|
// get converts a Go object into a Wire structure. It may return a
|
||||||
// *Provider, a structProviderPair, an *IfaceBinding, a *ProviderSet,
|
// *Provider, an *IfaceBinding, a *ProviderSet, or a *Value.
|
||||||
// or a *Value.
|
|
||||||
func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) {
|
func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) {
|
||||||
ref := objRef{
|
ref := objRef{
|
||||||
importPath: obj.Pkg().Path(),
|
importPath: obj.Pkg().Path(),
|
||||||
@@ -422,8 +416,7 @@ func (oc *objectCache) varDecl(obj *types.Var) *ast.ValueSpec {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// processExpr converts an expression into a Wire structure. It may
|
// processExpr converts an expression into a Wire structure. It may
|
||||||
// return a *Provider, a structProviderPair, an *IfaceBinding, a
|
// return a *Provider, an *IfaceBinding, a *ProviderSet, or a *Value.
|
||||||
// *ProviderSet, or a *Value.
|
|
||||||
func (oc *objectCache) processExpr(pkg *loader.PackageInfo, expr ast.Expr, varName string) (interface{}, []error) {
|
func (oc *objectCache) processExpr(pkg *loader.PackageInfo, expr ast.Expr, varName string) (interface{}, []error) {
|
||||||
exprPos := oc.prog.Fset.Position(expr.Pos())
|
exprPos := oc.prog.Fset.Position(expr.Pos())
|
||||||
expr = astutil.Unparen(expr)
|
expr = astutil.Unparen(expr)
|
||||||
@@ -469,19 +462,11 @@ func (oc *objectCache) processExpr(pkg *loader.PackageInfo, expr ast.Expr, varNa
|
|||||||
if len(errs) > 0 {
|
if len(errs) > 0 {
|
||||||
return nil, notePositionAll(exprPos, errs)
|
return nil, notePositionAll(exprPos, errs)
|
||||||
}
|
}
|
||||||
ptrp := new(Provider)
|
return p, nil
|
||||||
*ptrp = *p
|
|
||||||
ptrp.Out = types.NewPointer(p.Out)
|
|
||||||
return structProviderPair{p, ptrp}, nil
|
|
||||||
}
|
}
|
||||||
return nil, []error{notePosition(exprPos, errors.New("unknown pattern"))}
|
return nil, []error{notePosition(exprPos, errors.New("unknown pattern"))}
|
||||||
}
|
}
|
||||||
|
|
||||||
type structProviderPair struct {
|
|
||||||
provider *Provider
|
|
||||||
ptrProvider *Provider
|
|
||||||
}
|
|
||||||
|
|
||||||
func (oc *objectCache) processNewSet(pkg *loader.PackageInfo, call *ast.CallExpr, varName string) (*ProviderSet, []error) {
|
func (oc *objectCache) processNewSet(pkg *loader.PackageInfo, call *ast.CallExpr, varName string) (*ProviderSet, []error) {
|
||||||
// Assumes that call.Fun is wire.NewSet or wire.Build.
|
// Assumes that call.Fun is wire.NewSet or wire.Build.
|
||||||
|
|
||||||
@@ -504,8 +489,6 @@ func (oc *objectCache) processNewSet(pkg *loader.PackageInfo, call *ast.CallExpr
|
|||||||
pset.Imports = append(pset.Imports, item)
|
pset.Imports = append(pset.Imports, item)
|
||||||
case *IfaceBinding:
|
case *IfaceBinding:
|
||||||
pset.Bindings = append(pset.Bindings, item)
|
pset.Bindings = append(pset.Bindings, item)
|
||||||
case structProviderPair:
|
|
||||||
pset.Providers = append(pset.Providers, item.provider, item.ptrProvider)
|
|
||||||
case *Value:
|
case *Value:
|
||||||
pset.Values = append(pset.Values, item)
|
pset.Values = append(pset.Values, item)
|
||||||
default:
|
default:
|
||||||
@@ -577,7 +560,7 @@ func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, []erro
|
|||||||
Name: fn.Name(),
|
Name: fn.Name(),
|
||||||
Pos: fn.Pos(),
|
Pos: fn.Pos(),
|
||||||
Args: make([]ProviderInput, params.Len()),
|
Args: make([]ProviderInput, params.Len()),
|
||||||
Out: providerSig.out,
|
Out: []types.Type{providerSig.out},
|
||||||
HasCleanup: providerSig.cleanup,
|
HasCleanup: providerSig.cleanup,
|
||||||
HasErr: providerSig.err,
|
HasErr: providerSig.err,
|
||||||
}
|
}
|
||||||
@@ -649,7 +632,7 @@ func funcOutput(sig *types.Signature) (outputSignature, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// processStructProvider creates a provider for a named struct type.
|
// processStructProvider creates a provider for a named struct type.
|
||||||
// It only produces the non-pointer variant.
|
// It produces pointer and non-pointer variants via two values in Out.
|
||||||
func processStructProvider(fset *token.FileSet, typeName *types.TypeName) (*Provider, []error) {
|
func processStructProvider(fset *token.FileSet, typeName *types.TypeName) (*Provider, []error) {
|
||||||
out := typeName.Type()
|
out := typeName.Type()
|
||||||
st, ok := out.Underlying().(*types.Struct)
|
st, ok := out.Underlying().(*types.Struct)
|
||||||
@@ -665,7 +648,7 @@ func processStructProvider(fset *token.FileSet, typeName *types.TypeName) (*Prov
|
|||||||
Args: make([]ProviderInput, st.NumFields()),
|
Args: make([]ProviderInput, st.NumFields()),
|
||||||
Fields: make([]string, st.NumFields()),
|
Fields: make([]string, st.NumFields()),
|
||||||
IsStruct: true,
|
IsStruct: true,
|
||||||
Out: out,
|
Out: []types.Type{out, types.NewPointer(out)},
|
||||||
}
|
}
|
||||||
for i := 0; i < st.NumFields(); i++ {
|
for i := 0; i < st.NumFields(); i++ {
|
||||||
f := st.Field(i)
|
f := st.Field(i)
|
||||||
@@ -854,31 +837,38 @@ func isProviderSetType(t types.Type) bool {
|
|||||||
return obj.Pkg() != nil && isWireImport(obj.Pkg().Path()) && obj.Name() == "ProviderSet"
|
return obj.Pkg() != nil && isWireImport(obj.Pkg().Path()) && obj.Name() == "ProviderSet"
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProviderOrValue is a pointer to a Provider or a Value. The zero value is
|
// ProvidedType is a pointer to a Provider or a Value. The zero value is
|
||||||
// a nil pointer.
|
// a nil pointer. It also holds the concrete type that the Provider or Value
|
||||||
type ProviderOrValue struct {
|
// provided.
|
||||||
|
type ProvidedType struct {
|
||||||
|
t types.Type
|
||||||
p *Provider
|
p *Provider
|
||||||
v *Value
|
v *Value
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsNil reports whether pv is the zero value.
|
// IsNil reports whether pv is the zero value.
|
||||||
func (pv ProviderOrValue) IsNil() bool {
|
func (pv ProvidedType) IsNil() bool {
|
||||||
return pv.p == nil && pv.v == nil
|
return pv.p == nil && pv.v == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConcreteType returns the concrete type that was provided.
|
||||||
|
func (pv ProvidedType) ConcreteType() types.Type {
|
||||||
|
return pv.t
|
||||||
|
}
|
||||||
|
|
||||||
// IsProvider reports whether pv points to a Provider.
|
// IsProvider reports whether pv points to a Provider.
|
||||||
func (pv ProviderOrValue) IsProvider() bool {
|
func (pv ProvidedType) IsProvider() bool {
|
||||||
return pv.p != nil
|
return pv.p != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValue reports whether pv points to a Value.
|
// IsValue reports whether pv points to a Value.
|
||||||
func (pv ProviderOrValue) IsValue() bool {
|
func (pv ProvidedType) IsValue() bool {
|
||||||
return pv.v != nil
|
return pv.v != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Provider returns pv as a Provider pointer. It panics if pv points to a
|
// Provider returns pv as a Provider pointer. It panics if pv points to a
|
||||||
// Value.
|
// Value.
|
||||||
func (pv ProviderOrValue) Provider() *Provider {
|
func (pv ProvidedType) Provider() *Provider {
|
||||||
if pv.v != nil {
|
if pv.v != nil {
|
||||||
panic("Value pointer converted to a Provider")
|
panic("Value pointer converted to a Provider")
|
||||||
}
|
}
|
||||||
@@ -887,7 +877,7 @@ func (pv ProviderOrValue) Provider() *Provider {
|
|||||||
|
|
||||||
// Value returns pv as a Value pointer. It panics if pv points to a
|
// Value returns pv as a Value pointer. It panics if pv points to a
|
||||||
// Provider.
|
// Provider.
|
||||||
func (pv ProviderOrValue) Value() *Value {
|
func (pv ProvidedType) Value() *Value {
|
||||||
if pv.p != nil {
|
if pv.p != nil {
|
||||||
panic("Provider pointer converted to a Value")
|
panic("Provider pointer converted to a Value")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,7 +22,8 @@ import (
|
|||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
fb := injectFooBar()
|
fb := injectFooBar()
|
||||||
fmt.Println(fb.Foo, fb.Bar)
|
e := injectEmptyStruct()
|
||||||
|
fmt.Printf("%d %d %v\n", fb.Foo, fb.Bar, e)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Foo int
|
type Foo int
|
||||||
@@ -33,6 +34,8 @@ type FooBar struct {
|
|||||||
Bar Bar
|
Bar Bar
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Empty struct{}
|
||||||
|
|
||||||
func provideFoo() Foo {
|
func provideFoo() Foo {
|
||||||
return 41
|
return 41
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,3 +24,8 @@ func injectFooBar() *FooBar {
|
|||||||
wire.Build(Set)
|
wire.Build(Set)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func injectEmptyStruct() *Empty {
|
||||||
|
wire.Build(Empty{})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
41 1
|
41 1 &{}
|
||||||
|
|||||||
@@ -16,3 +16,8 @@ func injectFooBar() *FooBar {
|
|||||||
}
|
}
|
||||||
return fooBar
|
return fooBar
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func injectEmptyStruct() *Empty {
|
||||||
|
empty := &Empty{}
|
||||||
|
return empty
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user