wire: use package names to disambiguate variables (google/go-cloud#386)
This commit is contained in:
committed by
Ross Light
parent
fab79bd5bd
commit
c999a4d1b5
@@ -14,14 +14,14 @@ import (
|
||||
|
||||
// Injectors from wire.go:
|
||||
|
||||
func newMainService(config *foo.Config, config2 *bar.Config, config3 *baz.Config) *MainService {
|
||||
func newMainService(config *foo.Config, barConfig *bar.Config, bazConfig *baz.Config) *MainService {
|
||||
service := foo.New(config)
|
||||
service2 := bar.New(config2, service)
|
||||
service3 := baz.New(config3, service2)
|
||||
barService := bar.New(barConfig, service)
|
||||
bazService := baz.New(bazConfig, barService)
|
||||
mainService := &MainService{
|
||||
Foo: service,
|
||||
Bar: service2,
|
||||
Baz: service3,
|
||||
Bar: barService,
|
||||
Baz: bazService,
|
||||
}
|
||||
return mainService
|
||||
}
|
||||
|
||||
@@ -12,9 +12,9 @@ import (
|
||||
// Injectors from wire.go:
|
||||
|
||||
func inject(context3 context2.Context, err2 struct{}) (context, error) {
|
||||
context4, err := provide(context3)
|
||||
mainContext, err := provide(context3)
|
||||
if err != nil {
|
||||
return context{}, err
|
||||
}
|
||||
return context4, nil
|
||||
return mainContext, nil
|
||||
}
|
||||
|
||||
@@ -15,11 +15,11 @@ import (
|
||||
// Injectors from foo.go:
|
||||
|
||||
func inject(context3 context2.Context, err2 struct{}) (context, error) {
|
||||
context4, err := Provide(context3)
|
||||
mainContext, err := Provide(context3)
|
||||
if err != nil {
|
||||
return context{}, err
|
||||
}
|
||||
return context4, nil
|
||||
return mainContext, nil
|
||||
}
|
||||
|
||||
// foo.go:
|
||||
|
||||
@@ -11,10 +11,10 @@ import (
|
||||
|
||||
// Injectors from wire.go:
|
||||
|
||||
func inject(context3 context2.Context, arg struct{}) (context, error) {
|
||||
context4, err := provide(context3)
|
||||
func inject(contextContext context2.Context, arg struct{}) (context, error) {
|
||||
mainContext, err := provide(contextContext)
|
||||
if err != nil {
|
||||
return context{}, err
|
||||
}
|
||||
return context4, nil
|
||||
return mainContext, nil
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
|
||||
func injectFooBar() FooBar {
|
||||
foo := provideFoo()
|
||||
bar2 := bar.ProvideBar()
|
||||
fooBar := provideFooBar(foo, bar2)
|
||||
barBar := bar.ProvideBar()
|
||||
fooBar := provideFooBar(foo, barBar)
|
||||
return fooBar
|
||||
}
|
||||
|
||||
@@ -234,7 +234,8 @@ func (g *gen) inject(pos token.Pos, name string, sig *types.Signature, set *Prov
|
||||
}
|
||||
if g.values[c.valueExpr] == "" {
|
||||
t := c.valueTypeInfo.TypeOf(c.valueExpr)
|
||||
name := disambiguate("_wire"+export(typeVariableName(t))+"Value", g.nameInFileScope)
|
||||
|
||||
name := typeVariableName(t, "", func(name string) string { return "_wire" + export(name) + "Value" }, g.nameInFileScope)
|
||||
g.values[c.valueExpr] = name
|
||||
pendingVars = append(pendingVars, pendingVar{
|
||||
name: name,
|
||||
@@ -472,12 +473,11 @@ func injectPass(name string, params *types.Tuple, injectSig outputSignature, cal
|
||||
pi := params.At(i)
|
||||
a := pi.Name()
|
||||
if a == "" || a == "_" {
|
||||
a = unexport(typeVariableName(pi.Type()))
|
||||
if a == "" {
|
||||
a = "arg"
|
||||
}
|
||||
a = typeVariableName(pi.Type(), "arg", unexport, ig.nameInInjector)
|
||||
} else {
|
||||
a = disambiguate(a, ig.nameInInjector)
|
||||
}
|
||||
ig.paramNames = append(ig.paramNames, disambiguate(a, ig.nameInInjector))
|
||||
ig.paramNames = append(ig.paramNames, a)
|
||||
ig.p("%s %s", ig.paramNames[i], types.TypeString(pi.Type(), ig.g.qualifyPkg))
|
||||
}
|
||||
outTypeString := types.TypeString(injectSig.out, ig.g.qualifyPkg)
|
||||
@@ -493,11 +493,7 @@ func injectPass(name string, params *types.Tuple, injectSig outputSignature, cal
|
||||
}
|
||||
for i := range calls {
|
||||
c := &calls[i]
|
||||
lname := unexport(typeVariableName(c.out))
|
||||
if lname == "" {
|
||||
lname = "v"
|
||||
}
|
||||
lname = disambiguate(lname, ig.nameInInjector)
|
||||
lname := typeVariableName(c.out, "v", unexport, ig.nameInInjector)
|
||||
ig.localNames = append(ig.localNames, lname)
|
||||
switch c.kind {
|
||||
case structProvider:
|
||||
@@ -661,22 +657,53 @@ func zeroValue(t types.Type, qf types.Qualifier) string {
|
||||
}
|
||||
}
|
||||
|
||||
// typeVariableName invents a variable name derived from the type name
|
||||
// or returns the empty string if one could not be found. There are no
|
||||
// guarantees about whether the name is exported or unexported: call
|
||||
// export() or unexport() to convert.
|
||||
func typeVariableName(t types.Type) string {
|
||||
// typeVariableName invents a disambiguated variable name derived from the type name.
|
||||
// If no name can be derived from the type, defaultName is used.
|
||||
// transform is used to transform the derived name(s) (including defaultName);
|
||||
// commonly used functions include export and unexport.
|
||||
// collides is used to see if a name is ambiguous. If any one of the derived
|
||||
// names is unambiguous, it used; otherwise, the first derived name is
|
||||
// disambiguated using disambiguate().
|
||||
func typeVariableName(t types.Type, defaultName string, transform func(string) string, collides func(string) bool) string {
|
||||
if p, ok := t.(*types.Pointer); ok {
|
||||
t = p.Elem()
|
||||
}
|
||||
var names []string
|
||||
switch t := t.(type) {
|
||||
case *types.Basic:
|
||||
return t.Name()
|
||||
if t.Name() != "" {
|
||||
names = append(names, t.Name())
|
||||
}
|
||||
case *types.Named:
|
||||
// TODO(light): Include package name when appropriate.
|
||||
return t.Obj().Name()
|
||||
obj := t.Obj()
|
||||
if name := obj.Name(); name != "" {
|
||||
names = append(names, name)
|
||||
}
|
||||
// Provide an alternate name prefixed with the package name if possible.
|
||||
// E.g., in case of collisions, we'll use "fooCfg" instead of "cfg2".
|
||||
if pkg := obj.Pkg(); pkg != nil && pkg.Name() != "" {
|
||||
names = append(names, fmt.Sprintf("%s%s", pkg.Name(), strings.Title(obj.Name())))
|
||||
}
|
||||
}
|
||||
return ""
|
||||
|
||||
// If we were unable to derive a name, use defaultName.
|
||||
if len(names) == 0 {
|
||||
names = append(names, defaultName)
|
||||
}
|
||||
|
||||
// Transform the name(s).
|
||||
for i, name := range names {
|
||||
names[i] = transform(name)
|
||||
}
|
||||
|
||||
// See if there's an unambiguous name; if so, use it.
|
||||
for _, name := range names {
|
||||
if !collides(name) {
|
||||
return name
|
||||
}
|
||||
}
|
||||
// Otherwise, disambiguate the first name.
|
||||
return disambiguate(names[0], collides)
|
||||
}
|
||||
|
||||
// unexport converts a name that is potentially exported to an unexported name.
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/build"
|
||||
"go/types"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
@@ -242,30 +243,80 @@ func TestExport(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeVariableName(t *testing.T) {
|
||||
var (
|
||||
boolT = types.Typ[types.Bool]
|
||||
stringT = types.Typ[types.String]
|
||||
fooVarT = types.NewNamed(types.NewTypeName(0, nil, "foo", stringT), stringT, nil)
|
||||
nonameVarT = types.NewNamed(types.NewTypeName(0, nil, "", stringT), stringT, nil)
|
||||
barVarInFooPkgT = types.NewNamed(types.NewTypeName(0, types.NewPackage("my.example/foo", "foo"), "bar", stringT), stringT, nil)
|
||||
)
|
||||
tests := []struct {
|
||||
description string
|
||||
typ types.Type
|
||||
defaultName string
|
||||
transformAppend string
|
||||
collides map[string]bool
|
||||
want string
|
||||
}{
|
||||
{"basic type", boolT, "", "", map[string]bool{}, "bool"},
|
||||
{"basic type with transform", boolT, "", "suffix", map[string]bool{}, "boolsuffix"},
|
||||
{"basic type with collision", boolT, "", "", map[string]bool{"bool": true}, "bool2"},
|
||||
{"basic type with transform and collision", boolT, "", "suffix", map[string]bool{"boolsuffix": true}, "boolsuffix2"},
|
||||
{"a different basic type", stringT, "", "", map[string]bool{}, "string"},
|
||||
{"named type", fooVarT, "", "", map[string]bool{}, "foo"},
|
||||
{"named type with transform", fooVarT, "", "suffix", map[string]bool{}, "foosuffix"},
|
||||
{"named type with collision", fooVarT, "", "", map[string]bool{"foo": true}, "foo2"},
|
||||
{"named type with transform and collision", fooVarT, "", "suffix", map[string]bool{"foosuffix": true}, "foosuffix2"},
|
||||
{"noname type", nonameVarT, "bar", "", map[string]bool{}, "bar"},
|
||||
{"noname type with transform", nonameVarT, "bar", "s", map[string]bool{}, "bars"},
|
||||
{"noname type with transform and collision", nonameVarT, "bar", "s", map[string]bool{"bars": true}, "bars2"},
|
||||
{"var in pkg type", barVarInFooPkgT, "", "", map[string]bool{}, "bar"},
|
||||
{"var in pkg type with collision", barVarInFooPkgT, "", "", map[string]bool{"bar": true}, "fooBar"},
|
||||
{"var in pkg type with double collision", barVarInFooPkgT, "", "", map[string]bool{"bar": true, "fooBar": true}, "bar2"},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(fmt.Sprintf("%s: typeVariableName(%v, %q, %q, %v)", test.description, test.typ, test.defaultName, test.transformAppend, test.collides), func(t *testing.T) {
|
||||
got := typeVariableName(test.typ, test.defaultName, func(name string) string { return name + test.transformAppend }, func(name string) bool { return test.collides[name] })
|
||||
if !isIdent(got) {
|
||||
t.Errorf("%q is not an identifier", got)
|
||||
}
|
||||
if got != test.want {
|
||||
t.Errorf("got %q want %q", got, test.want)
|
||||
}
|
||||
if test.collides[got] {
|
||||
t.Errorf("%q collides", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisambiguate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
contains string
|
||||
want string
|
||||
collides map[string]bool
|
||||
}{
|
||||
{"foo", "foo", nil},
|
||||
{"foo", "foo", map[string]bool{"foo": true}},
|
||||
{"foo", "foo", map[string]bool{"foo": true, "foo1": true, "foo2": true}},
|
||||
{"foo1", "foo", map[string]bool{"foo": true, "foo1": true, "foo2": true}},
|
||||
{"foo\u0661", "foo", map[string]bool{"foo": true, "foo1": true, "foo2": true}},
|
||||
{"foo\u0661", "foo", map[string]bool{"foo": true, "foo1": true, "foo2": true, "foo\u0661": true}},
|
||||
{"foo", "foo2", map[string]bool{"foo": true}},
|
||||
{"foo", "foo3", map[string]bool{"foo": true, "foo1": true, "foo2": true}},
|
||||
{"foo1", "foo1_2", map[string]bool{"foo": true, "foo1": true, "foo2": true}},
|
||||
{"foo\u0661", "foo\u0661", map[string]bool{"foo": true, "foo1": true, "foo2": true}},
|
||||
{"foo\u0661", "foo\u06612", map[string]bool{"foo": true, "foo1": true, "foo2": true, "foo\u0661": true}},
|
||||
}
|
||||
for _, test := range tests {
|
||||
got := disambiguate(test.name, func(name string) bool { return test.collides[name] })
|
||||
if !isIdent(got) {
|
||||
t.Errorf("disambiguate(%q, %v) = %q; not an identifier", test.name, test.collides, got)
|
||||
}
|
||||
if !strings.Contains(got, test.contains) {
|
||||
t.Errorf("disambiguate(%q, %v) = %q; wanted to contain %q", test.name, test.collides, got, test.contains)
|
||||
}
|
||||
if test.collides[got] {
|
||||
t.Errorf("disambiguate(%q, %v) = %q; ", test.name, test.collides, got)
|
||||
}
|
||||
t.Run(fmt.Sprintf("disambiguate(%q, %v)", test.name, test.collides), func(t *testing.T) {
|
||||
got := disambiguate(test.name, func(name string) bool { return test.collides[name] })
|
||||
if !isIdent(got) {
|
||||
t.Errorf("%q is not an identifier", got)
|
||||
}
|
||||
if got != test.want {
|
||||
t.Errorf("got %q want %q", got, test.want)
|
||||
}
|
||||
if test.collides[got] {
|
||||
t.Errorf("%q collides", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user