527 lines
14 KiB
Go
527 lines
14 KiB
Go
package goose
|
|
|
|
import (
|
|
"fmt"
|
|
"go/ast"
|
|
"go/build"
|
|
"go/token"
|
|
"go/types"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"unicode"
|
|
|
|
"golang.org/x/tools/go/loader"
|
|
)
|
|
|
|
// A providerSet describes a set of providers. The zero value is an empty
|
|
// providerSet.
|
|
type providerSet struct {
|
|
providers []*providerInfo
|
|
imports []providerSetImport
|
|
}
|
|
|
|
type providerSetImport struct {
|
|
providerSetRef
|
|
pos token.Pos
|
|
}
|
|
|
|
// providerInfo records the signature of a provider function.
|
|
type providerInfo struct {
|
|
importPath string
|
|
funcName string
|
|
pos token.Pos
|
|
args []providerInput
|
|
out types.Type
|
|
hasErr bool
|
|
}
|
|
|
|
type providerInput struct {
|
|
typ types.Type
|
|
optional bool
|
|
}
|
|
|
|
type findContext struct {
|
|
fset *token.FileSet
|
|
pkg *types.Package
|
|
typeInfo *types.Info
|
|
r *importResolver
|
|
}
|
|
|
|
// findProviderSets processes a package and extracts the provider sets declared in it.
|
|
func findProviderSets(fctx findContext, files []*ast.File) (map[string]*providerSet, error) {
|
|
sets := make(map[string]*providerSet)
|
|
for _, f := range files {
|
|
fileScope := fctx.typeInfo.Scopes[f]
|
|
if fileScope == nil {
|
|
return nil, fmt.Errorf("%s: no scope found for file (likely a bug)", fctx.fset.File(f.Pos()).Name())
|
|
}
|
|
for _, dg := range parseFile(fctx.fset, f) {
|
|
if dg.decl != nil {
|
|
if err := processDeclDirectives(fctx, sets, fileScope, dg); err != nil {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
for _, d := range dg.dirs {
|
|
if err := processUnassociatedDirective(fctx, sets, fileScope, d); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return sets, nil
|
|
}
|
|
|
|
// processUnassociatedDirective handles any directive that was not associated with a top-level declaration.
|
|
func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet, scope *types.Scope, d directive) error {
|
|
switch d.kind {
|
|
case "provide", "optional":
|
|
return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(d.pos))
|
|
case "use":
|
|
// Ignore, picked up by injector flow.
|
|
case "import":
|
|
args := d.args()
|
|
if len(args) < 2 {
|
|
return fmt.Errorf("%s: invalid import: expected TARGET SETREF", fctx.fset.Position(d.pos))
|
|
}
|
|
name := args[0]
|
|
for _, spec := range args[1:] {
|
|
ref, err := parseProviderSetRef(fctx.r, spec, scope, fctx.pkg.Path(), d.pos)
|
|
if err != nil {
|
|
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
|
|
}
|
|
if ref.importPath != fctx.pkg.Path() {
|
|
imported := false
|
|
for _, imp := range fctx.pkg.Imports() {
|
|
if ref.importPath == imp.Path() {
|
|
imported = true
|
|
break
|
|
}
|
|
}
|
|
if !imported {
|
|
return fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fctx.fset.Position(d.pos), name, ref.importPath)
|
|
}
|
|
}
|
|
if mod := sets[name]; mod != nil {
|
|
found := false
|
|
for _, other := range mod.imports {
|
|
if ref == other.providerSetRef {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
mod.imports = append(mod.imports, providerSetImport{providerSetRef: ref, pos: d.pos})
|
|
}
|
|
} else {
|
|
sets[name] = &providerSet{
|
|
imports: []providerSetImport{{providerSetRef: ref, pos: d.pos}},
|
|
}
|
|
}
|
|
}
|
|
default:
|
|
return fmt.Errorf("%v: unknown directive %s", fctx.fset.Position(d.pos), d.kind)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// processDeclDirectives processes the directives associated with a top-level declaration.
|
|
func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope *types.Scope, dg directiveGroup) error {
|
|
p, err := dg.single(fctx.fset, "provide")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !p.isValid() {
|
|
for _, d := range dg.dirs {
|
|
if d.kind == "optional" {
|
|
return fmt.Errorf("%v: cannot use goose:%s directive on non-provider", fctx.fset.Position(d.pos), d.kind)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
fn, ok := dg.decl.(*ast.FuncDecl)
|
|
if !ok {
|
|
return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(p.pos))
|
|
}
|
|
sig := fctx.typeInfo.ObjectOf(fn.Name).Type().(*types.Signature)
|
|
|
|
optionals := make([]bool, sig.Params().Len())
|
|
for _, d := range dg.dirs {
|
|
if d.kind == "optional" {
|
|
// Marking the given argument names as optional inputs.
|
|
for _, arg := range d.args() {
|
|
pi := paramIndex(sig.Params(), arg)
|
|
if pi == -1 {
|
|
return fmt.Errorf("%v: %s is not a parameter of func %s", fctx.fset.Position(d.pos), arg, fn.Name.Name)
|
|
}
|
|
optionals[pi] = true
|
|
}
|
|
}
|
|
}
|
|
|
|
fpos := fn.Pos()
|
|
r := sig.Results()
|
|
var hasErr bool
|
|
switch r.Len() {
|
|
case 1:
|
|
hasErr = false
|
|
case 2:
|
|
if t := r.At(1).Type(); !types.Identical(t, errorType) {
|
|
return fmt.Errorf("%v: wrong signature for provider %s: second return type must be error", fctx.fset.Position(fpos), fn.Name.Name)
|
|
}
|
|
hasErr = true
|
|
default:
|
|
return fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fctx.fset.Position(fpos), fn.Name.Name)
|
|
}
|
|
out := r.At(0).Type()
|
|
params := sig.Params()
|
|
provider := &providerInfo{
|
|
importPath: fctx.pkg.Path(),
|
|
funcName: fn.Name.Name,
|
|
pos: fn.Pos(),
|
|
args: make([]providerInput, params.Len()),
|
|
out: out,
|
|
hasErr: hasErr,
|
|
}
|
|
for i := 0; i < params.Len(); i++ {
|
|
provider.args[i] = providerInput{
|
|
typ: params.At(i).Type(),
|
|
optional: optionals[i],
|
|
}
|
|
for j := 0; j < i; j++ {
|
|
if types.Identical(provider.args[i].typ, provider.args[j].typ) {
|
|
return fmt.Errorf("%v: provider has multiple parameters of type %s", fctx.fset.Position(fpos), types.TypeString(provider.args[j].typ, nil))
|
|
}
|
|
}
|
|
}
|
|
providerSetName := fn.Name.Name
|
|
if args := p.args(); len(args) == 1 {
|
|
// TODO(light): validate identifier
|
|
providerSetName = args[0]
|
|
} else if len(args) > 1 {
|
|
return fmt.Errorf("%v: goose:provide takes at most one argument", fctx.fset.Position(fpos))
|
|
}
|
|
if mod := sets[providerSetName]; mod != nil {
|
|
for _, other := range mod.providers {
|
|
if types.Identical(other.out, provider.out) {
|
|
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(fn.Pos()), providerSetName, types.TypeString(provider.out, nil), fctx.fset.Position(other.pos))
|
|
}
|
|
}
|
|
mod.providers = append(mod.providers, provider)
|
|
} else {
|
|
sets[providerSetName] = &providerSet{
|
|
providers: []*providerInfo{provider},
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// providerSetCache is a lazily evaluated index of provider sets.
|
|
type providerSetCache struct {
|
|
sets map[string]map[string]*providerSet
|
|
fset *token.FileSet
|
|
prog *loader.Program
|
|
r *importResolver
|
|
}
|
|
|
|
func newProviderSetCache(prog *loader.Program, r *importResolver) *providerSetCache {
|
|
return &providerSetCache{
|
|
fset: prog.Fset,
|
|
prog: prog,
|
|
r: r,
|
|
}
|
|
}
|
|
|
|
func (mc *providerSetCache) get(ref providerSetRef) (*providerSet, error) {
|
|
if mods, cached := mc.sets[ref.importPath]; cached {
|
|
mod := mods[ref.name]
|
|
if mod == nil {
|
|
return nil, fmt.Errorf("no such provider set %s in package %q", ref.name, ref.importPath)
|
|
}
|
|
return mod, nil
|
|
}
|
|
if mc.sets == nil {
|
|
mc.sets = make(map[string]map[string]*providerSet)
|
|
}
|
|
pkg := mc.prog.Package(ref.importPath)
|
|
mods, err := findProviderSets(findContext{
|
|
fset: mc.fset,
|
|
pkg: pkg.Pkg,
|
|
typeInfo: &pkg.Info,
|
|
r: mc.r,
|
|
}, pkg.Files)
|
|
if err != nil {
|
|
mc.sets[ref.importPath] = nil
|
|
return nil, err
|
|
}
|
|
mc.sets[ref.importPath] = mods
|
|
mod := mods[ref.name]
|
|
if mod == nil {
|
|
return nil, fmt.Errorf("no such provider set %s in package %q", ref.name, ref.importPath)
|
|
}
|
|
return mod, nil
|
|
}
|
|
|
|
// A providerSetRef is a parsed reference to a collection of providers.
|
|
type providerSetRef struct {
|
|
importPath string
|
|
name string
|
|
}
|
|
|
|
func parseProviderSetRef(r *importResolver, ref string, s *types.Scope, pkg string, pos token.Pos) (providerSetRef, error) {
|
|
// TODO(light): verify that provider set name is an identifier before returning
|
|
|
|
i := strings.LastIndexByte(ref, '.')
|
|
if i == -1 {
|
|
return providerSetRef{importPath: pkg, name: ref}, nil
|
|
}
|
|
imp, name := ref[:i], ref[i+1:]
|
|
if strings.HasPrefix(imp, `"`) {
|
|
path, err := strconv.Unquote(imp)
|
|
if err != nil {
|
|
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: bad import path", ref)
|
|
}
|
|
path, err = r.resolve(pos, path)
|
|
if err != nil {
|
|
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: %v", ref, err)
|
|
}
|
|
return providerSetRef{importPath: path, name: name}, nil
|
|
}
|
|
_, obj := s.LookupParent(imp, pos)
|
|
if obj == nil {
|
|
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: unknown identifier %s", ref, imp)
|
|
}
|
|
pn, ok := obj.(*types.PkgName)
|
|
if !ok {
|
|
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: %s does not name a package", ref, imp)
|
|
}
|
|
return providerSetRef{importPath: pn.Imported().Path(), name: name}, nil
|
|
}
|
|
|
|
func (ref providerSetRef) String() string {
|
|
return strconv.Quote(ref.importPath) + "." + ref.name
|
|
}
|
|
|
|
type importResolver struct {
|
|
fset *token.FileSet
|
|
bctx *build.Context
|
|
findPackage func(bctx *build.Context, importPath, fromDir string, mode build.ImportMode) (*build.Package, error)
|
|
}
|
|
|
|
func newImportResolver(c *loader.Config, fset *token.FileSet) *importResolver {
|
|
r := &importResolver{
|
|
fset: fset,
|
|
bctx: c.Build,
|
|
findPackage: c.FindPackage,
|
|
}
|
|
if r.bctx == nil {
|
|
r.bctx = &build.Default
|
|
}
|
|
if r.findPackage == nil {
|
|
r.findPackage = (*build.Context).Import
|
|
}
|
|
return r
|
|
}
|
|
|
|
func (r *importResolver) resolve(pos token.Pos, path string) (string, error) {
|
|
dir := filepath.Dir(r.fset.File(pos).Name())
|
|
pkg, err := r.findPackage(r.bctx, path, dir, build.FindOnly)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return pkg.ImportPath, nil
|
|
}
|
|
|
|
// A directive is a parsed goose comment.
|
|
type directive struct {
|
|
pos token.Pos
|
|
kind string
|
|
line string
|
|
}
|
|
|
|
// A directiveGroup is a set of directives associated with a particular
|
|
// declaration.
|
|
type directiveGroup struct {
|
|
decl ast.Decl
|
|
dirs []directive
|
|
}
|
|
|
|
// parseFile extracts the directives from a file, grouped by declaration.
|
|
func parseFile(fset *token.FileSet, f *ast.File) []directiveGroup {
|
|
cmap := ast.NewCommentMap(fset, f, f.Comments)
|
|
// Reserve first group for directives that don't associate with a
|
|
// declaration, like import.
|
|
groups := make([]directiveGroup, 1, len(f.Decls)+1)
|
|
// Walk declarations and add to groups.
|
|
for _, decl := range f.Decls {
|
|
grp := directiveGroup{decl: decl}
|
|
ast.Inspect(decl, func(node ast.Node) bool {
|
|
if g := cmap[node]; len(g) > 0 {
|
|
for _, cg := range g {
|
|
start := len(grp.dirs)
|
|
grp.dirs = extractDirectives(grp.dirs, cg)
|
|
|
|
// Move directives that don't associate into the unassociated group.
|
|
n := 0
|
|
for i := start; i < len(grp.dirs); i++ {
|
|
if k := grp.dirs[i].kind; k == "provide" || k == "optional" || k == "use" {
|
|
grp.dirs[start+n] = grp.dirs[i]
|
|
n++
|
|
} else {
|
|
groups[0].dirs = append(groups[0].dirs, grp.dirs[i])
|
|
}
|
|
}
|
|
grp.dirs = grp.dirs[:start+n]
|
|
}
|
|
delete(cmap, node)
|
|
}
|
|
return true
|
|
})
|
|
if len(grp.dirs) > 0 {
|
|
groups = append(groups, grp)
|
|
}
|
|
}
|
|
// Place remaining directives into the unassociated group.
|
|
unassoc := &groups[0]
|
|
for _, g := range cmap {
|
|
for _, cg := range g {
|
|
unassoc.dirs = extractDirectives(unassoc.dirs, cg)
|
|
}
|
|
}
|
|
if len(unassoc.dirs) == 0 {
|
|
return groups[1:]
|
|
}
|
|
return groups
|
|
}
|
|
|
|
func extractDirectives(d []directive, cg *ast.CommentGroup) []directive {
|
|
const prefix = "goose:"
|
|
text := cg.Text()
|
|
for len(text) > 0 {
|
|
text = strings.TrimLeft(text, " \t\r\n")
|
|
if !strings.HasPrefix(text, prefix) {
|
|
break
|
|
}
|
|
line := text[len(prefix):]
|
|
// Text() is always newline terminated.
|
|
i := strings.IndexByte(line, '\n')
|
|
line, text = line[:i], line[i+1:]
|
|
if i := strings.IndexByte(line, ' '); i != -1 {
|
|
d = append(d, directive{
|
|
kind: line[:i],
|
|
line: strings.TrimSpace(line[i+1:]),
|
|
pos: cg.Pos(), // TODO(light): more precise position
|
|
})
|
|
} else {
|
|
d = append(d, directive{
|
|
kind: line,
|
|
pos: cg.Pos(), // TODO(light): more precise position
|
|
})
|
|
}
|
|
}
|
|
return d
|
|
}
|
|
|
|
// single finds at most one directive that matches the given kind.
|
|
func (dg directiveGroup) single(fset *token.FileSet, kind string) (directive, error) {
|
|
var found directive
|
|
ok := false
|
|
for _, d := range dg.dirs {
|
|
if d.kind != kind {
|
|
continue
|
|
}
|
|
if ok {
|
|
switch decl := dg.decl.(type) {
|
|
case *ast.FuncDecl:
|
|
return directive{}, fmt.Errorf("%v: multiple %s directives for %s", fset.Position(d.pos), kind, decl.Name.Name)
|
|
case *ast.GenDecl:
|
|
if decl.Tok == token.TYPE && len(decl.Specs) == 1 {
|
|
name := decl.Specs[0].(*ast.TypeSpec).Name.Name
|
|
return directive{}, fmt.Errorf("%v: multiple %s directives for %s", fset.Position(d.pos), kind, name)
|
|
}
|
|
return directive{}, fmt.Errorf("%v: multiple %s directives", fset.Position(d.pos), kind)
|
|
default:
|
|
return directive{}, fmt.Errorf("%v: multiple %s directives", fset.Position(d.pos), kind)
|
|
}
|
|
}
|
|
found, ok = d, true
|
|
}
|
|
return found, nil
|
|
}
|
|
|
|
func (d directive) isValid() bool {
|
|
return d.kind != ""
|
|
}
|
|
|
|
// args splits the directive line into tokens.
|
|
func (d directive) args() []string {
|
|
var args []string
|
|
start := -1
|
|
state := 0 // 0 = boundary, 1 = in token, 2 = in quote, 3 = quote backslash
|
|
for i, r := range d.line {
|
|
switch state {
|
|
case 0:
|
|
// Argument boundary
|
|
switch {
|
|
case r == '"':
|
|
start = i
|
|
state = 2
|
|
case !unicode.IsSpace(r):
|
|
start = i
|
|
state = 1
|
|
}
|
|
case 1:
|
|
// In token
|
|
switch {
|
|
case unicode.IsSpace(r):
|
|
args = append(args, d.line[start:i])
|
|
start = -1
|
|
state = 0
|
|
case r == '"':
|
|
state = 2
|
|
}
|
|
case 2:
|
|
// In quotes
|
|
switch {
|
|
case r == '"':
|
|
state = 1
|
|
case r == '\\':
|
|
state = 3
|
|
}
|
|
case 3:
|
|
// Quote backslash. Consumes one character and jumps back into "in quote" state.
|
|
state = 2
|
|
default:
|
|
panic("unreachable")
|
|
}
|
|
}
|
|
if start != -1 {
|
|
args = append(args, d.line[start:])
|
|
}
|
|
return args
|
|
}
|
|
|
|
// isInjectFile reports whether a given file is an injection template.
|
|
func isInjectFile(f *ast.File) bool {
|
|
// TODO(light): better determination
|
|
for _, cg := range f.Comments {
|
|
text := cg.Text()
|
|
if strings.HasPrefix(text, "+build") && strings.Contains(text, "gooseinject") {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// paramIndex returns the index of the parameter with the given name, or
|
|
// -1 if no such parameter exists.
|
|
func paramIndex(params *types.Tuple, name string) int {
|
|
for i := 0; i < params.Len(); i++ {
|
|
if params.At(i).Name() == name {
|
|
return i
|
|
}
|
|
}
|
|
return -1
|
|
}
|