diff --git a/README.md b/README.md index 429f415..118f866 100644 --- a/README.md +++ b/README.md @@ -260,9 +260,9 @@ var BarFooer = wire.NewSet( ``` The first argument to `wire.Bind` is a pointer to a value of the desired -interface type and the second argument is a zero value of the concrete type. An -interface binding does not necessarily need to have a provider in the same set -that provides the concrete type. +interface type and the second argument is a zero value of the concrete type. +Any set that includes an interface binding must also have a provider in the +same set that provides the concrete type. [type identity]: https://golang.org/ref/spec#Type_identity [return concrete types]: https://github.com/golang/go/wiki/CodeReviewComments#interfaces diff --git a/cmd/gowire/main.go b/cmd/gowire/main.go index b9c9ad3..09c6147 100644 --- a/cmd/gowire/main.go +++ b/cmd/gowire/main.go @@ -134,8 +134,6 @@ func show(pkgs ...string) error { out[types.TypeString(t, nil)] = v.Pos case *wire.Value: out[types.TypeString(t, nil)] = v.Pos - case *wire.IfaceBinding: - out[types.TypeString(t, nil)] = v.Pos default: panic("unreachable") } @@ -152,17 +150,17 @@ func show(pkgs ...string) error { type outGroup struct { name string inputs *typeutil.Map // values are not important - outputs *typeutil.Map // values are *wire.Provider, *wire.Value, or *wire.IfaceBinding + outputs *typeutil.Map // values are *wire.Provider or *wire.Value } // gather flattens a provider set into outputs grouped by the inputs // required to create them. As it flattens the provider set, it records // the visited named provider sets as imports. func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[string]struct{}) { + set := info.Sets[key] hash := typeutil.MakeHasher() - // Map types to providers and bindings. - pm := new(typeutil.Map) - pm.SetHasher(hash) + + // Find imports. next := []*wire.ProviderSet{info.Sets[key]} visited := make(map[*wire.ProviderSet]struct{}) imports = make(map[string]struct{}) @@ -176,15 +174,6 @@ func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[ if curr.Name != "" && !(curr.PkgPath == key.ImportPath && curr.Name == key.VarName) { imports[formatProviderSetName(curr.PkgPath, curr.Name)] = struct{}{} } - for _, p := range curr.Providers { - pm.Set(p.Out, p) - } - for _, b := range curr.Bindings { - pm.Set(b.Iface, b) - } - for _, v := range curr.Values { - pm.Set(v.Out, v) - } for _, imp := range curr.Imports { next = append(next, imp) } @@ -194,9 +183,8 @@ func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[ var groups []outGroup inputVisited := new(typeutil.Map) // values are int, indices into groups or -1 for input. inputVisited.SetHasher(hash) - pmKeys := pm.Keys() var stk []types.Type - for _, k := range pmKeys { + for _, k := range set.Outputs() { // Start a DFS by picking a random unvisited node. if inputVisited.At(k) == nil { stk = append(stk, k) @@ -210,12 +198,13 @@ func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[ if inputVisited.At(curr) != nil { continue } - switch p := pm.At(curr).(type) { - case nil: + switch pv := set.For(curr); { + case pv.IsNil(): // This is an input. inputVisited.Set(curr, -1) - case *wire.Provider: + case pv.IsProvider(): // Try to see if any args haven't been visited. + p := pv.Provider() allPresent := true for _, arg := range p.Args { if inputVisited.At(arg.Type) == nil { @@ -245,24 +234,25 @@ func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[ } for i := range groups { if sameTypeKeys(groups[i].inputs, in) { - groups[i].outputs.Set(p.Out, p) - inputVisited.Set(p.Out, i) + groups[i].outputs.Set(curr, p) + inputVisited.Set(curr, i) continue dfs } } out := new(typeutil.Map) out.SetHasher(hash) - out.Set(p.Out, p) - inputVisited.Set(p.Out, len(groups)) + out.Set(curr, p) + inputVisited.Set(curr, len(groups)) groups = append(groups, outGroup{ inputs: in, outputs: out, }) - case *wire.Value: + case pv.IsValue(): + v := pv.Value() for i := range groups { if groups[i].inputs.Len() == 0 { - groups[i].outputs.Set(p.Out, p) - inputVisited.Set(p.Out, i) + groups[i].outputs.Set(curr, v) + inputVisited.Set(curr, i) continue dfs } } @@ -270,40 +260,8 @@ func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[ in.SetHasher(hash) out := new(typeutil.Map) out.SetHasher(hash) - out.Set(p.Out, p) - inputVisited.Set(p.Out, len(groups)) - groups = append(groups, outGroup{ - inputs: in, - outputs: out, - }) - case *wire.IfaceBinding: - i, ok := inputVisited.At(p.Provided).(int) - if !ok { - stk = append(stk, curr, p.Provided) - continue dfs - } - if i != -1 { - groups[i].outputs.Set(p.Iface, p) - inputVisited.Set(p.Iface, i) - continue dfs - } - // Binding must be provided. Find or add a group. - for i := range groups { - if groups[i].inputs.Len() != 1 { - continue - } - if groups[i].inputs.At(p.Provided) != nil { - groups[i].outputs.Set(p.Iface, p) - inputVisited.Set(p.Iface, i) - continue dfs - } - } - in := new(typeutil.Map) - in.SetHasher(hash) - in.Set(p.Provided, true) - out := new(typeutil.Map) - out.SetHasher(hash) - out.Set(p.Iface, p) + out.Set(curr, v) + inputVisited.Set(curr, len(groups)) groups = append(groups, outGroup{ inputs: in, outputs: out, @@ -314,7 +272,7 @@ func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[ } } - // Name and sort groups + // Name and sort groups. for i := range groups { if groups[i].inputs.Len() == 0 { groups[i].name = "no inputs" diff --git a/internal/wire/analyze.go b/internal/wire/analyze.go index 7aeefe2..39c3de8 100644 --- a/internal/wire/analyze.go +++ b/internal/wire/analyze.go @@ -85,18 +85,22 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide } } } - providers, err := buildProviderMap(fset, set) - if err != nil { - return nil, err - } // Start building the mapping of type to local variable of the given type. // The first len(given) local variables are the given types. index := new(typeutil.Map) for i, g := range given { - if p := providers.At(g); p != nil { - pp := p.(*Provider) - return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.Name, fset.Position(pp.Pos)) + if pv := set.For(g); !pv.IsNil() { + switch { + case pv.IsProvider(): + return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", + types.TypeString(g, nil), pv.Provider().Name, fset.Position(pv.Provider().Pos)) + case pv.IsValue(): + return nil, fmt.Errorf("input of %s conflicts with value at %s", + types.TypeString(g, nil), fset.Position(pv.Value().Pos)) + default: + panic("unknown return value from ProviderSet.For") + } } index.Set(g, i) } @@ -118,14 +122,15 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide } } - switch p := providers.At(typ).(type) { - case nil: + switch pv := set.For(typ); { + case pv.IsNil(): if len(trail) == 1 { return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil)) } // TODO(light): Give name of provider. return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2].Type, nil)) - case *Provider: + case pv.IsProvider(): + p := pv.Provider() if !types.Identical(p.Out, typ) { // Interface binding. Don't create a call ourselves. if err := visit(append(trail, ProviderInput{Type: p.Out})); err != nil { @@ -162,24 +167,25 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide hasCleanup: p.HasCleanup, hasErr: p.HasErr, }) - case *Value: - if !types.Identical(p.Out, typ) { + case pv.IsValue(): + v := pv.Value() + if !types.Identical(v.Out, typ) { // Interface binding. Don't create a call ourselves. - if err := visit(append(trail, ProviderInput{Type: p.Out})); err != nil { + if err := visit(append(trail, ProviderInput{Type: v.Out})); err != nil { return err } - index.Set(typ, index.At(p.Out)) + index.Set(typ, index.At(v.Out)) return nil } index.Set(typ, len(given)+len(calls)) calls = append(calls, call{ kind: valueExpr, out: typ, - valueExpr: p.expr, - valueTypeInfo: p.info, + valueExpr: v.expr, + valueTypeInfo: v.info, }) default: - panic("unknown provider map value type") + panic("unknown return value from ProviderSet.For") } return nil } @@ -189,52 +195,44 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide return calls, nil } -func buildProviderMap(fset *token.FileSet, set *ProviderSet) (*typeutil.Map, error) { - type binding struct { - *IfaceBinding - set *ProviderSet +// buildProviderMap creates the providerMap field for a given provider set. +// The given provider set's providerMap field is ignored. +func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *ProviderSet) (*typeutil.Map, error) { + providerMap := new(typeutil.Map) + providerMap.SetHasher(hasher) + setMap := new(typeutil.Map) // to *ProviderSet, for error messages + setMap.SetHasher(hasher) + + // Process imports first, verifying that there are no conflicts between sets. + for _, imp := range set.Imports { + for _, k := range imp.providerMap.Keys() { + if providerMap.At(k) != nil { + return nil, bindingConflictError(fset, imp.Pos, k, setMap.At(k).(*ProviderSet)) + } + providerMap.Set(k, imp.providerMap.At(k)) + setMap.Set(k, imp) + } } - providerMap := new(typeutil.Map) // to *Provider or *Value - setMap := new(typeutil.Map) // to *ProviderSet, for error messages - var bindings []binding - visited := make(map[*ProviderSet]struct{}) - next := []*ProviderSet{set} - for len(next) > 0 { - curr := next[0] - copy(next, next[1:]) - next = next[:len(next)-1] - if _, skip := visited[curr]; skip { - continue - } - visited[curr] = struct{}{} - for _, p := range curr.Providers { - if providerMap.At(p.Out) != nil { - return nil, bindingConflictError(fset, p.Pos, p.Out, setMap.At(p.Out).(*ProviderSet)) - } - providerMap.Set(p.Out, p) - setMap.Set(p.Out, curr) - } - for _, v := range curr.Values { - if providerMap.At(v.Out) != nil { - return nil, bindingConflictError(fset, v.Pos, v.Out, setMap.At(v.Out).(*ProviderSet)) - } - providerMap.Set(v.Out, v) - setMap.Set(v.Out, curr) - } - for _, b := range curr.Bindings { - bindings = append(bindings, binding{ - IfaceBinding: b, - set: curr, - }) - } - for _, imp := range curr.Imports { - next = append(next, imp) + // Process non-binding providers in new set. + for _, p := range set.Providers { + if providerMap.At(p.Out) != nil { + return nil, bindingConflictError(fset, p.Pos, p.Out, setMap.At(p.Out).(*ProviderSet)) } + providerMap.Set(p.Out, p) + setMap.Set(p.Out, set) } - // Validate that bindings have their concrete type provided in the set. - // TODO(light): Move this validation up into provider set creation. - for _, b := range bindings { + for _, v := range set.Values { + if providerMap.At(v.Out) != nil { + return nil, bindingConflictError(fset, v.Pos, v.Out, setMap.At(v.Out).(*ProviderSet)) + } + providerMap.Set(v.Out, v) + setMap.Set(v.Out, set) + } + + // Process bindings in set. Must happen after the other providers to + // ensure the concrete type is being provided. + for _, b := range set.Bindings { if providerMap.At(b.Iface) != nil { return nil, bindingConflictError(fset, b.Pos, b.Iface, setMap.At(b.Iface).(*ProviderSet)) } @@ -245,7 +243,7 @@ func buildProviderMap(fset *token.FileSet, set *ProviderSet) (*typeutil.Map, err return nil, fmt.Errorf("%v: no binding for %s", pos, typ) } providerMap.Set(b.Iface, concrete) - setMap.Set(b.Iface, b.set) + setMap.Set(b.Iface, set) } return providerMap, nil } diff --git a/internal/wire/parse.go b/internal/wire/parse.go index d2116fc..d05fe83 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -26,6 +26,7 @@ import ( "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/loader" + "golang.org/x/tools/go/types/typeutil" ) // A ProviderSet describes a set of providers. The zero value is an empty @@ -44,6 +45,31 @@ type ProviderSet struct { Bindings []*IfaceBinding Values []*Value Imports []*ProviderSet + + // providerMap maps from provided type to a *Provider or *Value. + // It includes all of the imported types. + providerMap *typeutil.Map +} + +// Outputs returns a new slice containing the set of possible types the +// provider set can produce. The order is unspecified. +func (set *ProviderSet) Outputs() []types.Type { + return set.providerMap.Keys() +} + +// For returns the provider or value for the given type, or the zero +// ProviderOrValue. +func (set *ProviderSet) For(t types.Type) ProviderOrValue { + switch x := set.providerMap.At(t).(type) { + case nil: + return ProviderOrValue{} + case *Provider: + return ProviderOrValue{p: x} + case *Value: + return ProviderOrValue{v: x} + default: + panic("invalid value in typeMap") + } } // An IfaceBinding declares that a type should be used to satisfy inputs @@ -181,6 +207,7 @@ func (id ProviderSetID) String() string { type objectCache struct { prog *loader.Program objects map[objRef]interface{} // *Provider, *ProviderSet, *IfaceBinding, or *Value + hasher typeutil.Hasher } type objRef struct { @@ -192,6 +219,7 @@ func newObjectCache(prog *loader.Program) *objectCache { return &objectCache{ prog: prog, objects: make(map[objRef]interface{}), + hasher: typeutil.MakeHasher(), } } @@ -342,6 +370,11 @@ func (oc *objectCache) processNewSet(pkg *loader.PackageInfo, call *ast.CallExpr panic("unknown item type") } } + var err error + pset.providerMap, err = buildProviderMap(oc.prog.Fset, oc.hasher, pset) + if err != nil { + return nil, err + } return pset, nil } @@ -618,3 +651,43 @@ func isWireImport(path string) bool { } return path == "github.com/google/go-cloud/wire" } + +// ProviderOrValue is a pointer to a Provider or a Value. The zero value is +// a nil pointer. +type ProviderOrValue struct { + p *Provider + v *Value +} + +// IsNil reports whether pv is the zero value. +func (pv ProviderOrValue) IsNil() bool { + return pv.p == nil && pv.v == nil +} + +// IsProvider reports whether pv points to a Provider. +func (pv ProviderOrValue) IsProvider() bool { + return pv.p != nil +} + +// IsValue reports whether pv points to a Value. +func (pv ProviderOrValue) IsValue() bool { + return pv.v != nil +} + +// Provider returns pv as a Provider pointer. It panics if pv points to a +// Value. +func (pv ProviderOrValue) Provider() *Provider { + if pv.v != nil { + panic("Value pointer converted to a Provider") + } + return pv.p +} + +// Provider returns pv as a Value pointer. It panics if pv points to a +// Provider. +func (pv ProviderOrValue) Value() *Value { + if pv.p != nil { + panic("Provider pointer converted to a Value") + } + return pv.v +} diff --git a/wire.go b/wire.go index 22a8208..9b44887 100644 --- a/wire.go +++ b/wire.go @@ -48,7 +48,17 @@ type Binding struct{} // // Example: // -// var MySet = wire.NewSet(wire.Bind(new(MyInterface), new(MyStruct))) +// type Fooer interface { +// Foo() +// } +// +// type MyFoo struct{} +// +// func (MyFoo) Foo() {} +// +// var MySet = wire.NewSet( +// MyFoo{}, +// wire.Bind(new(Fooer), new(MyFoo))) func Bind(iface, to interface{}) Binding { return Binding{} }