goose: strip vendor from generated import paths
This allows goose to work more gracefully in a vgo setting. Reviewed-by: Tuo Shan <shantuo@google.com>
This commit is contained in:
@@ -44,7 +44,8 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
|
||||
}
|
||||
pkgInfo := prog.InitialPackages()[0]
|
||||
g := newGen(prog, pkgInfo.Pkg.Path())
|
||||
mc := newProviderSetCache(prog)
|
||||
r := newImportResolver(conf, prog.Fset)
|
||||
mc := newProviderSetCache(prog, r)
|
||||
var directives []directive
|
||||
for _, f := range pkgInfo.Files {
|
||||
if !isInjectFile(f) {
|
||||
@@ -67,7 +68,7 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
|
||||
if d.kind != "use" {
|
||||
return nil, fmt.Errorf("%v: cannot use %s directive on inject function", prog.Fset.Position(d.pos), d.kind)
|
||||
}
|
||||
ref, err := parseProviderSetRef(d.line, fileScope, g.currPackage, d.pos)
|
||||
ref, err := parseProviderSetRef(r, d.line, fileScope, g.currPackage, d.pos)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%v: %v", prog.Fset.Position(d.pos), err)
|
||||
}
|
||||
@@ -278,7 +279,13 @@ func (g *gen) qualifyImport(path string) string {
|
||||
if path == g.currPackage {
|
||||
return ""
|
||||
}
|
||||
if name := g.imports[path]; name != "" {
|
||||
// TODO(light): this is depending on details of the current loader.
|
||||
const vendorPart = "vendor/"
|
||||
unvendored := path
|
||||
if i := strings.LastIndex(path, vendorPart); i != -1 && (i == 0 || path[i-1] == '/') {
|
||||
unvendored = path[i+len(vendorPart):]
|
||||
}
|
||||
if name := g.imports[unvendored]; name != "" {
|
||||
return name
|
||||
}
|
||||
// TODO(light): use parts of import path to disambiguate.
|
||||
@@ -286,7 +293,7 @@ func (g *gen) qualifyImport(path string) string {
|
||||
// Don't let an import take the "err" name. That's annoying.
|
||||
return n == "err" || g.nameInFileScope(n)
|
||||
})
|
||||
g.imports[path] = name
|
||||
g.imports[unvendored] = name
|
||||
return name
|
||||
}
|
||||
|
||||
|
||||
@@ -3,8 +3,10 @@ package goose
|
||||
import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/build"
|
||||
"go/token"
|
||||
"go/types"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -34,7 +36,7 @@ type providerInfo struct {
|
||||
}
|
||||
|
||||
// findProviderSets processes a package and extracts the provider sets declared in it.
|
||||
func findProviderSets(fset *token.FileSet, pkg *types.Package, typeInfo *types.Info, files []*ast.File) (map[string]*providerSet, error) {
|
||||
func findProviderSets(fset *token.FileSet, pkg *types.Package, r *importResolver, typeInfo *types.Info, files []*ast.File) (map[string]*providerSet, error) {
|
||||
sets := make(map[string]*providerSet)
|
||||
var directives []directive
|
||||
for _, f := range files {
|
||||
@@ -55,7 +57,7 @@ func findProviderSets(fset *token.FileSet, pkg *types.Package, typeInfo *types.I
|
||||
return nil, fmt.Errorf("%s: invalid import: expected TARGET SETREF", fset.Position(d.pos))
|
||||
}
|
||||
name, spec := d.line[:i], d.line[i+1:]
|
||||
ref, err := parseProviderSetRef(spec, fileScope, pkg.Path(), d.pos)
|
||||
ref, err := parseProviderSetRef(r, spec, fileScope, pkg.Path(), d.pos)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%v: %v", fset.Position(d.pos), err)
|
||||
}
|
||||
@@ -174,12 +176,14 @@ type providerSetCache struct {
|
||||
sets map[string]map[string]*providerSet
|
||||
fset *token.FileSet
|
||||
prog *loader.Program
|
||||
r *importResolver
|
||||
}
|
||||
|
||||
func newProviderSetCache(prog *loader.Program) *providerSetCache {
|
||||
func newProviderSetCache(prog *loader.Program, r *importResolver) *providerSetCache {
|
||||
return &providerSetCache{
|
||||
fset: prog.Fset,
|
||||
prog: prog,
|
||||
r: r,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -195,7 +199,7 @@ func (mc *providerSetCache) get(ref providerSetRef) (*providerSet, error) {
|
||||
mc.sets = make(map[string]map[string]*providerSet)
|
||||
}
|
||||
pkg := mc.prog.Package(ref.importPath)
|
||||
mods, err := findProviderSets(mc.fset, pkg.Pkg, &pkg.Info, pkg.Files)
|
||||
mods, err := findProviderSets(mc.fset, pkg.Pkg, mc.r, &pkg.Info, pkg.Files)
|
||||
if err != nil {
|
||||
mc.sets[ref.importPath] = nil
|
||||
return nil, err
|
||||
@@ -214,7 +218,7 @@ type providerSetRef struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func parseProviderSetRef(ref string, s *types.Scope, pkg string, pos token.Pos) (providerSetRef, error) {
|
||||
func parseProviderSetRef(r *importResolver, ref string, s *types.Scope, pkg string, pos token.Pos) (providerSetRef, error) {
|
||||
// TODO(light): verify that provider set name is an identifier before returning
|
||||
|
||||
i := strings.LastIndexByte(ref, '.')
|
||||
@@ -227,6 +231,10 @@ func parseProviderSetRef(ref string, s *types.Scope, pkg string, pos token.Pos)
|
||||
if err != nil {
|
||||
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: bad import path", ref)
|
||||
}
|
||||
path, err = r.resolve(pos, path)
|
||||
if err != nil {
|
||||
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: %v", ref, err)
|
||||
}
|
||||
return providerSetRef{importPath: path, name: name}, nil
|
||||
}
|
||||
_, obj := s.LookupParent(imp, pos)
|
||||
@@ -244,6 +252,36 @@ func (ref providerSetRef) String() string {
|
||||
return strconv.Quote(ref.importPath) + "." + ref.name
|
||||
}
|
||||
|
||||
type importResolver struct {
|
||||
fset *token.FileSet
|
||||
bctx *build.Context
|
||||
findPackage func(bctx *build.Context, importPath, fromDir string, mode build.ImportMode) (*build.Package, error)
|
||||
}
|
||||
|
||||
func newImportResolver(c *loader.Config, fset *token.FileSet) *importResolver {
|
||||
r := &importResolver{
|
||||
fset: fset,
|
||||
bctx: c.Build,
|
||||
findPackage: c.FindPackage,
|
||||
}
|
||||
if r.bctx == nil {
|
||||
r.bctx = &build.Default
|
||||
}
|
||||
if r.findPackage == nil {
|
||||
r.findPackage = (*build.Context).Import
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *importResolver) resolve(pos token.Pos, path string) (string, error) {
|
||||
dir := filepath.Dir(r.fset.File(pos).Name())
|
||||
pkg, err := r.findPackage(r.bctx, path, dir, build.FindOnly)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return pkg.ImportPath, nil
|
||||
}
|
||||
|
||||
type directive struct {
|
||||
pos token.Pos
|
||||
kind string
|
||||
|
||||
2
internal/goose/testdata/Vendor/bar/dummy.go
vendored
Normal file
2
internal/goose/testdata/Vendor/bar/dummy.go
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
// Package bar is left intentionally blank.
|
||||
package bar
|
||||
7
internal/goose/testdata/Vendor/foo/foo.go
vendored
Normal file
7
internal/goose/testdata/Vendor/foo/foo.go
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
package main
|
||||
|
||||
import "fmt"
|
||||
|
||||
func main() {
|
||||
fmt.Println(injectedMessage())
|
||||
}
|
||||
11
internal/goose/testdata/Vendor/foo/goose.go
vendored
Normal file
11
internal/goose/testdata/Vendor/foo/goose.go
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
//+build gooseinject
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
_ "bar"
|
||||
)
|
||||
|
||||
//goose:use "bar".Message
|
||||
|
||||
func injectedMessage() string
|
||||
9
internal/goose/testdata/Vendor/foo/vendor/bar/bar.go
vendored
Normal file
9
internal/goose/testdata/Vendor/foo/vendor/bar/bar.go
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
// Package bar is the vendored copy of bar which contains the real provider.
|
||||
package bar
|
||||
|
||||
//goose:provide Message
|
||||
|
||||
// ProvideMessage provides a friendly user greeting.
|
||||
func ProvideMessage() string {
|
||||
return "Hello, World!"
|
||||
}
|
||||
1
internal/goose/testdata/Vendor/out.txt
vendored
Normal file
1
internal/goose/testdata/Vendor/out.txt
vendored
Normal file
@@ -0,0 +1 @@
|
||||
Hello, World!
|
||||
1
internal/goose/testdata/Vendor/pkg
vendored
Normal file
1
internal/goose/testdata/Vendor/pkg
vendored
Normal file
@@ -0,0 +1 @@
|
||||
foo
|
||||
Reference in New Issue
Block a user