From c999a4d1b565ac46aea2f6b370f324dfb48e671d Mon Sep 17 00:00:00 2001 From: Robert van Gent Date: Mon, 10 Sep 2018 13:50:38 -0700 Subject: [PATCH] wire: use package names to disambiguate variables (google/go-cloud#386) --- .../MultipleSimilarPackages/want/wire_gen.go | 10 +-- .../testdata/NamingWorstCase/want/wire_gen.go | 4 +- .../NamingWorstCaseAllInOne/want/wire_gen.go | 4 +- .../NoInjectParamNames/want/wire_gen.go | 6 +- .../wire/testdata/PkgImport/want/wire_gen.go | 4 +- internal/wire/wire.go | 67 ++++++++++----- internal/wire/wire_test.go | 83 +++++++++++++++---- 7 files changed, 128 insertions(+), 50 deletions(-) diff --git a/internal/wire/testdata/MultipleSimilarPackages/want/wire_gen.go b/internal/wire/testdata/MultipleSimilarPackages/want/wire_gen.go index 1e8cc60..d564706 100644 --- a/internal/wire/testdata/MultipleSimilarPackages/want/wire_gen.go +++ b/internal/wire/testdata/MultipleSimilarPackages/want/wire_gen.go @@ -14,14 +14,14 @@ import ( // Injectors from wire.go: -func newMainService(config *foo.Config, config2 *bar.Config, config3 *baz.Config) *MainService { +func newMainService(config *foo.Config, barConfig *bar.Config, bazConfig *baz.Config) *MainService { service := foo.New(config) - service2 := bar.New(config2, service) - service3 := baz.New(config3, service2) + barService := bar.New(barConfig, service) + bazService := baz.New(bazConfig, barService) mainService := &MainService{ Foo: service, - Bar: service2, - Baz: service3, + Bar: barService, + Baz: bazService, } return mainService } diff --git a/internal/wire/testdata/NamingWorstCase/want/wire_gen.go b/internal/wire/testdata/NamingWorstCase/want/wire_gen.go index a2801c6..f44ccb6 100644 --- a/internal/wire/testdata/NamingWorstCase/want/wire_gen.go +++ b/internal/wire/testdata/NamingWorstCase/want/wire_gen.go @@ -12,9 +12,9 @@ import ( // Injectors from wire.go: func inject(context3 context2.Context, err2 struct{}) (context, error) { - context4, err := provide(context3) + mainContext, err := provide(context3) if err != nil { return context{}, err } - return context4, nil + return mainContext, nil } diff --git a/internal/wire/testdata/NamingWorstCaseAllInOne/want/wire_gen.go b/internal/wire/testdata/NamingWorstCaseAllInOne/want/wire_gen.go index 8a21ea4..0f647cb 100644 --- a/internal/wire/testdata/NamingWorstCaseAllInOne/want/wire_gen.go +++ b/internal/wire/testdata/NamingWorstCaseAllInOne/want/wire_gen.go @@ -15,11 +15,11 @@ import ( // Injectors from foo.go: func inject(context3 context2.Context, err2 struct{}) (context, error) { - context4, err := Provide(context3) + mainContext, err := Provide(context3) if err != nil { return context{}, err } - return context4, nil + return mainContext, nil } // foo.go: diff --git a/internal/wire/testdata/NoInjectParamNames/want/wire_gen.go b/internal/wire/testdata/NoInjectParamNames/want/wire_gen.go index 0d15c50..51f4240 100644 --- a/internal/wire/testdata/NoInjectParamNames/want/wire_gen.go +++ b/internal/wire/testdata/NoInjectParamNames/want/wire_gen.go @@ -11,10 +11,10 @@ import ( // Injectors from wire.go: -func inject(context3 context2.Context, arg struct{}) (context, error) { - context4, err := provide(context3) +func inject(contextContext context2.Context, arg struct{}) (context, error) { + mainContext, err := provide(contextContext) if err != nil { return context{}, err } - return context4, nil + return mainContext, nil } diff --git a/internal/wire/testdata/PkgImport/want/wire_gen.go b/internal/wire/testdata/PkgImport/want/wire_gen.go index a4e44b0..be14076 100644 --- a/internal/wire/testdata/PkgImport/want/wire_gen.go +++ b/internal/wire/testdata/PkgImport/want/wire_gen.go @@ -13,7 +13,7 @@ import ( func injectFooBar() FooBar { foo := provideFoo() - bar2 := bar.ProvideBar() - fooBar := provideFooBar(foo, bar2) + barBar := bar.ProvideBar() + fooBar := provideFooBar(foo, barBar) return fooBar } diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 666cae1..3b0d873 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -234,7 +234,8 @@ func (g *gen) inject(pos token.Pos, name string, sig *types.Signature, set *Prov } if g.values[c.valueExpr] == "" { t := c.valueTypeInfo.TypeOf(c.valueExpr) - name := disambiguate("_wire"+export(typeVariableName(t))+"Value", g.nameInFileScope) + + name := typeVariableName(t, "", func(name string) string { return "_wire" + export(name) + "Value" }, g.nameInFileScope) g.values[c.valueExpr] = name pendingVars = append(pendingVars, pendingVar{ name: name, @@ -472,12 +473,11 @@ func injectPass(name string, params *types.Tuple, injectSig outputSignature, cal pi := params.At(i) a := pi.Name() if a == "" || a == "_" { - a = unexport(typeVariableName(pi.Type())) - if a == "" { - a = "arg" - } + a = typeVariableName(pi.Type(), "arg", unexport, ig.nameInInjector) + } else { + a = disambiguate(a, ig.nameInInjector) } - ig.paramNames = append(ig.paramNames, disambiguate(a, ig.nameInInjector)) + ig.paramNames = append(ig.paramNames, a) ig.p("%s %s", ig.paramNames[i], types.TypeString(pi.Type(), ig.g.qualifyPkg)) } outTypeString := types.TypeString(injectSig.out, ig.g.qualifyPkg) @@ -493,11 +493,7 @@ func injectPass(name string, params *types.Tuple, injectSig outputSignature, cal } for i := range calls { c := &calls[i] - lname := unexport(typeVariableName(c.out)) - if lname == "" { - lname = "v" - } - lname = disambiguate(lname, ig.nameInInjector) + lname := typeVariableName(c.out, "v", unexport, ig.nameInInjector) ig.localNames = append(ig.localNames, lname) switch c.kind { case structProvider: @@ -661,22 +657,53 @@ func zeroValue(t types.Type, qf types.Qualifier) string { } } -// typeVariableName invents a variable name derived from the type name -// or returns the empty string if one could not be found. There are no -// guarantees about whether the name is exported or unexported: call -// export() or unexport() to convert. -func typeVariableName(t types.Type) string { +// typeVariableName invents a disambiguated variable name derived from the type name. +// If no name can be derived from the type, defaultName is used. +// transform is used to transform the derived name(s) (including defaultName); +// commonly used functions include export and unexport. +// collides is used to see if a name is ambiguous. If any one of the derived +// names is unambiguous, it used; otherwise, the first derived name is +// disambiguated using disambiguate(). +func typeVariableName(t types.Type, defaultName string, transform func(string) string, collides func(string) bool) string { if p, ok := t.(*types.Pointer); ok { t = p.Elem() } + var names []string switch t := t.(type) { case *types.Basic: - return t.Name() + if t.Name() != "" { + names = append(names, t.Name()) + } case *types.Named: - // TODO(light): Include package name when appropriate. - return t.Obj().Name() + obj := t.Obj() + if name := obj.Name(); name != "" { + names = append(names, name) + } + // Provide an alternate name prefixed with the package name if possible. + // E.g., in case of collisions, we'll use "fooCfg" instead of "cfg2". + if pkg := obj.Pkg(); pkg != nil && pkg.Name() != "" { + names = append(names, fmt.Sprintf("%s%s", pkg.Name(), strings.Title(obj.Name()))) + } } - return "" + + // If we were unable to derive a name, use defaultName. + if len(names) == 0 { + names = append(names, defaultName) + } + + // Transform the name(s). + for i, name := range names { + names[i] = transform(name) + } + + // See if there's an unambiguous name; if so, use it. + for _, name := range names { + if !collides(name) { + return name + } + } + // Otherwise, disambiguate the first name. + return disambiguate(names[0], collides) } // unexport converts a name that is potentially exported to an unexported name. diff --git a/internal/wire/wire_test.go b/internal/wire/wire_test.go index 90ed012..bf8a01b 100644 --- a/internal/wire/wire_test.go +++ b/internal/wire/wire_test.go @@ -19,6 +19,7 @@ import ( "errors" "fmt" "go/build" + "go/types" "io" "io/ioutil" "os" @@ -242,30 +243,80 @@ func TestExport(t *testing.T) { } } +func TestTypeVariableName(t *testing.T) { + var ( + boolT = types.Typ[types.Bool] + stringT = types.Typ[types.String] + fooVarT = types.NewNamed(types.NewTypeName(0, nil, "foo", stringT), stringT, nil) + nonameVarT = types.NewNamed(types.NewTypeName(0, nil, "", stringT), stringT, nil) + barVarInFooPkgT = types.NewNamed(types.NewTypeName(0, types.NewPackage("my.example/foo", "foo"), "bar", stringT), stringT, nil) + ) + tests := []struct { + description string + typ types.Type + defaultName string + transformAppend string + collides map[string]bool + want string + }{ + {"basic type", boolT, "", "", map[string]bool{}, "bool"}, + {"basic type with transform", boolT, "", "suffix", map[string]bool{}, "boolsuffix"}, + {"basic type with collision", boolT, "", "", map[string]bool{"bool": true}, "bool2"}, + {"basic type with transform and collision", boolT, "", "suffix", map[string]bool{"boolsuffix": true}, "boolsuffix2"}, + {"a different basic type", stringT, "", "", map[string]bool{}, "string"}, + {"named type", fooVarT, "", "", map[string]bool{}, "foo"}, + {"named type with transform", fooVarT, "", "suffix", map[string]bool{}, "foosuffix"}, + {"named type with collision", fooVarT, "", "", map[string]bool{"foo": true}, "foo2"}, + {"named type with transform and collision", fooVarT, "", "suffix", map[string]bool{"foosuffix": true}, "foosuffix2"}, + {"noname type", nonameVarT, "bar", "", map[string]bool{}, "bar"}, + {"noname type with transform", nonameVarT, "bar", "s", map[string]bool{}, "bars"}, + {"noname type with transform and collision", nonameVarT, "bar", "s", map[string]bool{"bars": true}, "bars2"}, + {"var in pkg type", barVarInFooPkgT, "", "", map[string]bool{}, "bar"}, + {"var in pkg type with collision", barVarInFooPkgT, "", "", map[string]bool{"bar": true}, "fooBar"}, + {"var in pkg type with double collision", barVarInFooPkgT, "", "", map[string]bool{"bar": true, "fooBar": true}, "bar2"}, + } + for _, test := range tests { + t.Run(fmt.Sprintf("%s: typeVariableName(%v, %q, %q, %v)", test.description, test.typ, test.defaultName, test.transformAppend, test.collides), func(t *testing.T) { + got := typeVariableName(test.typ, test.defaultName, func(name string) string { return name + test.transformAppend }, func(name string) bool { return test.collides[name] }) + if !isIdent(got) { + t.Errorf("%q is not an identifier", got) + } + if got != test.want { + t.Errorf("got %q want %q", got, test.want) + } + if test.collides[got] { + t.Errorf("%q collides", got) + } + }) + } +} + func TestDisambiguate(t *testing.T) { tests := []struct { name string - contains string + want string collides map[string]bool }{ {"foo", "foo", nil}, - {"foo", "foo", map[string]bool{"foo": true}}, - {"foo", "foo", map[string]bool{"foo": true, "foo1": true, "foo2": true}}, - {"foo1", "foo", map[string]bool{"foo": true, "foo1": true, "foo2": true}}, - {"foo\u0661", "foo", map[string]bool{"foo": true, "foo1": true, "foo2": true}}, - {"foo\u0661", "foo", map[string]bool{"foo": true, "foo1": true, "foo2": true, "foo\u0661": true}}, + {"foo", "foo2", map[string]bool{"foo": true}}, + {"foo", "foo3", map[string]bool{"foo": true, "foo1": true, "foo2": true}}, + {"foo1", "foo1_2", map[string]bool{"foo": true, "foo1": true, "foo2": true}}, + {"foo\u0661", "foo\u0661", map[string]bool{"foo": true, "foo1": true, "foo2": true}}, + {"foo\u0661", "foo\u06612", map[string]bool{"foo": true, "foo1": true, "foo2": true, "foo\u0661": true}}, } for _, test := range tests { - got := disambiguate(test.name, func(name string) bool { return test.collides[name] }) - if !isIdent(got) { - t.Errorf("disambiguate(%q, %v) = %q; not an identifier", test.name, test.collides, got) - } - if !strings.Contains(got, test.contains) { - t.Errorf("disambiguate(%q, %v) = %q; wanted to contain %q", test.name, test.collides, got, test.contains) - } - if test.collides[got] { - t.Errorf("disambiguate(%q, %v) = %q; ", test.name, test.collides, got) - } + t.Run(fmt.Sprintf("disambiguate(%q, %v)", test.name, test.collides), func(t *testing.T) { + got := disambiguate(test.name, func(name string) bool { return test.collides[name] }) + if !isIdent(got) { + t.Errorf("%q is not an identifier", got) + } + if got != test.want { + t.Errorf("got %q want %q", got, test.want) + } + if test.collides[got] { + t.Errorf("%q collides", got) + } + }) } }