wire: store value expressions in package variables (google/go-cloud#135)

The code assigns a local unnecessarily, but this should have no
appreciable effect on semantics, just readability.

Fixes google/go-cloud#104
This commit is contained in:
Ross Light
2018-06-27 08:17:45 -07:00
parent 2eb9d5ea1f
commit 338b1da068
5 changed files with 103 additions and 18 deletions

View File

@@ -304,9 +304,10 @@ func injectFoo() Foo {
}
```
It's important to note that the expression will be copied, so references to
variables will be evaluated during the call to the injector. `gowire` will emit
an error if the expression calls any functions.
It's important to note that the expression will be copied to the injector's
package; references to variables will be evaluated during the injector
package's initialization. `gowire` will emit an error if the expression calls
any functions or receives from any channels.
### Cleanup functions

View File

@@ -17,10 +17,10 @@ package main
import "fmt"
func main() {
// Value should be deferred until function call.
// Mutating value; value should have been stored at package initialization.
msg = "Hello, World!"
fmt.Println(injectedMessage())
}
var msg string
var msg string = "Package init"

View File

@@ -1 +1 @@
Hello, World!
Package init

View File

@@ -151,6 +151,7 @@ type gen struct {
currPackage string
buf bytes.Buffer
imports map[string]string
values map[ast.Expr]string
prog *loader.Program // for positions and determining package names
}
@@ -158,6 +159,7 @@ func newGen(prog *loader.Program, pkg string) *gen {
return &gen{
currPackage: pkg,
imports: make(map[string]string),
values: make(map[ast.Expr]string),
prog: prog,
}
}
@@ -207,6 +209,12 @@ func (g *gen) inject(name string, sig *types.Signature, set *ProviderSet) error
if err != nil {
return err
}
type pendingVar struct {
name string
expr ast.Expr
typeInfo *types.Info
}
var pendingVars []pendingVar
for i := range calls {
c := &calls[i]
if c.hasCleanup && !injectSig.cleanup {
@@ -221,6 +229,16 @@ func (g *gen) inject(name string, sig *types.Signature, set *ProviderSet) error
ts := types.TypeString(c.out, nil)
return fmt.Errorf("inject %s: value %s can't be used: %v", name, ts, err)
}
if g.values[c.valueExpr] == "" {
t := c.valueTypeInfo.TypeOf(c.valueExpr)
name := disambiguate("_wire"+export(typeVariableName(t))+"Value", g.nameInFileScope)
g.values[c.valueExpr] = name
pendingVars = append(pendingVars, pendingVar{
name: name,
expr: c.valueExpr,
typeInfo: c.valueTypeInfo,
})
}
}
}
@@ -235,6 +253,15 @@ func (g *gen) inject(name string, sig *types.Signature, set *ProviderSet) error
errVar: disambiguate("err", g.nameInFileScope),
discard: false,
})
if len(pendingVars) > 0 {
g.p("var (\n")
for _, pv := range pendingVars {
g.p("\t%s = ", pv.name)
g.writeAST(pv.typeInfo, pv.expr)
g.p("\n")
}
g.p(")\n\n")
}
return nil
}
@@ -398,6 +425,11 @@ func (g *gen) nameInFileScope(name string) bool {
return true
}
}
for _, other := range g.values {
if other == name {
return true
}
}
_, obj := g.prog.Package(g.currPackage).Pkg.Scope().LookupParent(name, 0)
return obj != nil
}
@@ -434,7 +466,7 @@ func injectPass(name string, params *types.Tuple, injectSig outputSignature, cal
pi := params.At(i)
a := pi.Name()
if a == "" || a == "_" {
a = typeVariableName(pi.Type())
a = unexport(typeVariableName(pi.Type()))
if a == "" {
a = "arg"
}
@@ -454,7 +486,7 @@ func injectPass(name string, params *types.Tuple, injectSig outputSignature, cal
}
for i := range calls {
c := &calls[i]
lname := typeVariableName(c.out)
lname := unexport(typeVariableName(c.out))
if lname == "" {
lname = "v"
}
@@ -553,10 +585,7 @@ func (ig *injectorGen) structProviderCall(lname string, c *call) {
}
func (ig *injectorGen) valueExpr(lname string, c *call) {
ig.p("\t%s", lname)
ig.p(" := ")
ig.writeAST(c.valueTypeInfo, c.valueExpr)
ig.p("\n")
ig.p("\t%s := %s\n", lname, ig.g.values[c.valueExpr])
}
// nameInInjector reports whether name collides with any other identifier
@@ -626,21 +655,28 @@ func zeroValue(t types.Type, qf types.Qualifier) string {
}
// typeVariableName invents a variable name derived from the type name
// or returns the empty string if one could not be found.
// or returns the empty string if one could not be found. There are no
// guarantees about whether the name is exported or unexported: call
// export() or unexport() to convert.
func typeVariableName(t types.Type) string {
if p, ok := t.(*types.Pointer); ok {
t = p.Elem()
}
tn, ok := t.(*types.Named)
if !ok {
return ""
switch t := t.(type) {
case *types.Basic:
return t.Name()
case *types.Named:
// TODO(light): Include package name when appropriate.
return t.Obj().Name()
}
// TODO(light): Include package name when appropriate.
return unexport(tn.Obj().Name())
return ""
}
// unexport converts a name that is potentially exported to an unexported name.
func unexport(name string) string {
if name == "" {
return ""
}
r, sz := utf8.DecodeRuneInString(name)
if !unicode.IsUpper(r) {
// foo -> foo
@@ -669,6 +705,23 @@ func unexport(name string) string {
return sbuf.String()
}
// export converts a name that is potentially unexported to an exported name.
func export(name string) string {
if name == "" {
return ""
}
r, sz := utf8.DecodeRuneInString(name)
if unicode.IsUpper(r) {
// Foo -> Foo
return name
}
// fooBar -> FooBar
sbuf := new(strings.Builder)
sbuf.WriteRune(unicode.ToUpper(r))
sbuf.WriteString(name[sz:])
return sbuf.String()
}
// disambiguate picks a unique name, preferring name if it is already unique.
func disambiguate(name string, collides func(string) bool) string {
if !collides(name) {

View File

@@ -157,6 +157,7 @@ func TestUnexport(t *testing.T) {
name string
want string
}{
{"", ""},
{"a", "a"},
{"ab", "ab"},
{"A", "a"},
@@ -179,6 +180,36 @@ func TestUnexport(t *testing.T) {
}
}
func TestExport(t *testing.T) {
tests := []struct {
name string
want string
}{
{"", ""},
{"a", "A"},
{"ab", "Ab"},
{"A", "A"},
{"AB", "AB"},
{"A_", "A_"},
{"ABc", "ABc"},
{"ABC", "ABC"},
{"AB_", "AB_"},
{"foo", "Foo"},
{"Foo", "Foo"},
{"HTTPClient", "HTTPClient"},
{"httpClient", "HttpClient"},
{"IFace", "IFace"},
{"iFace", "IFace"},
{"SNAKE_CASE", "SNAKE_CASE"},
{"HTTP", "HTTP"},
}
for _, test := range tests {
if got := export(test.name); got != test.want {
t.Errorf("export(%q) = %q; want %q", test.name, got, test.want)
}
}
}
func TestDisambiguate(t *testing.T) {
tests := []struct {
name string