wire: build provider map incrementally (google/go-cloud#96)
One small breaking change: a provider set can no longer include an interface binding to a concrete type that is not being provided (directly or indirectly) by the provider set. I can't imagine a reasonable use case for the previous behavior, so this likely will catch more errors In terms of operation, binding conflict error messages will now give much more specific line numbers, since they will be reported closer to where the problem occurred. Now that provider sets gather this information, it can be exposed in the package API. gowire now uses this information instead of trying to build it itself. Fixes google/go-cloud#29
This commit is contained in:
@@ -134,8 +134,6 @@ func show(pkgs ...string) error {
|
||||
out[types.TypeString(t, nil)] = v.Pos
|
||||
case *wire.Value:
|
||||
out[types.TypeString(t, nil)] = v.Pos
|
||||
case *wire.IfaceBinding:
|
||||
out[types.TypeString(t, nil)] = v.Pos
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
@@ -152,17 +150,17 @@ func show(pkgs ...string) error {
|
||||
type outGroup struct {
|
||||
name string
|
||||
inputs *typeutil.Map // values are not important
|
||||
outputs *typeutil.Map // values are *wire.Provider, *wire.Value, or *wire.IfaceBinding
|
||||
outputs *typeutil.Map // values are *wire.Provider or *wire.Value
|
||||
}
|
||||
|
||||
// gather flattens a provider set into outputs grouped by the inputs
|
||||
// required to create them. As it flattens the provider set, it records
|
||||
// the visited named provider sets as imports.
|
||||
func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[string]struct{}) {
|
||||
set := info.Sets[key]
|
||||
hash := typeutil.MakeHasher()
|
||||
// Map types to providers and bindings.
|
||||
pm := new(typeutil.Map)
|
||||
pm.SetHasher(hash)
|
||||
|
||||
// Find imports.
|
||||
next := []*wire.ProviderSet{info.Sets[key]}
|
||||
visited := make(map[*wire.ProviderSet]struct{})
|
||||
imports = make(map[string]struct{})
|
||||
@@ -176,15 +174,6 @@ func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[
|
||||
if curr.Name != "" && !(curr.PkgPath == key.ImportPath && curr.Name == key.VarName) {
|
||||
imports[formatProviderSetName(curr.PkgPath, curr.Name)] = struct{}{}
|
||||
}
|
||||
for _, p := range curr.Providers {
|
||||
pm.Set(p.Out, p)
|
||||
}
|
||||
for _, b := range curr.Bindings {
|
||||
pm.Set(b.Iface, b)
|
||||
}
|
||||
for _, v := range curr.Values {
|
||||
pm.Set(v.Out, v)
|
||||
}
|
||||
for _, imp := range curr.Imports {
|
||||
next = append(next, imp)
|
||||
}
|
||||
@@ -194,9 +183,8 @@ func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[
|
||||
var groups []outGroup
|
||||
inputVisited := new(typeutil.Map) // values are int, indices into groups or -1 for input.
|
||||
inputVisited.SetHasher(hash)
|
||||
pmKeys := pm.Keys()
|
||||
var stk []types.Type
|
||||
for _, k := range pmKeys {
|
||||
for _, k := range set.Outputs() {
|
||||
// Start a DFS by picking a random unvisited node.
|
||||
if inputVisited.At(k) == nil {
|
||||
stk = append(stk, k)
|
||||
@@ -210,12 +198,13 @@ func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[
|
||||
if inputVisited.At(curr) != nil {
|
||||
continue
|
||||
}
|
||||
switch p := pm.At(curr).(type) {
|
||||
case nil:
|
||||
switch pv := set.For(curr); {
|
||||
case pv.IsNil():
|
||||
// This is an input.
|
||||
inputVisited.Set(curr, -1)
|
||||
case *wire.Provider:
|
||||
case pv.IsProvider():
|
||||
// Try to see if any args haven't been visited.
|
||||
p := pv.Provider()
|
||||
allPresent := true
|
||||
for _, arg := range p.Args {
|
||||
if inputVisited.At(arg.Type) == nil {
|
||||
@@ -245,24 +234,25 @@ func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[
|
||||
}
|
||||
for i := range groups {
|
||||
if sameTypeKeys(groups[i].inputs, in) {
|
||||
groups[i].outputs.Set(p.Out, p)
|
||||
inputVisited.Set(p.Out, i)
|
||||
groups[i].outputs.Set(curr, p)
|
||||
inputVisited.Set(curr, i)
|
||||
continue dfs
|
||||
}
|
||||
}
|
||||
out := new(typeutil.Map)
|
||||
out.SetHasher(hash)
|
||||
out.Set(p.Out, p)
|
||||
inputVisited.Set(p.Out, len(groups))
|
||||
out.Set(curr, p)
|
||||
inputVisited.Set(curr, len(groups))
|
||||
groups = append(groups, outGroup{
|
||||
inputs: in,
|
||||
outputs: out,
|
||||
})
|
||||
case *wire.Value:
|
||||
case pv.IsValue():
|
||||
v := pv.Value()
|
||||
for i := range groups {
|
||||
if groups[i].inputs.Len() == 0 {
|
||||
groups[i].outputs.Set(p.Out, p)
|
||||
inputVisited.Set(p.Out, i)
|
||||
groups[i].outputs.Set(curr, v)
|
||||
inputVisited.Set(curr, i)
|
||||
continue dfs
|
||||
}
|
||||
}
|
||||
@@ -270,40 +260,8 @@ func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[
|
||||
in.SetHasher(hash)
|
||||
out := new(typeutil.Map)
|
||||
out.SetHasher(hash)
|
||||
out.Set(p.Out, p)
|
||||
inputVisited.Set(p.Out, len(groups))
|
||||
groups = append(groups, outGroup{
|
||||
inputs: in,
|
||||
outputs: out,
|
||||
})
|
||||
case *wire.IfaceBinding:
|
||||
i, ok := inputVisited.At(p.Provided).(int)
|
||||
if !ok {
|
||||
stk = append(stk, curr, p.Provided)
|
||||
continue dfs
|
||||
}
|
||||
if i != -1 {
|
||||
groups[i].outputs.Set(p.Iface, p)
|
||||
inputVisited.Set(p.Iface, i)
|
||||
continue dfs
|
||||
}
|
||||
// Binding must be provided. Find or add a group.
|
||||
for i := range groups {
|
||||
if groups[i].inputs.Len() != 1 {
|
||||
continue
|
||||
}
|
||||
if groups[i].inputs.At(p.Provided) != nil {
|
||||
groups[i].outputs.Set(p.Iface, p)
|
||||
inputVisited.Set(p.Iface, i)
|
||||
continue dfs
|
||||
}
|
||||
}
|
||||
in := new(typeutil.Map)
|
||||
in.SetHasher(hash)
|
||||
in.Set(p.Provided, true)
|
||||
out := new(typeutil.Map)
|
||||
out.SetHasher(hash)
|
||||
out.Set(p.Iface, p)
|
||||
out.Set(curr, v)
|
||||
inputVisited.Set(curr, len(groups))
|
||||
groups = append(groups, outGroup{
|
||||
inputs: in,
|
||||
outputs: out,
|
||||
@@ -314,7 +272,7 @@ func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[
|
||||
}
|
||||
}
|
||||
|
||||
// Name and sort groups
|
||||
// Name and sort groups.
|
||||
for i := range groups {
|
||||
if groups[i].inputs.Len() == 0 {
|
||||
groups[i].name = "no inputs"
|
||||
|
||||
Reference in New Issue
Block a user