goose: use readable variable names

Names are inferred from types most of the time, but have a fallback for
a non-named type. Names are now also disambiguated from symbols in the
same scope.

Reviewed-by: Tuo Shan <shantuo@google.com>
Reviewed-by: Herbie Ong <herbie@google.com>
This commit is contained in:
Ross Light
2018-03-30 11:17:35 -07:00
parent 50dbe5a65d
commit cb1853b1af
11 changed files with 305 additions and 38 deletions

View File

@@ -208,9 +208,6 @@ type MySQLConnectionString string
## Future Work
- The names of imports and provider results in the generated code are not
actually as nice as shown above. I'd like to make them nicer in the
common cases while ensuring uniqueness.
- Support for map bindings.
- Support for multiple provider outputs.
- Currently, all dependency satisfaction is done using identity. I'd like to

View File

@@ -14,6 +14,8 @@ import (
"sort"
"strconv"
"strings"
"unicode"
"unicode/utf8"
"golang.org/x/tools/go/loader"
"golang.org/x/tools/go/types/typeutil"
@@ -43,13 +45,14 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
return nil, fmt.Errorf("load: got %d packages", len(prog.InitialPackages()))
}
pkgInfo := prog.InitialPackages()[0]
g := newGen(pkgInfo.Pkg.Path())
g := newGen(prog, pkgInfo.Pkg.Path())
mc := newProviderSetCache(prog)
var directives []directive
for _, f := range pkgInfo.Files {
if !isInjectFile(f) {
continue
}
// TODO(light): use same directive extraction logic as provider set finding.
fileScope := pkgInfo.Scopes[f]
cmap := ast.NewCommentMap(prog.Fset, f, f.Comments)
for _, decl := range f.Decls {
@@ -78,7 +81,7 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
}
}
}
goSrc := g.frame(pkgInfo.Pkg.Name())
goSrc := g.frame()
fmtSrc, err := format.Source(goSrc)
if err != nil {
// This is likely a bug from a poorly generated source file.
@@ -93,24 +96,25 @@ type gen struct {
currPackage string
buf bytes.Buffer
imports map[string]string
n int
prog *loader.Program // for determining package names
}
func newGen(pkg string) *gen {
func newGen(prog *loader.Program, pkg string) *gen {
return &gen{
currPackage: pkg,
imports: make(map[string]string),
prog: prog,
}
}
// frame bakes the built up source body into an unformatted Go source file.
func (g *gen) frame(pkgName string) []byte {
func (g *gen) frame() []byte {
if g.buf.Len() == 0 {
return nil
}
var buf bytes.Buffer
buf.WriteString("// Code generated by goose. DO NOT EDIT.\n\n//+build !gooseinject\n\npackage ")
buf.WriteString(pkgName)
buf.WriteString(g.prog.Package(g.currPackage).Pkg.Name())
buf.WriteString("\n\n")
if len(g.imports) > 0 {
buf.WriteString("import (\n")
@@ -120,6 +124,8 @@ func (g *gen) frame(pkgName string) []byte {
}
sort.Strings(imps)
for _, path := range imps {
// TODO(light): Omit the local package identifier if it matches
// the package name.
fmt.Fprintf(&buf, "\t%s %q\n", g.imports[path], path)
}
buf.WriteString(")\n\n")
@@ -160,25 +166,71 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se
return fmt.Errorf("inject %s: provider for %s returns error but injection not allowed to fail", name, types.TypeString(calls[i].out, nil))
}
}
// Prequalify all types. Since import disambiguation ignores local
// variables, it takes precedence.
paramTypes := make([]string, params.Len())
for i := 0; i < params.Len(); i++ {
paramTypes[i] = types.TypeString(params.At(i).Type(), g.qualifyPkg)
}
for _, c := range calls {
g.qualifyImport(c.importPath)
}
outTypeString := types.TypeString(outType, g.qualifyPkg)
zv := zeroValue(outType, g.qualifyPkg)
// Set up local variables
paramNames := make([]string, params.Len())
localNames := make([]string, len(calls))
errVar := disambiguate("err", g.nameInFileScope)
collides := func(v string) bool {
if v == errVar {
return true
}
for _, a := range paramNames {
if a == v {
return true
}
}
for _, l := range localNames {
if l == v {
return true
}
}
return g.nameInFileScope(v)
}
g.p("func %s(", name)
for i := 0; i < params.Len(); i++ {
if i > 0 {
g.p(", ")
}
pi := params.At(i)
g.p("%s %s", pi.Name(), types.TypeString(pi.Type(), g.qualifyPkg))
a := pi.Name()
if a == "" || a == "_" {
a = typeVariableName(pi.Type())
if a == "" {
a = "arg"
}
}
paramNames[i] = disambiguate(a, collides)
g.p("%s %s", paramNames[i], paramTypes[i])
}
if returnsErr {
g.p(") (%s, error) {\n", types.TypeString(outType, g.qualifyPkg))
g.p(") (%s, error) {\n", outTypeString)
} else {
g.p(") %s {\n", types.TypeString(outType, g.qualifyPkg))
g.p(") %s {\n", outTypeString)
}
zv := zeroValue(outType, g.qualifyPkg)
for i := range calls {
c := &calls[i]
g.p("\tv%d", i)
lname := typeVariableName(c.out)
if lname == "" {
lname = "v"
}
lname = disambiguate(lname, collides)
localNames[i] = lname
g.p("\t%s", lname)
if c.hasErr {
g.p(", err")
g.p(", %s", errVar)
}
g.p(" := %s(", g.qualifiedID(c.importPath, c.funcName))
for j, a := range c.args {
@@ -186,14 +238,14 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se
g.p(", ")
}
if a < params.Len() {
g.p("%s", params.At(a).Name())
g.p("%s", paramNames[a])
} else {
g.p("v%d", a-params.Len())
g.p("%s", localNames[a-params.Len()])
}
}
g.p(")\n")
if c.hasErr {
g.p("\tif err != nil {\n")
g.p("\tif %s != nil {\n", errVar)
// TODO(light): give information about failing provider
g.p("\t\treturn %s, err\n", zv)
g.p("\t}\n")
@@ -202,12 +254,12 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se
if len(calls) == 0 {
for i := range given {
if types.Identical(outType, given[i]) {
g.p("\treturn %s", params.At(i).Name())
g.p("\treturn %s", paramNames[i])
break
}
}
} else {
g.p("\treturn v%d", len(calls)-1)
g.p("\treturn %s", localNames[len(calls)-1])
}
if returnsErr {
g.p(", nil")
@@ -231,12 +283,25 @@ func (g *gen) qualifyImport(path string) string {
if name := g.imports[path]; name != "" {
return name
}
name := fmt.Sprintf("pkg%d", g.n)
g.n++
// TODO(light): use parts of import path to disambiguate.
name := disambiguate(g.prog.Package(path).Pkg.Name(), func(n string) bool {
// Don't let an import take the "err" name. That's annoying.
return n == "err" || g.nameInFileScope(n)
})
g.imports[path] = name
return name
}
func (g *gen) nameInFileScope(name string) bool {
for _, other := range g.imports {
if other == name {
return true
}
}
_, obj := g.prog.Package(g.currPackage).Pkg.Scope().LookupParent(name, 0)
return obj != nil
}
func (g *gen) qualifyPkg(pkg *types.Package) string {
return g.qualifyImport(pkg.Path())
}
@@ -418,12 +483,8 @@ func (mc *providerSetCache) get(ref providerSetRef) (*providerSet, error) {
if mc.sets == nil {
mc.sets = make(map[string]map[string]*providerSet)
}
pkg, info, files, err := mc.getpkg(ref.importPath)
if err != nil {
mc.sets[ref.importPath] = nil
return nil, fmt.Errorf("analyze package: %v", err)
}
mods, err := findProviderSets(mc.fset, pkg, info, files)
pkg := mc.prog.Package(ref.importPath)
mods, err := findProviderSets(mc.fset, pkg.Pkg, &pkg.Info, pkg.Files)
if err != nil {
mc.sets[ref.importPath] = nil
return nil, err
@@ -436,16 +497,6 @@ func (mc *providerSetCache) get(ref providerSetRef) (*providerSet, error) {
return mod, nil
}
func (mc *providerSetCache) getpkg(path string) (*types.Package, *types.Info, []*ast.File, error) {
// TODO(light): allow other implementations for testing
pkg := mc.prog.Package(path)
if pkg == nil {
return nil, nil, nil, fmt.Errorf("package %q not found", path)
}
return pkg.Pkg, &pkg.Info, pkg.Files, nil
}
// solve finds the sequence of calls required to produce an output type
// with an optional set of provided inputs.
func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []providerSetRef) ([]call, error) {
@@ -708,4 +759,67 @@ func zeroValue(t types.Type, qf types.Qualifier) string {
}
}
// typeVariableName invents a variable name derived from the type name
// or returns the empty string if one could not be found.
func typeVariableName(t types.Type) string {
if p, ok := t.(*types.Pointer); ok {
t = p.Elem()
}
tn, ok := t.(*types.Named)
if !ok {
return ""
}
// TODO(light): include package name when appropriate
return unexport(tn.Obj().Name())
}
// unexport converts a name that is potentially exported to an unexported name.
func unexport(name string) string {
r, sz := utf8.DecodeRuneInString(name)
if !unicode.IsUpper(r) {
// foo -> foo
return name
}
r2, sz2 := utf8.DecodeRuneInString(name[sz:])
if !unicode.IsUpper(r2) {
// Foo -> foo
return string(unicode.ToLower(r)) + name[sz:]
}
// UPPERWord -> upperWord
sbuf := new(strings.Builder)
sbuf.WriteRune(unicode.ToLower(r))
i := sz
r, sz = r2, sz2
for unicode.IsUpper(r) && sz > 0 {
r2, sz2 := utf8.DecodeRuneInString(name[i+sz:])
if sz2 > 0 && unicode.IsLower(r2) {
break
}
i += sz
sbuf.WriteRune(unicode.ToLower(r))
r, sz = r2, sz2
}
sbuf.WriteString(name[i:])
return sbuf.String()
}
// disambiguate picks a unique name, preferring name if it is already unique.
func disambiguate(name string, collides func(string) bool) string {
if !collides(name) {
return name
}
buf := []byte(name)
if len(buf) > 0 && buf[len(buf)-1] >= '0' && buf[len(buf)-1] <= '9' {
buf = append(buf, '_')
}
base := len(buf)
for n := 2; ; n++ {
buf = strconv.AppendInt(buf[:base], int64(n), 10)
sbuf := string(buf)
if !collides(sbuf) {
return sbuf
}
}
}
var errorType = types.Universe.Lookup("error").Type()

View File

@@ -15,6 +15,8 @@ import (
"strings"
"testing"
"time"
"unicode"
"unicode/utf8"
)
func TestGoose(t *testing.T) {
@@ -130,6 +132,83 @@ func TestGoose(t *testing.T) {
})
}
func TestUnexport(t *testing.T) {
tests := []struct {
name string
want string
}{
{"a", "a"},
{"ab", "ab"},
{"A", "a"},
{"AB", "ab"},
{"A_", "a_"},
{"ABc", "aBc"},
{"ABC", "abc"},
{"AB_", "ab_"},
{"foo", "foo"},
{"Foo", "foo"},
{"HTTPClient", "httpClient"},
{"IFace", "iFace"},
{"SNAKE_CASE", "snake_CASE"},
{"HTTP", "http"},
}
for _, test := range tests {
if got := unexport(test.name); got != test.want {
t.Errorf("unexport(%q) = %q; want %q", test.name, got, test.want)
}
}
}
func TestDisambiguate(t *testing.T) {
tests := []struct {
name string
collides map[string]bool
}{
{"foo", nil},
{"foo", map[string]bool{"foo": true}},
{"foo", map[string]bool{"foo": true, "foo1": true, "foo2": true}},
{"foo1", map[string]bool{"foo": true, "foo1": true, "foo2": true}},
{"foo\u0661", map[string]bool{"foo": true, "foo1": true, "foo2": true}},
{"foo\u0661", map[string]bool{"foo": true, "foo1": true, "foo2": true, "foo\u0661": true}},
}
for _, test := range tests {
got := disambiguate(test.name, func(name string) bool { return test.collides[name] })
if !isIdent(got) {
t.Errorf("disambiguate(%q, %v) = %q; not an identifier", test.name, test.collides, got)
}
if test.collides[got] {
t.Errorf("disambiguate(%q, %v) = %q; ", test.name, test.collides, got)
}
}
}
func isIdent(s string) bool {
if len(s) == 0 {
if s == "foo" {
panic("BREAK3")
}
return false
}
r, i := utf8.DecodeRuneInString(s)
if !unicode.IsLetter(r) && r != '_' {
if s == "foo" {
panic("BREAK2")
}
return false
}
for i < len(s) {
r, sz := utf8.DecodeRuneInString(s[i:])
if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' {
if s == "foo" {
panic("BREAK1")
}
return false
}
i += sz
}
return true
}
type testCase struct {
name string
pkg string

View File

@@ -0,0 +1,24 @@
package main
import (
stdcontext "context"
"fmt"
"os"
)
type context struct{}
func main() {
c, err := inject(stdcontext.Background(), struct{}{})
if err != nil {
fmt.Println("ERROR:", err)
os.Exit(1)
}
fmt.Println(c)
}
//goose:provide
func provide(ctx stdcontext.Context) (context, error) {
return context{}, nil
}

View File

@@ -0,0 +1,11 @@
//+build gooseinject
package main
import (
stdcontext "context"
)
//goose:use provide
func inject(context stdcontext.Context, err struct{}) (context, error)

View File

@@ -0,0 +1 @@
{}

View File

@@ -0,0 +1 @@
foo

View File

@@ -0,0 +1,24 @@
package main
import (
stdcontext "context"
"fmt"
"os"
)
type context struct{}
func main() {
c, err := inject(stdcontext.Background(), struct{}{})
if err != nil {
fmt.Println("ERROR:", err)
os.Exit(1)
}
fmt.Println(c)
}
//goose:provide
func provide(ctx stdcontext.Context) (context, error) {
return context{}, nil
}

View File

@@ -0,0 +1,14 @@
//+build gooseinject
package main
import (
stdcontext "context"
)
// The notable characteristic of this test is that there are no
// parameter names on the inject stub.
//goose:use provide
func inject(stdcontext.Context, struct{}) (context, error)

View File

@@ -0,0 +1 @@
{}

View File

@@ -0,0 +1 @@
foo