From 338b1da068a7ff8472fd1b0beaa5069c6ff5cb16 Mon Sep 17 00:00:00 2001 From: Ross Light Date: Wed, 27 Jun 2018 08:17:45 -0700 Subject: [PATCH] wire: store value expressions in package variables (google/go-cloud#135) The code assigns a local unnecessarily, but this should have no appreciable effect on semantics, just readability. Fixes google/go-cloud#104 --- README.md | 7 +- internal/wire/testdata/VarValue/foo/foo.go | 4 +- internal/wire/testdata/VarValue/out.txt | 2 +- internal/wire/wire.go | 77 ++++++++++++++++++---- internal/wire/wire_test.go | 31 +++++++++ 5 files changed, 103 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 4ec9e1e..14c1979 100644 --- a/README.md +++ b/README.md @@ -304,9 +304,10 @@ func injectFoo() Foo { } ``` -It's important to note that the expression will be copied, so references to -variables will be evaluated during the call to the injector. `gowire` will emit -an error if the expression calls any functions. +It's important to note that the expression will be copied to the injector's +package; references to variables will be evaluated during the injector +package's initialization. `gowire` will emit an error if the expression calls +any functions or receives from any channels. ### Cleanup functions diff --git a/internal/wire/testdata/VarValue/foo/foo.go b/internal/wire/testdata/VarValue/foo/foo.go index a81b517..0c05f7e 100644 --- a/internal/wire/testdata/VarValue/foo/foo.go +++ b/internal/wire/testdata/VarValue/foo/foo.go @@ -17,10 +17,10 @@ package main import "fmt" func main() { - // Value should be deferred until function call. + // Mutating value; value should have been stored at package initialization. msg = "Hello, World!" fmt.Println(injectedMessage()) } -var msg string +var msg string = "Package init" diff --git a/internal/wire/testdata/VarValue/out.txt b/internal/wire/testdata/VarValue/out.txt index 8ab686e..f4a93ef 100644 --- a/internal/wire/testdata/VarValue/out.txt +++ b/internal/wire/testdata/VarValue/out.txt @@ -1 +1 @@ -Hello, World! +Package init diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 11cf805..327d6b5 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -151,6 +151,7 @@ type gen struct { currPackage string buf bytes.Buffer imports map[string]string + values map[ast.Expr]string prog *loader.Program // for positions and determining package names } @@ -158,6 +159,7 @@ func newGen(prog *loader.Program, pkg string) *gen { return &gen{ currPackage: pkg, imports: make(map[string]string), + values: make(map[ast.Expr]string), prog: prog, } } @@ -207,6 +209,12 @@ func (g *gen) inject(name string, sig *types.Signature, set *ProviderSet) error if err != nil { return err } + type pendingVar struct { + name string + expr ast.Expr + typeInfo *types.Info + } + var pendingVars []pendingVar for i := range calls { c := &calls[i] if c.hasCleanup && !injectSig.cleanup { @@ -221,6 +229,16 @@ func (g *gen) inject(name string, sig *types.Signature, set *ProviderSet) error ts := types.TypeString(c.out, nil) return fmt.Errorf("inject %s: value %s can't be used: %v", name, ts, err) } + if g.values[c.valueExpr] == "" { + t := c.valueTypeInfo.TypeOf(c.valueExpr) + name := disambiguate("_wire"+export(typeVariableName(t))+"Value", g.nameInFileScope) + g.values[c.valueExpr] = name + pendingVars = append(pendingVars, pendingVar{ + name: name, + expr: c.valueExpr, + typeInfo: c.valueTypeInfo, + }) + } } } @@ -235,6 +253,15 @@ func (g *gen) inject(name string, sig *types.Signature, set *ProviderSet) error errVar: disambiguate("err", g.nameInFileScope), discard: false, }) + if len(pendingVars) > 0 { + g.p("var (\n") + for _, pv := range pendingVars { + g.p("\t%s = ", pv.name) + g.writeAST(pv.typeInfo, pv.expr) + g.p("\n") + } + g.p(")\n\n") + } return nil } @@ -398,6 +425,11 @@ func (g *gen) nameInFileScope(name string) bool { return true } } + for _, other := range g.values { + if other == name { + return true + } + } _, obj := g.prog.Package(g.currPackage).Pkg.Scope().LookupParent(name, 0) return obj != nil } @@ -434,7 +466,7 @@ func injectPass(name string, params *types.Tuple, injectSig outputSignature, cal pi := params.At(i) a := pi.Name() if a == "" || a == "_" { - a = typeVariableName(pi.Type()) + a = unexport(typeVariableName(pi.Type())) if a == "" { a = "arg" } @@ -454,7 +486,7 @@ func injectPass(name string, params *types.Tuple, injectSig outputSignature, cal } for i := range calls { c := &calls[i] - lname := typeVariableName(c.out) + lname := unexport(typeVariableName(c.out)) if lname == "" { lname = "v" } @@ -553,10 +585,7 @@ func (ig *injectorGen) structProviderCall(lname string, c *call) { } func (ig *injectorGen) valueExpr(lname string, c *call) { - ig.p("\t%s", lname) - ig.p(" := ") - ig.writeAST(c.valueTypeInfo, c.valueExpr) - ig.p("\n") + ig.p("\t%s := %s\n", lname, ig.g.values[c.valueExpr]) } // nameInInjector reports whether name collides with any other identifier @@ -626,21 +655,28 @@ func zeroValue(t types.Type, qf types.Qualifier) string { } // typeVariableName invents a variable name derived from the type name -// or returns the empty string if one could not be found. +// or returns the empty string if one could not be found. There are no +// guarantees about whether the name is exported or unexported: call +// export() or unexport() to convert. func typeVariableName(t types.Type) string { if p, ok := t.(*types.Pointer); ok { t = p.Elem() } - tn, ok := t.(*types.Named) - if !ok { - return "" + switch t := t.(type) { + case *types.Basic: + return t.Name() + case *types.Named: + // TODO(light): Include package name when appropriate. + return t.Obj().Name() } - // TODO(light): Include package name when appropriate. - return unexport(tn.Obj().Name()) + return "" } // unexport converts a name that is potentially exported to an unexported name. func unexport(name string) string { + if name == "" { + return "" + } r, sz := utf8.DecodeRuneInString(name) if !unicode.IsUpper(r) { // foo -> foo @@ -669,6 +705,23 @@ func unexport(name string) string { return sbuf.String() } +// export converts a name that is potentially unexported to an exported name. +func export(name string) string { + if name == "" { + return "" + } + r, sz := utf8.DecodeRuneInString(name) + if unicode.IsUpper(r) { + // Foo -> Foo + return name + } + // fooBar -> FooBar + sbuf := new(strings.Builder) + sbuf.WriteRune(unicode.ToUpper(r)) + sbuf.WriteString(name[sz:]) + return sbuf.String() +} + // disambiguate picks a unique name, preferring name if it is already unique. func disambiguate(name string, collides func(string) bool) string { if !collides(name) { diff --git a/internal/wire/wire_test.go b/internal/wire/wire_test.go index d89da7d..08cd07e 100644 --- a/internal/wire/wire_test.go +++ b/internal/wire/wire_test.go @@ -157,6 +157,7 @@ func TestUnexport(t *testing.T) { name string want string }{ + {"", ""}, {"a", "a"}, {"ab", "ab"}, {"A", "a"}, @@ -179,6 +180,36 @@ func TestUnexport(t *testing.T) { } } +func TestExport(t *testing.T) { + tests := []struct { + name string + want string + }{ + {"", ""}, + {"a", "A"}, + {"ab", "Ab"}, + {"A", "A"}, + {"AB", "AB"}, + {"A_", "A_"}, + {"ABc", "ABc"}, + {"ABC", "ABC"}, + {"AB_", "AB_"}, + {"foo", "Foo"}, + {"Foo", "Foo"}, + {"HTTPClient", "HTTPClient"}, + {"httpClient", "HttpClient"}, + {"IFace", "IFace"}, + {"iFace", "IFace"}, + {"SNAKE_CASE", "SNAKE_CASE"}, + {"HTTP", "HTTP"}, + } + for _, test := range tests { + if got := export(test.name); got != test.want { + t.Errorf("export(%q) = %q; want %q", test.name, got, test.want) + } + } +} + func TestDisambiguate(t *testing.T) { tests := []struct { name string