wire: build provider map incrementally (google/go-cloud#96)
One small breaking change: a provider set can no longer include an interface binding to a concrete type that is not being provided (directly or indirectly) by the provider set. I can't imagine a reasonable use case for the previous behavior, so this likely will catch more errors In terms of operation, binding conflict error messages will now give much more specific line numbers, since they will be reported closer to where the problem occurred. Now that provider sets gather this information, it can be exposed in the package API. gowire now uses this information instead of trying to build it itself. Fixes google/go-cloud#29
This commit is contained in:
@@ -85,18 +85,22 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide
|
||||
}
|
||||
}
|
||||
}
|
||||
providers, err := buildProviderMap(fset, set)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Start building the mapping of type to local variable of the given type.
|
||||
// The first len(given) local variables are the given types.
|
||||
index := new(typeutil.Map)
|
||||
for i, g := range given {
|
||||
if p := providers.At(g); p != nil {
|
||||
pp := p.(*Provider)
|
||||
return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.Name, fset.Position(pp.Pos))
|
||||
if pv := set.For(g); !pv.IsNil() {
|
||||
switch {
|
||||
case pv.IsProvider():
|
||||
return nil, fmt.Errorf("input of %s conflicts with provider %s at %s",
|
||||
types.TypeString(g, nil), pv.Provider().Name, fset.Position(pv.Provider().Pos))
|
||||
case pv.IsValue():
|
||||
return nil, fmt.Errorf("input of %s conflicts with value at %s",
|
||||
types.TypeString(g, nil), fset.Position(pv.Value().Pos))
|
||||
default:
|
||||
panic("unknown return value from ProviderSet.For")
|
||||
}
|
||||
}
|
||||
index.Set(g, i)
|
||||
}
|
||||
@@ -118,14 +122,15 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide
|
||||
}
|
||||
}
|
||||
|
||||
switch p := providers.At(typ).(type) {
|
||||
case nil:
|
||||
switch pv := set.For(typ); {
|
||||
case pv.IsNil():
|
||||
if len(trail) == 1 {
|
||||
return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil))
|
||||
}
|
||||
// TODO(light): Give name of provider.
|
||||
return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2].Type, nil))
|
||||
case *Provider:
|
||||
case pv.IsProvider():
|
||||
p := pv.Provider()
|
||||
if !types.Identical(p.Out, typ) {
|
||||
// Interface binding. Don't create a call ourselves.
|
||||
if err := visit(append(trail, ProviderInput{Type: p.Out})); err != nil {
|
||||
@@ -162,24 +167,25 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide
|
||||
hasCleanup: p.HasCleanup,
|
||||
hasErr: p.HasErr,
|
||||
})
|
||||
case *Value:
|
||||
if !types.Identical(p.Out, typ) {
|
||||
case pv.IsValue():
|
||||
v := pv.Value()
|
||||
if !types.Identical(v.Out, typ) {
|
||||
// Interface binding. Don't create a call ourselves.
|
||||
if err := visit(append(trail, ProviderInput{Type: p.Out})); err != nil {
|
||||
if err := visit(append(trail, ProviderInput{Type: v.Out})); err != nil {
|
||||
return err
|
||||
}
|
||||
index.Set(typ, index.At(p.Out))
|
||||
index.Set(typ, index.At(v.Out))
|
||||
return nil
|
||||
}
|
||||
index.Set(typ, len(given)+len(calls))
|
||||
calls = append(calls, call{
|
||||
kind: valueExpr,
|
||||
out: typ,
|
||||
valueExpr: p.expr,
|
||||
valueTypeInfo: p.info,
|
||||
valueExpr: v.expr,
|
||||
valueTypeInfo: v.info,
|
||||
})
|
||||
default:
|
||||
panic("unknown provider map value type")
|
||||
panic("unknown return value from ProviderSet.For")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -189,52 +195,44 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide
|
||||
return calls, nil
|
||||
}
|
||||
|
||||
func buildProviderMap(fset *token.FileSet, set *ProviderSet) (*typeutil.Map, error) {
|
||||
type binding struct {
|
||||
*IfaceBinding
|
||||
set *ProviderSet
|
||||
// buildProviderMap creates the providerMap field for a given provider set.
|
||||
// The given provider set's providerMap field is ignored.
|
||||
func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *ProviderSet) (*typeutil.Map, error) {
|
||||
providerMap := new(typeutil.Map)
|
||||
providerMap.SetHasher(hasher)
|
||||
setMap := new(typeutil.Map) // to *ProviderSet, for error messages
|
||||
setMap.SetHasher(hasher)
|
||||
|
||||
// Process imports first, verifying that there are no conflicts between sets.
|
||||
for _, imp := range set.Imports {
|
||||
for _, k := range imp.providerMap.Keys() {
|
||||
if providerMap.At(k) != nil {
|
||||
return nil, bindingConflictError(fset, imp.Pos, k, setMap.At(k).(*ProviderSet))
|
||||
}
|
||||
providerMap.Set(k, imp.providerMap.At(k))
|
||||
setMap.Set(k, imp)
|
||||
}
|
||||
}
|
||||
|
||||
providerMap := new(typeutil.Map) // to *Provider or *Value
|
||||
setMap := new(typeutil.Map) // to *ProviderSet, for error messages
|
||||
var bindings []binding
|
||||
visited := make(map[*ProviderSet]struct{})
|
||||
next := []*ProviderSet{set}
|
||||
for len(next) > 0 {
|
||||
curr := next[0]
|
||||
copy(next, next[1:])
|
||||
next = next[:len(next)-1]
|
||||
if _, skip := visited[curr]; skip {
|
||||
continue
|
||||
}
|
||||
visited[curr] = struct{}{}
|
||||
for _, p := range curr.Providers {
|
||||
if providerMap.At(p.Out) != nil {
|
||||
return nil, bindingConflictError(fset, p.Pos, p.Out, setMap.At(p.Out).(*ProviderSet))
|
||||
}
|
||||
providerMap.Set(p.Out, p)
|
||||
setMap.Set(p.Out, curr)
|
||||
}
|
||||
for _, v := range curr.Values {
|
||||
if providerMap.At(v.Out) != nil {
|
||||
return nil, bindingConflictError(fset, v.Pos, v.Out, setMap.At(v.Out).(*ProviderSet))
|
||||
}
|
||||
providerMap.Set(v.Out, v)
|
||||
setMap.Set(v.Out, curr)
|
||||
}
|
||||
for _, b := range curr.Bindings {
|
||||
bindings = append(bindings, binding{
|
||||
IfaceBinding: b,
|
||||
set: curr,
|
||||
})
|
||||
}
|
||||
for _, imp := range curr.Imports {
|
||||
next = append(next, imp)
|
||||
// Process non-binding providers in new set.
|
||||
for _, p := range set.Providers {
|
||||
if providerMap.At(p.Out) != nil {
|
||||
return nil, bindingConflictError(fset, p.Pos, p.Out, setMap.At(p.Out).(*ProviderSet))
|
||||
}
|
||||
providerMap.Set(p.Out, p)
|
||||
setMap.Set(p.Out, set)
|
||||
}
|
||||
// Validate that bindings have their concrete type provided in the set.
|
||||
// TODO(light): Move this validation up into provider set creation.
|
||||
for _, b := range bindings {
|
||||
for _, v := range set.Values {
|
||||
if providerMap.At(v.Out) != nil {
|
||||
return nil, bindingConflictError(fset, v.Pos, v.Out, setMap.At(v.Out).(*ProviderSet))
|
||||
}
|
||||
providerMap.Set(v.Out, v)
|
||||
setMap.Set(v.Out, set)
|
||||
}
|
||||
|
||||
// Process bindings in set. Must happen after the other providers to
|
||||
// ensure the concrete type is being provided.
|
||||
for _, b := range set.Bindings {
|
||||
if providerMap.At(b.Iface) != nil {
|
||||
return nil, bindingConflictError(fset, b.Pos, b.Iface, setMap.At(b.Iface).(*ProviderSet))
|
||||
}
|
||||
@@ -245,7 +243,7 @@ func buildProviderMap(fset *token.FileSet, set *ProviderSet) (*typeutil.Map, err
|
||||
return nil, fmt.Errorf("%v: no binding for %s", pos, typ)
|
||||
}
|
||||
providerMap.Set(b.Iface, concrete)
|
||||
setMap.Set(b.Iface, b.set)
|
||||
setMap.Set(b.Iface, set)
|
||||
}
|
||||
return providerMap, nil
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
|
||||
"golang.org/x/tools/go/ast/astutil"
|
||||
"golang.org/x/tools/go/loader"
|
||||
"golang.org/x/tools/go/types/typeutil"
|
||||
)
|
||||
|
||||
// A ProviderSet describes a set of providers. The zero value is an empty
|
||||
@@ -44,6 +45,31 @@ type ProviderSet struct {
|
||||
Bindings []*IfaceBinding
|
||||
Values []*Value
|
||||
Imports []*ProviderSet
|
||||
|
||||
// providerMap maps from provided type to a *Provider or *Value.
|
||||
// It includes all of the imported types.
|
||||
providerMap *typeutil.Map
|
||||
}
|
||||
|
||||
// Outputs returns a new slice containing the set of possible types the
|
||||
// provider set can produce. The order is unspecified.
|
||||
func (set *ProviderSet) Outputs() []types.Type {
|
||||
return set.providerMap.Keys()
|
||||
}
|
||||
|
||||
// For returns the provider or value for the given type, or the zero
|
||||
// ProviderOrValue.
|
||||
func (set *ProviderSet) For(t types.Type) ProviderOrValue {
|
||||
switch x := set.providerMap.At(t).(type) {
|
||||
case nil:
|
||||
return ProviderOrValue{}
|
||||
case *Provider:
|
||||
return ProviderOrValue{p: x}
|
||||
case *Value:
|
||||
return ProviderOrValue{v: x}
|
||||
default:
|
||||
panic("invalid value in typeMap")
|
||||
}
|
||||
}
|
||||
|
||||
// An IfaceBinding declares that a type should be used to satisfy inputs
|
||||
@@ -181,6 +207,7 @@ func (id ProviderSetID) String() string {
|
||||
type objectCache struct {
|
||||
prog *loader.Program
|
||||
objects map[objRef]interface{} // *Provider, *ProviderSet, *IfaceBinding, or *Value
|
||||
hasher typeutil.Hasher
|
||||
}
|
||||
|
||||
type objRef struct {
|
||||
@@ -192,6 +219,7 @@ func newObjectCache(prog *loader.Program) *objectCache {
|
||||
return &objectCache{
|
||||
prog: prog,
|
||||
objects: make(map[objRef]interface{}),
|
||||
hasher: typeutil.MakeHasher(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -342,6 +370,11 @@ func (oc *objectCache) processNewSet(pkg *loader.PackageInfo, call *ast.CallExpr
|
||||
panic("unknown item type")
|
||||
}
|
||||
}
|
||||
var err error
|
||||
pset.providerMap, err = buildProviderMap(oc.prog.Fset, oc.hasher, pset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return pset, nil
|
||||
}
|
||||
|
||||
@@ -618,3 +651,43 @@ func isWireImport(path string) bool {
|
||||
}
|
||||
return path == "github.com/google/go-cloud/wire"
|
||||
}
|
||||
|
||||
// ProviderOrValue is a pointer to a Provider or a Value. The zero value is
|
||||
// a nil pointer.
|
||||
type ProviderOrValue struct {
|
||||
p *Provider
|
||||
v *Value
|
||||
}
|
||||
|
||||
// IsNil reports whether pv is the zero value.
|
||||
func (pv ProviderOrValue) IsNil() bool {
|
||||
return pv.p == nil && pv.v == nil
|
||||
}
|
||||
|
||||
// IsProvider reports whether pv points to a Provider.
|
||||
func (pv ProviderOrValue) IsProvider() bool {
|
||||
return pv.p != nil
|
||||
}
|
||||
|
||||
// IsValue reports whether pv points to a Value.
|
||||
func (pv ProviderOrValue) IsValue() bool {
|
||||
return pv.v != nil
|
||||
}
|
||||
|
||||
// Provider returns pv as a Provider pointer. It panics if pv points to a
|
||||
// Value.
|
||||
func (pv ProviderOrValue) Provider() *Provider {
|
||||
if pv.v != nil {
|
||||
panic("Value pointer converted to a Provider")
|
||||
}
|
||||
return pv.p
|
||||
}
|
||||
|
||||
// Provider returns pv as a Value pointer. It panics if pv points to a
|
||||
// Provider.
|
||||
func (pv ProviderOrValue) Value() *Value {
|
||||
if pv.p != nil {
|
||||
panic("Provider pointer converted to a Value")
|
||||
}
|
||||
return pv.v
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user