wire: Build now returns an error if it has any unused arguments (google/go-cloud#268)

Fixes google/go-cloud#164
This commit is contained in:
Robert van Gent
2018-08-02 09:31:50 -07:00
committed by Ross Light
parent b348a78000
commit 85deb53791
6 changed files with 204 additions and 9 deletions

View File

@@ -108,7 +108,6 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide
index.Set(g, i) index.Set(g, i)
} }
} }
if len(ec.errors) > 0 { if len(ec.errors) > 0 {
return nil, ec.errors return nil, ec.errors
} }
@@ -118,6 +117,7 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide
// guaranteed to be acyclic. An index value of errAbort indicates that // guaranteed to be acyclic. An index value of errAbort indicates that
// the type was visited, but failed due to an error added to ec. // the type was visited, but failed due to an error added to ec.
errAbort := errors.New("failed to visit") errAbort := errors.New("failed to visit")
var used []*providerSetSrc
var calls []call var calls []call
type frame struct { type frame struct {
t types.Type t types.Type
@@ -145,6 +145,8 @@ dfs:
continue continue
case pv.IsProvider(): case pv.IsProvider():
p := pv.Provider() p := pv.Provider()
src := set.srcMap.At(curr.t).(*providerSetSrc)
used = append(used, src)
if !types.Identical(p.Out, curr.t) { if !types.Identical(p.Out, curr.t) {
// Interface binding. Don't create a call ourselves. // Interface binding. Don't create a call ourselves.
i := index.At(p.Out) i := index.At(p.Out)
@@ -212,6 +214,8 @@ dfs:
index.Set(curr.t, i) index.Set(curr.t, i)
continue continue
} }
src := set.srcMap.At(curr.t).(*providerSetSrc)
used = append(used, src)
index.Set(curr.t, len(given)+len(calls)) index.Set(curr.t, len(given)+len(calls))
calls = append(calls, call{ calls = append(calls, call{
kind: valueExpr, kind: valueExpr,
@@ -226,14 +230,77 @@ dfs:
if len(ec.errors) > 0 { if len(ec.errors) > 0 {
return nil, ec.errors return nil, ec.errors
} }
if errs := verifyArgsUsed(set, used); len(errs) > 0 {
return nil, errs
}
return calls, nil return calls, nil
} }
// buildProviderMap creates the providerMap field for a given provider set. // verifyArgsUsed ensures that all of the arguments in set were used during solve.
// The given provider set's providerMap field is ignored. func verifyArgsUsed(set *ProviderSet, used []*providerSetSrc) []error {
func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *ProviderSet) (*typeutil.Map, []error) { var errs []error
for _, imp := range set.Imports {
found := false
for _, u := range used {
if u.Import == imp {
found = true
break
}
}
if !found {
if imp.Name == "" {
errs = append(errs, errors.New("unused provider set"))
} else {
errs = append(errs, fmt.Errorf("unused provider set %q", imp.Name))
}
}
}
for _, p := range set.Providers {
found := false
for _, u := range used {
if u.Provider == p {
found = true
break
}
}
if !found {
errs = append(errs, fmt.Errorf("unused provider %q", p.Name))
}
}
for _, v := range set.Values {
found := false
for _, u := range used {
if u.Value == v {
found = true
break
}
}
if !found {
errs = append(errs, fmt.Errorf("unused value of type %s", types.TypeString(v.Out, nil)))
}
}
for _, b := range set.Bindings {
found := false
for _, u := range used {
if u.Binding == b {
found = true
break
}
}
if !found {
errs = append(errs, fmt.Errorf("unused interface binding to type %s", types.TypeString(b.Iface, nil)))
}
}
return errs
}
// buildProviderMap creates the providerMap and srcMap fields for a given provider set.
// The given provider set's providerMap and srcMap fields are ignored.
func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *ProviderSet) (*typeutil.Map, *typeutil.Map, []error) {
providerMap := new(typeutil.Map) providerMap := new(typeutil.Map)
providerMap.SetHasher(hasher) providerMap.SetHasher(hasher)
srcMap := new(typeutil.Map)
srcMap.SetHasher(hasher)
setMap := new(typeutil.Map) // to *ProviderSet, for error messages setMap := new(typeutil.Map) // to *ProviderSet, for error messages
setMap.SetHasher(hasher) setMap.SetHasher(hasher)
@@ -246,11 +313,12 @@ func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *Provider
return return
} }
providerMap.Set(k, v) providerMap.Set(k, v)
srcMap.Set(k, &providerSetSrc{Import: imp})
setMap.Set(k, imp) setMap.Set(k, imp)
}) })
} }
if len(ec.errors) > 0 { if len(ec.errors) > 0 {
return nil, ec.errors return nil, nil, ec.errors
} }
// Process non-binding providers in new set. // Process non-binding providers in new set.
@@ -260,6 +328,7 @@ func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *Provider
continue continue
} }
providerMap.Set(p.Out, p) providerMap.Set(p.Out, p)
srcMap.Set(p.Out, &providerSetSrc{Provider: p})
setMap.Set(p.Out, set) setMap.Set(p.Out, set)
} }
for _, v := range set.Values { for _, v := range set.Values {
@@ -268,10 +337,11 @@ func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *Provider
continue continue
} }
providerMap.Set(v.Out, v) providerMap.Set(v.Out, v)
srcMap.Set(v.Out, &providerSetSrc{Value: v})
setMap.Set(v.Out, set) setMap.Set(v.Out, set)
} }
if len(ec.errors) > 0 { if len(ec.errors) > 0 {
return nil, ec.errors return nil, nil, ec.errors
} }
// Process bindings in set. Must happen after the other providers to // Process bindings in set. Must happen after the other providers to
@@ -289,12 +359,13 @@ func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *Provider
continue continue
} }
providerMap.Set(b.Iface, concrete) providerMap.Set(b.Iface, concrete)
srcMap.Set(b.Iface, &providerSetSrc{Binding: b})
setMap.Set(b.Iface, set) setMap.Set(b.Iface, set)
} }
if len(ec.errors) > 0 { if len(ec.errors) > 0 {
return nil, ec.errors return nil, nil, ec.errors
} }
return providerMap, nil return providerMap, srcMap, nil
} }
func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) []error { func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) []error {

View File

@@ -29,6 +29,15 @@ import (
"golang.org/x/tools/go/types/typeutil" "golang.org/x/tools/go/types/typeutil"
) )
// A providerSetSrc captures the source for a type provided by a ProviderSet.
// Exactly one of the fields will be set.
type providerSetSrc struct {
Provider *Provider
Binding *IfaceBinding
Value *Value
Import *ProviderSet
}
// A ProviderSet describes a set of providers. The zero value is an empty // A ProviderSet describes a set of providers. The zero value is an empty
// ProviderSet. // ProviderSet.
type ProviderSet struct { type ProviderSet struct {
@@ -49,6 +58,10 @@ type ProviderSet struct {
// providerMap maps from provided type to a *Provider or *Value. // providerMap maps from provided type to a *Provider or *Value.
// It includes all of the imported types. // It includes all of the imported types.
providerMap *typeutil.Map providerMap *typeutil.Map
// srcMap maps from provided type to a *providerSetSrc capturing the
// Provider, Binding, Value, or Import that provided the type.
srcMap *typeutil.Map
} }
// Outputs returns a new slice containing the set of possible types the // Outputs returns a new slice containing the set of possible types the
@@ -496,7 +509,7 @@ func (oc *objectCache) processNewSet(pkg *loader.PackageInfo, call *ast.CallExpr
return nil, ec.errors return nil, ec.errors
} }
var errs []error var errs []error
pset.providerMap, errs = buildProviderMap(oc.prog.Fset, oc.hasher, pset) pset.providerMap, pset.srcMap, errs = buildProviderMap(oc.prog.Fset, oc.hasher, pset)
if len(errs) > 0 { if len(errs) > 0 {
return nil, errs return nil, errs
} }

View File

@@ -0,0 +1,71 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"fmt"
"github.com/google/go-cloud/wire"
)
func main() {
fmt.Println(injectBar())
}
type Foo int
type Bar int
type Unused int
type UnusedInSet int
type OneOfTwo int
type TwoOfTwo int
var (
unusedSet = wire.NewSet(provideUnusedInSet)
partiallyUsedSet = wire.NewSet(provideOneOfTwo, provideTwoOfTwo)
)
type Fooer interface {
Foo() string
}
func (f *Foo) Foo() string {
return fmt.Sprintf("Hello World %d", f)
}
func provideFoo() *Foo {
f := new(Foo)
*f = 1
return f
}
func provideBar(foo *Foo, one OneOfTwo) Bar {
return Bar(int(*foo) + int(one))
}
func provideUnused() Unused {
return 1
}
func provideUnusedInSet() UnusedInSet {
return 1
}
func provideOneOfTwo() OneOfTwo {
return 1
}
func provideTwoOfTwo() TwoOfTwo {
return 1
}

View File

@@ -0,0 +1,34 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//+build wireinject
package main
import (
"github.com/google/go-cloud/wire"
)
func injectBar() Bar {
wire.Build(
provideFoo, // needed as input for provideBar
provideBar, // needed for Bar
partiallyUsedSet, // 1/2 providers in the set are needed
provideUnused, // not needed -> error
wire.Value("unused"), // not needed -> error
unusedSet, // nothing in set is needed -> error
wire.Bind((*Fooer)(nil), (*Foo)(nil)), // binding to Fooer is not needed -> error
)
return 0
}

View File

@@ -0,0 +1,5 @@
ERROR
unused provider set
unused provider "provideUnused"
unused value of type string
unused interface binding to type foo.Fooer

View File

@@ -0,0 +1 @@
foo