wire: build provider map incrementally (google/go-cloud#96)

One small breaking change: a provider set can no longer include an
interface binding to a concrete type that is not being provided
(directly or indirectly) by the provider set. I can't imagine a
reasonable use case for the previous behavior, so this likely will
catch more errors

In terms of operation, binding conflict error messages will now give
much more specific line numbers, since they will be reported closer to
where the problem occurred.

Now that provider sets gather this information, it can be exposed in
the package API. gowire now uses this information instead of
trying to build it itself.

Fixes google/go-cloud#29
This commit is contained in:
Ross Light
2018-06-15 15:58:48 -07:00
parent b2d47f8fcc
commit b12449f9e3
5 changed files with 165 additions and 126 deletions

View File

@@ -260,9 +260,9 @@ var BarFooer = wire.NewSet(
``` ```
The first argument to `wire.Bind` is a pointer to a value of the desired The first argument to `wire.Bind` is a pointer to a value of the desired
interface type and the second argument is a zero value of the concrete type. An interface type and the second argument is a zero value of the concrete type.
interface binding does not necessarily need to have a provider in the same set Any set that includes an interface binding must also have a provider in the
that provides the concrete type. same set that provides the concrete type.
[type identity]: https://golang.org/ref/spec#Type_identity [type identity]: https://golang.org/ref/spec#Type_identity
[return concrete types]: https://github.com/golang/go/wiki/CodeReviewComments#interfaces [return concrete types]: https://github.com/golang/go/wiki/CodeReviewComments#interfaces

View File

@@ -134,8 +134,6 @@ func show(pkgs ...string) error {
out[types.TypeString(t, nil)] = v.Pos out[types.TypeString(t, nil)] = v.Pos
case *wire.Value: case *wire.Value:
out[types.TypeString(t, nil)] = v.Pos out[types.TypeString(t, nil)] = v.Pos
case *wire.IfaceBinding:
out[types.TypeString(t, nil)] = v.Pos
default: default:
panic("unreachable") panic("unreachable")
} }
@@ -152,17 +150,17 @@ func show(pkgs ...string) error {
type outGroup struct { type outGroup struct {
name string name string
inputs *typeutil.Map // values are not important inputs *typeutil.Map // values are not important
outputs *typeutil.Map // values are *wire.Provider, *wire.Value, or *wire.IfaceBinding outputs *typeutil.Map // values are *wire.Provider or *wire.Value
} }
// gather flattens a provider set into outputs grouped by the inputs // gather flattens a provider set into outputs grouped by the inputs
// required to create them. As it flattens the provider set, it records // required to create them. As it flattens the provider set, it records
// the visited named provider sets as imports. // the visited named provider sets as imports.
func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[string]struct{}) { func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[string]struct{}) {
set := info.Sets[key]
hash := typeutil.MakeHasher() hash := typeutil.MakeHasher()
// Map types to providers and bindings.
pm := new(typeutil.Map) // Find imports.
pm.SetHasher(hash)
next := []*wire.ProviderSet{info.Sets[key]} next := []*wire.ProviderSet{info.Sets[key]}
visited := make(map[*wire.ProviderSet]struct{}) visited := make(map[*wire.ProviderSet]struct{})
imports = make(map[string]struct{}) imports = make(map[string]struct{})
@@ -176,15 +174,6 @@ func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[
if curr.Name != "" && !(curr.PkgPath == key.ImportPath && curr.Name == key.VarName) { if curr.Name != "" && !(curr.PkgPath == key.ImportPath && curr.Name == key.VarName) {
imports[formatProviderSetName(curr.PkgPath, curr.Name)] = struct{}{} imports[formatProviderSetName(curr.PkgPath, curr.Name)] = struct{}{}
} }
for _, p := range curr.Providers {
pm.Set(p.Out, p)
}
for _, b := range curr.Bindings {
pm.Set(b.Iface, b)
}
for _, v := range curr.Values {
pm.Set(v.Out, v)
}
for _, imp := range curr.Imports { for _, imp := range curr.Imports {
next = append(next, imp) next = append(next, imp)
} }
@@ -194,9 +183,8 @@ func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[
var groups []outGroup var groups []outGroup
inputVisited := new(typeutil.Map) // values are int, indices into groups or -1 for input. inputVisited := new(typeutil.Map) // values are int, indices into groups or -1 for input.
inputVisited.SetHasher(hash) inputVisited.SetHasher(hash)
pmKeys := pm.Keys()
var stk []types.Type var stk []types.Type
for _, k := range pmKeys { for _, k := range set.Outputs() {
// Start a DFS by picking a random unvisited node. // Start a DFS by picking a random unvisited node.
if inputVisited.At(k) == nil { if inputVisited.At(k) == nil {
stk = append(stk, k) stk = append(stk, k)
@@ -210,12 +198,13 @@ func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[
if inputVisited.At(curr) != nil { if inputVisited.At(curr) != nil {
continue continue
} }
switch p := pm.At(curr).(type) { switch pv := set.For(curr); {
case nil: case pv.IsNil():
// This is an input. // This is an input.
inputVisited.Set(curr, -1) inputVisited.Set(curr, -1)
case *wire.Provider: case pv.IsProvider():
// Try to see if any args haven't been visited. // Try to see if any args haven't been visited.
p := pv.Provider()
allPresent := true allPresent := true
for _, arg := range p.Args { for _, arg := range p.Args {
if inputVisited.At(arg.Type) == nil { if inputVisited.At(arg.Type) == nil {
@@ -245,24 +234,25 @@ func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[
} }
for i := range groups { for i := range groups {
if sameTypeKeys(groups[i].inputs, in) { if sameTypeKeys(groups[i].inputs, in) {
groups[i].outputs.Set(p.Out, p) groups[i].outputs.Set(curr, p)
inputVisited.Set(p.Out, i) inputVisited.Set(curr, i)
continue dfs continue dfs
} }
} }
out := new(typeutil.Map) out := new(typeutil.Map)
out.SetHasher(hash) out.SetHasher(hash)
out.Set(p.Out, p) out.Set(curr, p)
inputVisited.Set(p.Out, len(groups)) inputVisited.Set(curr, len(groups))
groups = append(groups, outGroup{ groups = append(groups, outGroup{
inputs: in, inputs: in,
outputs: out, outputs: out,
}) })
case *wire.Value: case pv.IsValue():
v := pv.Value()
for i := range groups { for i := range groups {
if groups[i].inputs.Len() == 0 { if groups[i].inputs.Len() == 0 {
groups[i].outputs.Set(p.Out, p) groups[i].outputs.Set(curr, v)
inputVisited.Set(p.Out, i) inputVisited.Set(curr, i)
continue dfs continue dfs
} }
} }
@@ -270,40 +260,8 @@ func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[
in.SetHasher(hash) in.SetHasher(hash)
out := new(typeutil.Map) out := new(typeutil.Map)
out.SetHasher(hash) out.SetHasher(hash)
out.Set(p.Out, p) out.Set(curr, v)
inputVisited.Set(p.Out, len(groups)) inputVisited.Set(curr, len(groups))
groups = append(groups, outGroup{
inputs: in,
outputs: out,
})
case *wire.IfaceBinding:
i, ok := inputVisited.At(p.Provided).(int)
if !ok {
stk = append(stk, curr, p.Provided)
continue dfs
}
if i != -1 {
groups[i].outputs.Set(p.Iface, p)
inputVisited.Set(p.Iface, i)
continue dfs
}
// Binding must be provided. Find or add a group.
for i := range groups {
if groups[i].inputs.Len() != 1 {
continue
}
if groups[i].inputs.At(p.Provided) != nil {
groups[i].outputs.Set(p.Iface, p)
inputVisited.Set(p.Iface, i)
continue dfs
}
}
in := new(typeutil.Map)
in.SetHasher(hash)
in.Set(p.Provided, true)
out := new(typeutil.Map)
out.SetHasher(hash)
out.Set(p.Iface, p)
groups = append(groups, outGroup{ groups = append(groups, outGroup{
inputs: in, inputs: in,
outputs: out, outputs: out,
@@ -314,7 +272,7 @@ func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[
} }
} }
// Name and sort groups // Name and sort groups.
for i := range groups { for i := range groups {
if groups[i].inputs.Len() == 0 { if groups[i].inputs.Len() == 0 {
groups[i].name = "no inputs" groups[i].name = "no inputs"

View File

@@ -85,18 +85,22 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide
} }
} }
} }
providers, err := buildProviderMap(fset, set)
if err != nil {
return nil, err
}
// Start building the mapping of type to local variable of the given type. // Start building the mapping of type to local variable of the given type.
// The first len(given) local variables are the given types. // The first len(given) local variables are the given types.
index := new(typeutil.Map) index := new(typeutil.Map)
for i, g := range given { for i, g := range given {
if p := providers.At(g); p != nil { if pv := set.For(g); !pv.IsNil() {
pp := p.(*Provider) switch {
return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.Name, fset.Position(pp.Pos)) case pv.IsProvider():
return nil, fmt.Errorf("input of %s conflicts with provider %s at %s",
types.TypeString(g, nil), pv.Provider().Name, fset.Position(pv.Provider().Pos))
case pv.IsValue():
return nil, fmt.Errorf("input of %s conflicts with value at %s",
types.TypeString(g, nil), fset.Position(pv.Value().Pos))
default:
panic("unknown return value from ProviderSet.For")
}
} }
index.Set(g, i) index.Set(g, i)
} }
@@ -118,14 +122,15 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide
} }
} }
switch p := providers.At(typ).(type) { switch pv := set.For(typ); {
case nil: case pv.IsNil():
if len(trail) == 1 { if len(trail) == 1 {
return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil)) return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil))
} }
// TODO(light): Give name of provider. // TODO(light): Give name of provider.
return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2].Type, nil)) return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2].Type, nil))
case *Provider: case pv.IsProvider():
p := pv.Provider()
if !types.Identical(p.Out, typ) { if !types.Identical(p.Out, typ) {
// Interface binding. Don't create a call ourselves. // Interface binding. Don't create a call ourselves.
if err := visit(append(trail, ProviderInput{Type: p.Out})); err != nil { if err := visit(append(trail, ProviderInput{Type: p.Out})); err != nil {
@@ -162,24 +167,25 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide
hasCleanup: p.HasCleanup, hasCleanup: p.HasCleanup,
hasErr: p.HasErr, hasErr: p.HasErr,
}) })
case *Value: case pv.IsValue():
if !types.Identical(p.Out, typ) { v := pv.Value()
if !types.Identical(v.Out, typ) {
// Interface binding. Don't create a call ourselves. // Interface binding. Don't create a call ourselves.
if err := visit(append(trail, ProviderInput{Type: p.Out})); err != nil { if err := visit(append(trail, ProviderInput{Type: v.Out})); err != nil {
return err return err
} }
index.Set(typ, index.At(p.Out)) index.Set(typ, index.At(v.Out))
return nil return nil
} }
index.Set(typ, len(given)+len(calls)) index.Set(typ, len(given)+len(calls))
calls = append(calls, call{ calls = append(calls, call{
kind: valueExpr, kind: valueExpr,
out: typ, out: typ,
valueExpr: p.expr, valueExpr: v.expr,
valueTypeInfo: p.info, valueTypeInfo: v.info,
}) })
default: default:
panic("unknown provider map value type") panic("unknown return value from ProviderSet.For")
} }
return nil return nil
} }
@@ -189,52 +195,44 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide
return calls, nil return calls, nil
} }
func buildProviderMap(fset *token.FileSet, set *ProviderSet) (*typeutil.Map, error) { // buildProviderMap creates the providerMap field for a given provider set.
type binding struct { // The given provider set's providerMap field is ignored.
*IfaceBinding func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *ProviderSet) (*typeutil.Map, error) {
set *ProviderSet providerMap := new(typeutil.Map)
providerMap.SetHasher(hasher)
setMap := new(typeutil.Map) // to *ProviderSet, for error messages
setMap.SetHasher(hasher)
// Process imports first, verifying that there are no conflicts between sets.
for _, imp := range set.Imports {
for _, k := range imp.providerMap.Keys() {
if providerMap.At(k) != nil {
return nil, bindingConflictError(fset, imp.Pos, k, setMap.At(k).(*ProviderSet))
}
providerMap.Set(k, imp.providerMap.At(k))
setMap.Set(k, imp)
}
} }
providerMap := new(typeutil.Map) // to *Provider or *Value // Process non-binding providers in new set.
setMap := new(typeutil.Map) // to *ProviderSet, for error messages for _, p := range set.Providers {
var bindings []binding
visited := make(map[*ProviderSet]struct{})
next := []*ProviderSet{set}
for len(next) > 0 {
curr := next[0]
copy(next, next[1:])
next = next[:len(next)-1]
if _, skip := visited[curr]; skip {
continue
}
visited[curr] = struct{}{}
for _, p := range curr.Providers {
if providerMap.At(p.Out) != nil { if providerMap.At(p.Out) != nil {
return nil, bindingConflictError(fset, p.Pos, p.Out, setMap.At(p.Out).(*ProviderSet)) return nil, bindingConflictError(fset, p.Pos, p.Out, setMap.At(p.Out).(*ProviderSet))
} }
providerMap.Set(p.Out, p) providerMap.Set(p.Out, p)
setMap.Set(p.Out, curr) setMap.Set(p.Out, set)
} }
for _, v := range curr.Values { for _, v := range set.Values {
if providerMap.At(v.Out) != nil { if providerMap.At(v.Out) != nil {
return nil, bindingConflictError(fset, v.Pos, v.Out, setMap.At(v.Out).(*ProviderSet)) return nil, bindingConflictError(fset, v.Pos, v.Out, setMap.At(v.Out).(*ProviderSet))
} }
providerMap.Set(v.Out, v) providerMap.Set(v.Out, v)
setMap.Set(v.Out, curr) setMap.Set(v.Out, set)
} }
for _, b := range curr.Bindings {
bindings = append(bindings, binding{ // Process bindings in set. Must happen after the other providers to
IfaceBinding: b, // ensure the concrete type is being provided.
set: curr, for _, b := range set.Bindings {
})
}
for _, imp := range curr.Imports {
next = append(next, imp)
}
}
// Validate that bindings have their concrete type provided in the set.
// TODO(light): Move this validation up into provider set creation.
for _, b := range bindings {
if providerMap.At(b.Iface) != nil { if providerMap.At(b.Iface) != nil {
return nil, bindingConflictError(fset, b.Pos, b.Iface, setMap.At(b.Iface).(*ProviderSet)) return nil, bindingConflictError(fset, b.Pos, b.Iface, setMap.At(b.Iface).(*ProviderSet))
} }
@@ -245,7 +243,7 @@ func buildProviderMap(fset *token.FileSet, set *ProviderSet) (*typeutil.Map, err
return nil, fmt.Errorf("%v: no binding for %s", pos, typ) return nil, fmt.Errorf("%v: no binding for %s", pos, typ)
} }
providerMap.Set(b.Iface, concrete) providerMap.Set(b.Iface, concrete)
setMap.Set(b.Iface, b.set) setMap.Set(b.Iface, set)
} }
return providerMap, nil return providerMap, nil
} }

View File

@@ -26,6 +26,7 @@ import (
"golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/loader" "golang.org/x/tools/go/loader"
"golang.org/x/tools/go/types/typeutil"
) )
// A ProviderSet describes a set of providers. The zero value is an empty // A ProviderSet describes a set of providers. The zero value is an empty
@@ -44,6 +45,31 @@ type ProviderSet struct {
Bindings []*IfaceBinding Bindings []*IfaceBinding
Values []*Value Values []*Value
Imports []*ProviderSet Imports []*ProviderSet
// providerMap maps from provided type to a *Provider or *Value.
// It includes all of the imported types.
providerMap *typeutil.Map
}
// Outputs returns a new slice containing the set of possible types the
// provider set can produce. The order is unspecified.
func (set *ProviderSet) Outputs() []types.Type {
return set.providerMap.Keys()
}
// For returns the provider or value for the given type, or the zero
// ProviderOrValue.
func (set *ProviderSet) For(t types.Type) ProviderOrValue {
switch x := set.providerMap.At(t).(type) {
case nil:
return ProviderOrValue{}
case *Provider:
return ProviderOrValue{p: x}
case *Value:
return ProviderOrValue{v: x}
default:
panic("invalid value in typeMap")
}
} }
// An IfaceBinding declares that a type should be used to satisfy inputs // An IfaceBinding declares that a type should be used to satisfy inputs
@@ -181,6 +207,7 @@ func (id ProviderSetID) String() string {
type objectCache struct { type objectCache struct {
prog *loader.Program prog *loader.Program
objects map[objRef]interface{} // *Provider, *ProviderSet, *IfaceBinding, or *Value objects map[objRef]interface{} // *Provider, *ProviderSet, *IfaceBinding, or *Value
hasher typeutil.Hasher
} }
type objRef struct { type objRef struct {
@@ -192,6 +219,7 @@ func newObjectCache(prog *loader.Program) *objectCache {
return &objectCache{ return &objectCache{
prog: prog, prog: prog,
objects: make(map[objRef]interface{}), objects: make(map[objRef]interface{}),
hasher: typeutil.MakeHasher(),
} }
} }
@@ -342,6 +370,11 @@ func (oc *objectCache) processNewSet(pkg *loader.PackageInfo, call *ast.CallExpr
panic("unknown item type") panic("unknown item type")
} }
} }
var err error
pset.providerMap, err = buildProviderMap(oc.prog.Fset, oc.hasher, pset)
if err != nil {
return nil, err
}
return pset, nil return pset, nil
} }
@@ -618,3 +651,43 @@ func isWireImport(path string) bool {
} }
return path == "github.com/google/go-cloud/wire" return path == "github.com/google/go-cloud/wire"
} }
// ProviderOrValue is a pointer to a Provider or a Value. The zero value is
// a nil pointer.
type ProviderOrValue struct {
p *Provider
v *Value
}
// IsNil reports whether pv is the zero value.
func (pv ProviderOrValue) IsNil() bool {
return pv.p == nil && pv.v == nil
}
// IsProvider reports whether pv points to a Provider.
func (pv ProviderOrValue) IsProvider() bool {
return pv.p != nil
}
// IsValue reports whether pv points to a Value.
func (pv ProviderOrValue) IsValue() bool {
return pv.v != nil
}
// Provider returns pv as a Provider pointer. It panics if pv points to a
// Value.
func (pv ProviderOrValue) Provider() *Provider {
if pv.v != nil {
panic("Value pointer converted to a Provider")
}
return pv.p
}
// Provider returns pv as a Value pointer. It panics if pv points to a
// Provider.
func (pv ProviderOrValue) Value() *Value {
if pv.p != nil {
panic("Provider pointer converted to a Value")
}
return pv.v
}

12
wire.go
View File

@@ -48,7 +48,17 @@ type Binding struct{}
// //
// Example: // Example:
// //
// var MySet = wire.NewSet(wire.Bind(new(MyInterface), new(MyStruct))) // type Fooer interface {
// Foo()
// }
//
// type MyFoo struct{}
//
// func (MyFoo) Foo() {}
//
// var MySet = wire.NewSet(
// MyFoo{},
// wire.Bind(new(Fooer), new(MyFoo)))
func Bind(iface, to interface{}) Binding { func Bind(iface, to interface{}) Binding {
return Binding{} return Binding{}
} }