wire: use package names to disambiguate variables (google/go-cloud#386)

This commit is contained in:
Robert van Gent
2018-09-10 13:50:38 -07:00
committed by Ross Light
parent fab79bd5bd
commit c999a4d1b5
7 changed files with 128 additions and 50 deletions

View File

@@ -14,14 +14,14 @@ import (
// Injectors from wire.go:
func newMainService(config *foo.Config, config2 *bar.Config, config3 *baz.Config) *MainService {
func newMainService(config *foo.Config, barConfig *bar.Config, bazConfig *baz.Config) *MainService {
service := foo.New(config)
service2 := bar.New(config2, service)
service3 := baz.New(config3, service2)
barService := bar.New(barConfig, service)
bazService := baz.New(bazConfig, barService)
mainService := &MainService{
Foo: service,
Bar: service2,
Baz: service3,
Bar: barService,
Baz: bazService,
}
return mainService
}

View File

@@ -12,9 +12,9 @@ import (
// Injectors from wire.go:
func inject(context3 context2.Context, err2 struct{}) (context, error) {
context4, err := provide(context3)
mainContext, err := provide(context3)
if err != nil {
return context{}, err
}
return context4, nil
return mainContext, nil
}

View File

@@ -15,11 +15,11 @@ import (
// Injectors from foo.go:
func inject(context3 context2.Context, err2 struct{}) (context, error) {
context4, err := Provide(context3)
mainContext, err := Provide(context3)
if err != nil {
return context{}, err
}
return context4, nil
return mainContext, nil
}
// foo.go:

View File

@@ -11,10 +11,10 @@ import (
// Injectors from wire.go:
func inject(context3 context2.Context, arg struct{}) (context, error) {
context4, err := provide(context3)
func inject(contextContext context2.Context, arg struct{}) (context, error) {
mainContext, err := provide(contextContext)
if err != nil {
return context{}, err
}
return context4, nil
return mainContext, nil
}

View File

@@ -13,7 +13,7 @@ import (
func injectFooBar() FooBar {
foo := provideFoo()
bar2 := bar.ProvideBar()
fooBar := provideFooBar(foo, bar2)
barBar := bar.ProvideBar()
fooBar := provideFooBar(foo, barBar)
return fooBar
}

View File

@@ -234,7 +234,8 @@ func (g *gen) inject(pos token.Pos, name string, sig *types.Signature, set *Prov
}
if g.values[c.valueExpr] == "" {
t := c.valueTypeInfo.TypeOf(c.valueExpr)
name := disambiguate("_wire"+export(typeVariableName(t))+"Value", g.nameInFileScope)
name := typeVariableName(t, "", func(name string) string { return "_wire" + export(name) + "Value" }, g.nameInFileScope)
g.values[c.valueExpr] = name
pendingVars = append(pendingVars, pendingVar{
name: name,
@@ -472,12 +473,11 @@ func injectPass(name string, params *types.Tuple, injectSig outputSignature, cal
pi := params.At(i)
a := pi.Name()
if a == "" || a == "_" {
a = unexport(typeVariableName(pi.Type()))
if a == "" {
a = "arg"
a = typeVariableName(pi.Type(), "arg", unexport, ig.nameInInjector)
} else {
a = disambiguate(a, ig.nameInInjector)
}
}
ig.paramNames = append(ig.paramNames, disambiguate(a, ig.nameInInjector))
ig.paramNames = append(ig.paramNames, a)
ig.p("%s %s", ig.paramNames[i], types.TypeString(pi.Type(), ig.g.qualifyPkg))
}
outTypeString := types.TypeString(injectSig.out, ig.g.qualifyPkg)
@@ -493,11 +493,7 @@ func injectPass(name string, params *types.Tuple, injectSig outputSignature, cal
}
for i := range calls {
c := &calls[i]
lname := unexport(typeVariableName(c.out))
if lname == "" {
lname = "v"
}
lname = disambiguate(lname, ig.nameInInjector)
lname := typeVariableName(c.out, "v", unexport, ig.nameInInjector)
ig.localNames = append(ig.localNames, lname)
switch c.kind {
case structProvider:
@@ -661,22 +657,53 @@ func zeroValue(t types.Type, qf types.Qualifier) string {
}
}
// typeVariableName invents a variable name derived from the type name
// or returns the empty string if one could not be found. There are no
// guarantees about whether the name is exported or unexported: call
// export() or unexport() to convert.
func typeVariableName(t types.Type) string {
// typeVariableName invents a disambiguated variable name derived from the type name.
// If no name can be derived from the type, defaultName is used.
// transform is used to transform the derived name(s) (including defaultName);
// commonly used functions include export and unexport.
// collides is used to see if a name is ambiguous. If any one of the derived
// names is unambiguous, it used; otherwise, the first derived name is
// disambiguated using disambiguate().
func typeVariableName(t types.Type, defaultName string, transform func(string) string, collides func(string) bool) string {
if p, ok := t.(*types.Pointer); ok {
t = p.Elem()
}
var names []string
switch t := t.(type) {
case *types.Basic:
return t.Name()
case *types.Named:
// TODO(light): Include package name when appropriate.
return t.Obj().Name()
if t.Name() != "" {
names = append(names, t.Name())
}
return ""
case *types.Named:
obj := t.Obj()
if name := obj.Name(); name != "" {
names = append(names, name)
}
// Provide an alternate name prefixed with the package name if possible.
// E.g., in case of collisions, we'll use "fooCfg" instead of "cfg2".
if pkg := obj.Pkg(); pkg != nil && pkg.Name() != "" {
names = append(names, fmt.Sprintf("%s%s", pkg.Name(), strings.Title(obj.Name())))
}
}
// If we were unable to derive a name, use defaultName.
if len(names) == 0 {
names = append(names, defaultName)
}
// Transform the name(s).
for i, name := range names {
names[i] = transform(name)
}
// See if there's an unambiguous name; if so, use it.
for _, name := range names {
if !collides(name) {
return name
}
}
// Otherwise, disambiguate the first name.
return disambiguate(names[0], collides)
}
// unexport converts a name that is potentially exported to an unexported name.

View File

@@ -19,6 +19,7 @@ import (
"errors"
"fmt"
"go/build"
"go/types"
"io"
"io/ioutil"
"os"
@@ -242,30 +243,80 @@ func TestExport(t *testing.T) {
}
}
func TestTypeVariableName(t *testing.T) {
var (
boolT = types.Typ[types.Bool]
stringT = types.Typ[types.String]
fooVarT = types.NewNamed(types.NewTypeName(0, nil, "foo", stringT), stringT, nil)
nonameVarT = types.NewNamed(types.NewTypeName(0, nil, "", stringT), stringT, nil)
barVarInFooPkgT = types.NewNamed(types.NewTypeName(0, types.NewPackage("my.example/foo", "foo"), "bar", stringT), stringT, nil)
)
tests := []struct {
description string
typ types.Type
defaultName string
transformAppend string
collides map[string]bool
want string
}{
{"basic type", boolT, "", "", map[string]bool{}, "bool"},
{"basic type with transform", boolT, "", "suffix", map[string]bool{}, "boolsuffix"},
{"basic type with collision", boolT, "", "", map[string]bool{"bool": true}, "bool2"},
{"basic type with transform and collision", boolT, "", "suffix", map[string]bool{"boolsuffix": true}, "boolsuffix2"},
{"a different basic type", stringT, "", "", map[string]bool{}, "string"},
{"named type", fooVarT, "", "", map[string]bool{}, "foo"},
{"named type with transform", fooVarT, "", "suffix", map[string]bool{}, "foosuffix"},
{"named type with collision", fooVarT, "", "", map[string]bool{"foo": true}, "foo2"},
{"named type with transform and collision", fooVarT, "", "suffix", map[string]bool{"foosuffix": true}, "foosuffix2"},
{"noname type", nonameVarT, "bar", "", map[string]bool{}, "bar"},
{"noname type with transform", nonameVarT, "bar", "s", map[string]bool{}, "bars"},
{"noname type with transform and collision", nonameVarT, "bar", "s", map[string]bool{"bars": true}, "bars2"},
{"var in pkg type", barVarInFooPkgT, "", "", map[string]bool{}, "bar"},
{"var in pkg type with collision", barVarInFooPkgT, "", "", map[string]bool{"bar": true}, "fooBar"},
{"var in pkg type with double collision", barVarInFooPkgT, "", "", map[string]bool{"bar": true, "fooBar": true}, "bar2"},
}
for _, test := range tests {
t.Run(fmt.Sprintf("%s: typeVariableName(%v, %q, %q, %v)", test.description, test.typ, test.defaultName, test.transformAppend, test.collides), func(t *testing.T) {
got := typeVariableName(test.typ, test.defaultName, func(name string) string { return name + test.transformAppend }, func(name string) bool { return test.collides[name] })
if !isIdent(got) {
t.Errorf("%q is not an identifier", got)
}
if got != test.want {
t.Errorf("got %q want %q", got, test.want)
}
if test.collides[got] {
t.Errorf("%q collides", got)
}
})
}
}
func TestDisambiguate(t *testing.T) {
tests := []struct {
name string
contains string
want string
collides map[string]bool
}{
{"foo", "foo", nil},
{"foo", "foo", map[string]bool{"foo": true}},
{"foo", "foo", map[string]bool{"foo": true, "foo1": true, "foo2": true}},
{"foo1", "foo", map[string]bool{"foo": true, "foo1": true, "foo2": true}},
{"foo\u0661", "foo", map[string]bool{"foo": true, "foo1": true, "foo2": true}},
{"foo\u0661", "foo", map[string]bool{"foo": true, "foo1": true, "foo2": true, "foo\u0661": true}},
{"foo", "foo2", map[string]bool{"foo": true}},
{"foo", "foo3", map[string]bool{"foo": true, "foo1": true, "foo2": true}},
{"foo1", "foo1_2", map[string]bool{"foo": true, "foo1": true, "foo2": true}},
{"foo\u0661", "foo\u0661", map[string]bool{"foo": true, "foo1": true, "foo2": true}},
{"foo\u0661", "foo\u06612", map[string]bool{"foo": true, "foo1": true, "foo2": true, "foo\u0661": true}},
}
for _, test := range tests {
t.Run(fmt.Sprintf("disambiguate(%q, %v)", test.name, test.collides), func(t *testing.T) {
got := disambiguate(test.name, func(name string) bool { return test.collides[name] })
if !isIdent(got) {
t.Errorf("disambiguate(%q, %v) = %q; not an identifier", test.name, test.collides, got)
t.Errorf("%q is not an identifier", got)
}
if !strings.Contains(got, test.contains) {
t.Errorf("disambiguate(%q, %v) = %q; wanted to contain %q", test.name, test.collides, got, test.contains)
if got != test.want {
t.Errorf("got %q want %q", got, test.want)
}
if test.collides[got] {
t.Errorf("disambiguate(%q, %v) = %q; ", test.name, test.collides, got)
t.Errorf("%q collides", got)
}
})
}
}