wire: omit the local package identifier if it matches the package name (google/go-cloud#385)

Fixes google/go-cloud#424
This commit is contained in:
Robert van Gent
2018-09-12 14:05:54 -07:00
committed by Ross Light
parent c999a4d1b5
commit a8c7c0b8e1
10 changed files with 35 additions and 24 deletions

View File

@@ -6,7 +6,7 @@
package main package main
import ( import (
fmt "fmt" "fmt"
) )
// Injectors from foo.go: // Injectors from foo.go:

View File

@@ -6,7 +6,7 @@
package main package main
import ( import (
bar "example.com/bar" "example.com/bar"
) )
// Injectors from wire.go: // Injectors from wire.go:

View File

@@ -6,7 +6,7 @@
package main package main
import ( import (
os "os" "os"
) )
// Injectors from wire.go: // Injectors from wire.go:

View File

@@ -6,7 +6,7 @@
package main package main
import ( import (
foo "example.com/foo" "example.com/foo"
) )
// Injectors from wire.go: // Injectors from wire.go:

View File

@@ -6,8 +6,8 @@
package main package main
import ( import (
io "io" "io"
strings "strings" "strings"
) )
// Injectors from wire.go: // Injectors from wire.go:

View File

@@ -6,10 +6,10 @@
package main package main
import ( import (
bar "example.com/bar" "example.com/bar"
baz "example.com/baz" "example.com/baz"
foo "example.com/foo" "example.com/foo"
fmt "fmt" "fmt"
) )
// Injectors from wire.go: // Injectors from wire.go:

View File

@@ -7,9 +7,9 @@ package main
import ( import (
context2 "context" context2 "context"
fmt "fmt" "fmt"
os "os" "os"
reflect "reflect" "reflect"
) )
// Injectors from foo.go: // Injectors from foo.go:

View File

@@ -6,7 +6,7 @@
package main package main
import ( import (
bar "example.com/bar" "example.com/bar"
) )
// Injectors from wire.go: // Injectors from wire.go:

View File

@@ -6,7 +6,7 @@
package main package main
import ( import (
bar "example.com/bar" "example.com/bar"
) )
// Injectors from wire.go: // Injectors from wire.go:

View File

@@ -134,11 +134,18 @@ func copyNonInjectorDecls(g *gen, files []*ast.File, info *types.Info) {
} }
} }
// importInfo holds info about an import.
type importInfo struct {
name string
// fullpath is the full, possibly vendored, path.
fullpath string
}
// gen is the file-wide generator state. // gen is the file-wide generator state.
type gen struct { type gen struct {
currPackage string currPackage string
buf bytes.Buffer buf bytes.Buffer
imports map[string]string imports map[string]*importInfo
values map[ast.Expr]string values map[ast.Expr]string
prog *loader.Program // for positions and determining package names prog *loader.Program // for positions and determining package names
} }
@@ -146,7 +153,7 @@ type gen struct {
func newGen(prog *loader.Program, pkg string) *gen { func newGen(prog *loader.Program, pkg string) *gen {
return &gen{ return &gen{
currPackage: pkg, currPackage: pkg,
imports: make(map[string]string), imports: make(map[string]*importInfo),
values: make(map[ast.Expr]string), values: make(map[ast.Expr]string),
prog: prog, prog: prog,
} }
@@ -172,9 +179,13 @@ func (g *gen) frame() []byte {
} }
sort.Strings(imps) sort.Strings(imps)
for _, path := range imps { for _, path := range imps {
// TODO(light): Omit the local package identifier if it matches // Omit the local package identifier if it matches the package name.
// the package name. info := g.imports[path]
fmt.Fprintf(&buf, "\t%s %q\n", g.imports[path], path) if g.prog.Package(info.fullpath).Pkg.Name() == info.name {
fmt.Fprintf(&buf, "\t%q\n", path)
} else {
fmt.Fprintf(&buf, "\t%s %q\n", info.name, path)
}
} }
buf.WriteString(")\n\n") buf.WriteString(")\n\n")
} }
@@ -414,21 +425,21 @@ func (g *gen) qualifyImport(path string) string {
if i := strings.LastIndex(path, vendorPart); i != -1 && (i == 0 || path[i-1] == '/') { if i := strings.LastIndex(path, vendorPart); i != -1 && (i == 0 || path[i-1] == '/') {
unvendored = path[i+len(vendorPart):] unvendored = path[i+len(vendorPart):]
} }
if name := g.imports[unvendored]; name != "" { if info := g.imports[unvendored]; info != nil {
return name return info.name
} }
// TODO(light): Use parts of import path to disambiguate. // TODO(light): Use parts of import path to disambiguate.
name := disambiguate(g.prog.Package(path).Pkg.Name(), func(n string) bool { name := disambiguate(g.prog.Package(path).Pkg.Name(), func(n string) bool {
// Don't let an import take the "err" name. That's annoying. // Don't let an import take the "err" name. That's annoying.
return n == "err" || g.nameInFileScope(n) return n == "err" || g.nameInFileScope(n)
}) })
g.imports[unvendored] = name g.imports[unvendored] = &importInfo{name: name, fullpath: path}
return name return name
} }
func (g *gen) nameInFileScope(name string) bool { func (g *gen) nameInFileScope(name string) bool {
for _, other := range g.imports { for _, other := range g.imports {
if other == name { if other.name == name {
return true return true
} }
} }