From cb1853b1af897481e395b8fb16dfaeb2d13fdd66 Mon Sep 17 00:00:00 2001 From: Ross Light Date: Fri, 30 Mar 2018 11:17:35 -0700 Subject: [PATCH] goose: use readable variable names Names are inferred from types most of the time, but have a fallback for a non-named type. Names are now also disambiguated from symbols in the same scope. Reviewed-by: Tuo Shan Reviewed-by: Herbie Ong --- README.md | 3 - internal/goose/goose.go | 184 ++++++++++++++---- internal/goose/goose_test.go | 79 ++++++++ .../goose/testdata/NamingWorstCase/foo/foo.go | 24 +++ .../testdata/NamingWorstCase/foo/goose.go | 11 ++ .../goose/testdata/NamingWorstCase/out.txt | 1 + internal/goose/testdata/NamingWorstCase/pkg | 1 + .../testdata/NoInjectParamNames/foo/foo.go | 24 +++ .../testdata/NoInjectParamNames/foo/goose.go | 14 ++ .../goose/testdata/NoInjectParamNames/out.txt | 1 + .../goose/testdata/NoInjectParamNames/pkg | 1 + 11 files changed, 305 insertions(+), 38 deletions(-) create mode 100644 internal/goose/testdata/NamingWorstCase/foo/foo.go create mode 100644 internal/goose/testdata/NamingWorstCase/foo/goose.go create mode 100644 internal/goose/testdata/NamingWorstCase/out.txt create mode 100644 internal/goose/testdata/NamingWorstCase/pkg create mode 100644 internal/goose/testdata/NoInjectParamNames/foo/foo.go create mode 100644 internal/goose/testdata/NoInjectParamNames/foo/goose.go create mode 100644 internal/goose/testdata/NoInjectParamNames/out.txt create mode 100644 internal/goose/testdata/NoInjectParamNames/pkg diff --git a/README.md b/README.md index a7130c5..66f7d54 100644 --- a/README.md +++ b/README.md @@ -208,9 +208,6 @@ type MySQLConnectionString string ## Future Work -- The names of imports and provider results in the generated code are not - actually as nice as shown above. I'd like to make them nicer in the - common cases while ensuring uniqueness. - Support for map bindings. - Support for multiple provider outputs. - Currently, all dependency satisfaction is done using identity. I'd like to diff --git a/internal/goose/goose.go b/internal/goose/goose.go index f94bfa6..b524275 100644 --- a/internal/goose/goose.go +++ b/internal/goose/goose.go @@ -14,6 +14,8 @@ import ( "sort" "strconv" "strings" + "unicode" + "unicode/utf8" "golang.org/x/tools/go/loader" "golang.org/x/tools/go/types/typeutil" @@ -43,13 +45,14 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) { return nil, fmt.Errorf("load: got %d packages", len(prog.InitialPackages())) } pkgInfo := prog.InitialPackages()[0] - g := newGen(pkgInfo.Pkg.Path()) + g := newGen(prog, pkgInfo.Pkg.Path()) mc := newProviderSetCache(prog) var directives []directive for _, f := range pkgInfo.Files { if !isInjectFile(f) { continue } + // TODO(light): use same directive extraction logic as provider set finding. fileScope := pkgInfo.Scopes[f] cmap := ast.NewCommentMap(prog.Fset, f, f.Comments) for _, decl := range f.Decls { @@ -78,7 +81,7 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) { } } } - goSrc := g.frame(pkgInfo.Pkg.Name()) + goSrc := g.frame() fmtSrc, err := format.Source(goSrc) if err != nil { // This is likely a bug from a poorly generated source file. @@ -93,24 +96,25 @@ type gen struct { currPackage string buf bytes.Buffer imports map[string]string - n int + prog *loader.Program // for determining package names } -func newGen(pkg string) *gen { +func newGen(prog *loader.Program, pkg string) *gen { return &gen{ currPackage: pkg, imports: make(map[string]string), + prog: prog, } } // frame bakes the built up source body into an unformatted Go source file. -func (g *gen) frame(pkgName string) []byte { +func (g *gen) frame() []byte { if g.buf.Len() == 0 { return nil } var buf bytes.Buffer buf.WriteString("// Code generated by goose. DO NOT EDIT.\n\n//+build !gooseinject\n\npackage ") - buf.WriteString(pkgName) + buf.WriteString(g.prog.Package(g.currPackage).Pkg.Name()) buf.WriteString("\n\n") if len(g.imports) > 0 { buf.WriteString("import (\n") @@ -120,6 +124,8 @@ func (g *gen) frame(pkgName string) []byte { } sort.Strings(imps) for _, path := range imps { + // TODO(light): Omit the local package identifier if it matches + // the package name. fmt.Fprintf(&buf, "\t%s %q\n", g.imports[path], path) } buf.WriteString(")\n\n") @@ -160,25 +166,71 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se return fmt.Errorf("inject %s: provider for %s returns error but injection not allowed to fail", name, types.TypeString(calls[i].out, nil)) } } + + // Prequalify all types. Since import disambiguation ignores local + // variables, it takes precedence. + paramTypes := make([]string, params.Len()) + for i := 0; i < params.Len(); i++ { + paramTypes[i] = types.TypeString(params.At(i).Type(), g.qualifyPkg) + } + for _, c := range calls { + g.qualifyImport(c.importPath) + } + outTypeString := types.TypeString(outType, g.qualifyPkg) + zv := zeroValue(outType, g.qualifyPkg) + // Set up local variables + paramNames := make([]string, params.Len()) + localNames := make([]string, len(calls)) + errVar := disambiguate("err", g.nameInFileScope) + collides := func(v string) bool { + if v == errVar { + return true + } + for _, a := range paramNames { + if a == v { + return true + } + } + for _, l := range localNames { + if l == v { + return true + } + } + return g.nameInFileScope(v) + } + g.p("func %s(", name) for i := 0; i < params.Len(); i++ { if i > 0 { g.p(", ") } pi := params.At(i) - g.p("%s %s", pi.Name(), types.TypeString(pi.Type(), g.qualifyPkg)) + a := pi.Name() + if a == "" || a == "_" { + a = typeVariableName(pi.Type()) + if a == "" { + a = "arg" + } + } + paramNames[i] = disambiguate(a, collides) + g.p("%s %s", paramNames[i], paramTypes[i]) } if returnsErr { - g.p(") (%s, error) {\n", types.TypeString(outType, g.qualifyPkg)) + g.p(") (%s, error) {\n", outTypeString) } else { - g.p(") %s {\n", types.TypeString(outType, g.qualifyPkg)) + g.p(") %s {\n", outTypeString) } - zv := zeroValue(outType, g.qualifyPkg) for i := range calls { c := &calls[i] - g.p("\tv%d", i) + lname := typeVariableName(c.out) + if lname == "" { + lname = "v" + } + lname = disambiguate(lname, collides) + localNames[i] = lname + g.p("\t%s", lname) if c.hasErr { - g.p(", err") + g.p(", %s", errVar) } g.p(" := %s(", g.qualifiedID(c.importPath, c.funcName)) for j, a := range c.args { @@ -186,14 +238,14 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se g.p(", ") } if a < params.Len() { - g.p("%s", params.At(a).Name()) + g.p("%s", paramNames[a]) } else { - g.p("v%d", a-params.Len()) + g.p("%s", localNames[a-params.Len()]) } } g.p(")\n") if c.hasErr { - g.p("\tif err != nil {\n") + g.p("\tif %s != nil {\n", errVar) // TODO(light): give information about failing provider g.p("\t\treturn %s, err\n", zv) g.p("\t}\n") @@ -202,12 +254,12 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se if len(calls) == 0 { for i := range given { if types.Identical(outType, given[i]) { - g.p("\treturn %s", params.At(i).Name()) + g.p("\treturn %s", paramNames[i]) break } } } else { - g.p("\treturn v%d", len(calls)-1) + g.p("\treturn %s", localNames[len(calls)-1]) } if returnsErr { g.p(", nil") @@ -231,12 +283,25 @@ func (g *gen) qualifyImport(path string) string { if name := g.imports[path]; name != "" { return name } - name := fmt.Sprintf("pkg%d", g.n) - g.n++ + // TODO(light): use parts of import path to disambiguate. + name := disambiguate(g.prog.Package(path).Pkg.Name(), func(n string) bool { + // Don't let an import take the "err" name. That's annoying. + return n == "err" || g.nameInFileScope(n) + }) g.imports[path] = name return name } +func (g *gen) nameInFileScope(name string) bool { + for _, other := range g.imports { + if other == name { + return true + } + } + _, obj := g.prog.Package(g.currPackage).Pkg.Scope().LookupParent(name, 0) + return obj != nil +} + func (g *gen) qualifyPkg(pkg *types.Package) string { return g.qualifyImport(pkg.Path()) } @@ -418,12 +483,8 @@ func (mc *providerSetCache) get(ref providerSetRef) (*providerSet, error) { if mc.sets == nil { mc.sets = make(map[string]map[string]*providerSet) } - pkg, info, files, err := mc.getpkg(ref.importPath) - if err != nil { - mc.sets[ref.importPath] = nil - return nil, fmt.Errorf("analyze package: %v", err) - } - mods, err := findProviderSets(mc.fset, pkg, info, files) + pkg := mc.prog.Package(ref.importPath) + mods, err := findProviderSets(mc.fset, pkg.Pkg, &pkg.Info, pkg.Files) if err != nil { mc.sets[ref.importPath] = nil return nil, err @@ -436,16 +497,6 @@ func (mc *providerSetCache) get(ref providerSetRef) (*providerSet, error) { return mod, nil } -func (mc *providerSetCache) getpkg(path string) (*types.Package, *types.Info, []*ast.File, error) { - // TODO(light): allow other implementations for testing - - pkg := mc.prog.Package(path) - if pkg == nil { - return nil, nil, nil, fmt.Errorf("package %q not found", path) - } - return pkg.Pkg, &pkg.Info, pkg.Files, nil -} - // solve finds the sequence of calls required to produce an output type // with an optional set of provided inputs. func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []providerSetRef) ([]call, error) { @@ -708,4 +759,67 @@ func zeroValue(t types.Type, qf types.Qualifier) string { } } +// typeVariableName invents a variable name derived from the type name +// or returns the empty string if one could not be found. +func typeVariableName(t types.Type) string { + if p, ok := t.(*types.Pointer); ok { + t = p.Elem() + } + tn, ok := t.(*types.Named) + if !ok { + return "" + } + // TODO(light): include package name when appropriate + return unexport(tn.Obj().Name()) +} + +// unexport converts a name that is potentially exported to an unexported name. +func unexport(name string) string { + r, sz := utf8.DecodeRuneInString(name) + if !unicode.IsUpper(r) { + // foo -> foo + return name + } + r2, sz2 := utf8.DecodeRuneInString(name[sz:]) + if !unicode.IsUpper(r2) { + // Foo -> foo + return string(unicode.ToLower(r)) + name[sz:] + } + // UPPERWord -> upperWord + sbuf := new(strings.Builder) + sbuf.WriteRune(unicode.ToLower(r)) + i := sz + r, sz = r2, sz2 + for unicode.IsUpper(r) && sz > 0 { + r2, sz2 := utf8.DecodeRuneInString(name[i+sz:]) + if sz2 > 0 && unicode.IsLower(r2) { + break + } + i += sz + sbuf.WriteRune(unicode.ToLower(r)) + r, sz = r2, sz2 + } + sbuf.WriteString(name[i:]) + return sbuf.String() +} + +// disambiguate picks a unique name, preferring name if it is already unique. +func disambiguate(name string, collides func(string) bool) string { + if !collides(name) { + return name + } + buf := []byte(name) + if len(buf) > 0 && buf[len(buf)-1] >= '0' && buf[len(buf)-1] <= '9' { + buf = append(buf, '_') + } + base := len(buf) + for n := 2; ; n++ { + buf = strconv.AppendInt(buf[:base], int64(n), 10) + sbuf := string(buf) + if !collides(sbuf) { + return sbuf + } + } +} + var errorType = types.Universe.Lookup("error").Type() diff --git a/internal/goose/goose_test.go b/internal/goose/goose_test.go index 23e3468..b1ef8b8 100644 --- a/internal/goose/goose_test.go +++ b/internal/goose/goose_test.go @@ -15,6 +15,8 @@ import ( "strings" "testing" "time" + "unicode" + "unicode/utf8" ) func TestGoose(t *testing.T) { @@ -130,6 +132,83 @@ func TestGoose(t *testing.T) { }) } +func TestUnexport(t *testing.T) { + tests := []struct { + name string + want string + }{ + {"a", "a"}, + {"ab", "ab"}, + {"A", "a"}, + {"AB", "ab"}, + {"A_", "a_"}, + {"ABc", "aBc"}, + {"ABC", "abc"}, + {"AB_", "ab_"}, + {"foo", "foo"}, + {"Foo", "foo"}, + {"HTTPClient", "httpClient"}, + {"IFace", "iFace"}, + {"SNAKE_CASE", "snake_CASE"}, + {"HTTP", "http"}, + } + for _, test := range tests { + if got := unexport(test.name); got != test.want { + t.Errorf("unexport(%q) = %q; want %q", test.name, got, test.want) + } + } +} + +func TestDisambiguate(t *testing.T) { + tests := []struct { + name string + collides map[string]bool + }{ + {"foo", nil}, + {"foo", map[string]bool{"foo": true}}, + {"foo", map[string]bool{"foo": true, "foo1": true, "foo2": true}}, + {"foo1", map[string]bool{"foo": true, "foo1": true, "foo2": true}}, + {"foo\u0661", map[string]bool{"foo": true, "foo1": true, "foo2": true}}, + {"foo\u0661", map[string]bool{"foo": true, "foo1": true, "foo2": true, "foo\u0661": true}}, + } + for _, test := range tests { + got := disambiguate(test.name, func(name string) bool { return test.collides[name] }) + if !isIdent(got) { + t.Errorf("disambiguate(%q, %v) = %q; not an identifier", test.name, test.collides, got) + } + if test.collides[got] { + t.Errorf("disambiguate(%q, %v) = %q; ", test.name, test.collides, got) + } + } +} + +func isIdent(s string) bool { + if len(s) == 0 { + if s == "foo" { + panic("BREAK3") + } + return false + } + r, i := utf8.DecodeRuneInString(s) + if !unicode.IsLetter(r) && r != '_' { + if s == "foo" { + panic("BREAK2") + } + return false + } + for i < len(s) { + r, sz := utf8.DecodeRuneInString(s[i:]) + if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' { + if s == "foo" { + panic("BREAK1") + } + return false + } + i += sz + } + return true +} + type testCase struct { name string pkg string diff --git a/internal/goose/testdata/NamingWorstCase/foo/foo.go b/internal/goose/testdata/NamingWorstCase/foo/foo.go new file mode 100644 index 0000000..4558109 --- /dev/null +++ b/internal/goose/testdata/NamingWorstCase/foo/foo.go @@ -0,0 +1,24 @@ +package main + +import ( + stdcontext "context" + "fmt" + "os" +) + +type context struct{} + +func main() { + c, err := inject(stdcontext.Background(), struct{}{}) + if err != nil { + fmt.Println("ERROR:", err) + os.Exit(1) + } + fmt.Println(c) +} + +//goose:provide + +func provide(ctx stdcontext.Context) (context, error) { + return context{}, nil +} diff --git a/internal/goose/testdata/NamingWorstCase/foo/goose.go b/internal/goose/testdata/NamingWorstCase/foo/goose.go new file mode 100644 index 0000000..6c276e3 --- /dev/null +++ b/internal/goose/testdata/NamingWorstCase/foo/goose.go @@ -0,0 +1,11 @@ +//+build gooseinject + +package main + +import ( + stdcontext "context" +) + +//goose:use provide + +func inject(context stdcontext.Context, err struct{}) (context, error) diff --git a/internal/goose/testdata/NamingWorstCase/out.txt b/internal/goose/testdata/NamingWorstCase/out.txt new file mode 100644 index 0000000..0967ef4 --- /dev/null +++ b/internal/goose/testdata/NamingWorstCase/out.txt @@ -0,0 +1 @@ +{} diff --git a/internal/goose/testdata/NamingWorstCase/pkg b/internal/goose/testdata/NamingWorstCase/pkg new file mode 100644 index 0000000..257cc56 --- /dev/null +++ b/internal/goose/testdata/NamingWorstCase/pkg @@ -0,0 +1 @@ +foo diff --git a/internal/goose/testdata/NoInjectParamNames/foo/foo.go b/internal/goose/testdata/NoInjectParamNames/foo/foo.go new file mode 100644 index 0000000..4558109 --- /dev/null +++ b/internal/goose/testdata/NoInjectParamNames/foo/foo.go @@ -0,0 +1,24 @@ +package main + +import ( + stdcontext "context" + "fmt" + "os" +) + +type context struct{} + +func main() { + c, err := inject(stdcontext.Background(), struct{}{}) + if err != nil { + fmt.Println("ERROR:", err) + os.Exit(1) + } + fmt.Println(c) +} + +//goose:provide + +func provide(ctx stdcontext.Context) (context, error) { + return context{}, nil +} diff --git a/internal/goose/testdata/NoInjectParamNames/foo/goose.go b/internal/goose/testdata/NoInjectParamNames/foo/goose.go new file mode 100644 index 0000000..e94e92f --- /dev/null +++ b/internal/goose/testdata/NoInjectParamNames/foo/goose.go @@ -0,0 +1,14 @@ +//+build gooseinject + +package main + +import ( + stdcontext "context" +) + +// The notable characteristic of this test is that there are no +// parameter names on the inject stub. + +//goose:use provide + +func inject(stdcontext.Context, struct{}) (context, error) diff --git a/internal/goose/testdata/NoInjectParamNames/out.txt b/internal/goose/testdata/NoInjectParamNames/out.txt new file mode 100644 index 0000000..0967ef4 --- /dev/null +++ b/internal/goose/testdata/NoInjectParamNames/out.txt @@ -0,0 +1 @@ +{} diff --git a/internal/goose/testdata/NoInjectParamNames/pkg b/internal/goose/testdata/NoInjectParamNames/pkg new file mode 100644 index 0000000..257cc56 --- /dev/null +++ b/internal/goose/testdata/NoInjectParamNames/pkg @@ -0,0 +1 @@ +foo