Files
wire/internal/goose/parse.go
Ross Light 1380f96c06 goose: add interface binding
An interface binding instructs goose that a concrete type should be used
to satisfy a dependency on an interface type. goose could determine this
implicitly, but having an explicit directive makes the provider author's
intent clear and allows different concrete types to satisfy different
smaller interfaces.

Reviewed-by: Tuo Shan <shantuo@google.com>
2018-11-12 14:09:56 -08:00

613 lines
17 KiB
Go

package goose
import (
"fmt"
"go/ast"
"go/build"
"go/token"
"go/types"
"path/filepath"
"strconv"
"strings"
"unicode"
"golang.org/x/tools/go/loader"
)
// A providerSet describes a set of providers. The zero value is an empty
// providerSet.
type providerSet struct {
providers []*providerInfo
bindings []ifaceBinding
imports []providerSetImport
}
// An ifaceBinding declares that a type should be used to satisfy inputs
// of the given interface type.
//
// provided is always a type that is assignable to iface.
type ifaceBinding struct {
iface types.Type
provided types.Type
pos token.Pos
}
type providerSetImport struct {
symref
pos token.Pos
}
// providerInfo records the signature of a provider function.
type providerInfo struct {
importPath string
funcName string
pos token.Pos // provider function definition
args []providerInput
out types.Type
hasErr bool
}
type providerInput struct {
typ types.Type
optional bool
}
type findContext struct {
fset *token.FileSet
pkg *types.Package
typeInfo *types.Info
r *importResolver
}
// findProviderSets processes a package and extracts the provider sets declared in it.
func findProviderSets(fctx findContext, files []*ast.File) (map[string]*providerSet, error) {
sets := make(map[string]*providerSet)
for _, f := range files {
fileScope := fctx.typeInfo.Scopes[f]
if fileScope == nil {
return nil, fmt.Errorf("%s: no scope found for file (likely a bug)", fctx.fset.File(f.Pos()).Name())
}
for _, dg := range parseFile(fctx.fset, f) {
if dg.decl != nil {
if err := processDeclDirectives(fctx, sets, fileScope, dg); err != nil {
return nil, err
}
} else {
for _, d := range dg.dirs {
if err := processUnassociatedDirective(fctx, sets, fileScope, d); err != nil {
return nil, err
}
}
}
}
}
return sets, nil
}
// processUnassociatedDirective handles any directive that was not associated with a top-level declaration.
func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet, scope *types.Scope, d directive) error {
switch d.kind {
case "provide", "optional":
return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(d.pos))
case "use":
// Ignore, picked up by injector flow.
case "bind":
args := d.args()
if len(args) != 3 {
return fmt.Errorf("%v: invalid binding: expected TARGET IFACE TYPE", fctx.fset.Position(d.pos))
}
ifaceRef, err := parseSymbolRef(fctx.r, args[1], scope, fctx.pkg.Path(), d.pos)
if err != nil {
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
}
ifaceObj, err := ifaceRef.resolveObject(fctx.pkg)
if err != nil {
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
}
ifaceDecl, ok := ifaceObj.(*types.TypeName)
if !ok {
return fmt.Errorf("%v: %v does not name a type", fctx.fset.Position(d.pos), ifaceRef)
}
iface := ifaceDecl.Type()
methodSet, ok := iface.Underlying().(*types.Interface)
if !ok {
return fmt.Errorf("%v: %v does not name an interface type", fctx.fset.Position(d.pos), ifaceRef)
}
providedRef, err := parseSymbolRef(fctx.r, strings.TrimPrefix(args[2], "*"), scope, fctx.pkg.Path(), d.pos)
if err != nil {
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
}
providedObj, err := providedRef.resolveObject(fctx.pkg)
if err != nil {
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
}
providedDecl, ok := providedObj.(*types.TypeName)
if !ok {
return fmt.Errorf("%v: %v does not name a type", fctx.fset.Position(d.pos), providedRef)
}
provided := providedDecl.Type()
if types.Identical(provided, iface) {
return fmt.Errorf("%v: cannot bind interface to itself", fctx.fset.Position(d.pos))
}
if strings.HasPrefix(args[2], "*") {
provided = types.NewPointer(provided)
}
if !types.Implements(provided, methodSet) {
return fmt.Errorf("%v: %s does not implement %s", fctx.fset.Position(d.pos), types.TypeString(provided, nil), types.TypeString(iface, nil))
}
name := args[0]
if pset := sets[name]; pset != nil {
pset.bindings = append(pset.bindings, ifaceBinding{
iface: iface,
provided: provided,
})
} else {
sets[name] = &providerSet{
bindings: []ifaceBinding{{
iface: iface,
provided: provided,
}},
}
}
case "import":
args := d.args()
if len(args) < 2 {
return fmt.Errorf("%v: invalid import: expected TARGET SETREF", fctx.fset.Position(d.pos))
}
name := args[0]
for _, spec := range args[1:] {
ref, err := parseSymbolRef(fctx.r, spec, scope, fctx.pkg.Path(), d.pos)
if err != nil {
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
}
if findImport(fctx.pkg, ref.importPath) == nil {
return fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fctx.fset.Position(d.pos), name, ref.importPath)
}
if mod := sets[name]; mod != nil {
found := false
for _, other := range mod.imports {
if ref == other.symref {
found = true
break
}
}
if !found {
mod.imports = append(mod.imports, providerSetImport{symref: ref, pos: d.pos})
}
} else {
sets[name] = &providerSet{
imports: []providerSetImport{{symref: ref, pos: d.pos}},
}
}
}
default:
return fmt.Errorf("%v: unknown directive %s", fctx.fset.Position(d.pos), d.kind)
}
return nil
}
// processDeclDirectives processes the directives associated with a top-level declaration.
func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope *types.Scope, dg directiveGroup) error {
p, err := dg.single(fctx.fset, "provide")
if err != nil {
return err
}
if !p.isValid() {
for _, d := range dg.dirs {
if d.kind == "optional" {
return fmt.Errorf("%v: cannot use goose:%s directive on non-provider", fctx.fset.Position(d.pos), d.kind)
}
}
return nil
}
fn, ok := dg.decl.(*ast.FuncDecl)
if !ok {
return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(p.pos))
}
sig := fctx.typeInfo.ObjectOf(fn.Name).Type().(*types.Signature)
optionals := make([]bool, sig.Params().Len())
for _, d := range dg.dirs {
if d.kind == "optional" {
// Marking the given argument names as optional inputs.
for _, arg := range d.args() {
pi := paramIndex(sig.Params(), arg)
if pi == -1 {
return fmt.Errorf("%v: %s is not a parameter of func %s", fctx.fset.Position(d.pos), arg, fn.Name.Name)
}
optionals[pi] = true
}
}
}
fpos := fn.Pos()
r := sig.Results()
var hasErr bool
switch r.Len() {
case 1:
hasErr = false
case 2:
if t := r.At(1).Type(); !types.Identical(t, errorType) {
return fmt.Errorf("%v: wrong signature for provider %s: second return type must be error", fctx.fset.Position(fpos), fn.Name.Name)
}
hasErr = true
default:
return fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fctx.fset.Position(fpos), fn.Name.Name)
}
out := r.At(0).Type()
params := sig.Params()
provider := &providerInfo{
importPath: fctx.pkg.Path(),
funcName: fn.Name.Name,
pos: fn.Pos(),
args: make([]providerInput, params.Len()),
out: out,
hasErr: hasErr,
}
for i := 0; i < params.Len(); i++ {
provider.args[i] = providerInput{
typ: params.At(i).Type(),
optional: optionals[i],
}
for j := 0; j < i; j++ {
if types.Identical(provider.args[i].typ, provider.args[j].typ) {
return fmt.Errorf("%v: provider has multiple parameters of type %s", fctx.fset.Position(fpos), types.TypeString(provider.args[j].typ, nil))
}
}
}
providerSetName := fn.Name.Name
if args := p.args(); len(args) == 1 {
// TODO(light): validate identifier
providerSetName = args[0]
} else if len(args) > 1 {
return fmt.Errorf("%v: goose:provide takes at most one argument", fctx.fset.Position(fpos))
}
if mod := sets[providerSetName]; mod != nil {
for _, other := range mod.providers {
if types.Identical(other.out, provider.out) {
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(fn.Pos()), providerSetName, types.TypeString(provider.out, nil), fctx.fset.Position(other.pos))
}
}
mod.providers = append(mod.providers, provider)
} else {
sets[providerSetName] = &providerSet{
providers: []*providerInfo{provider},
}
}
return nil
}
// providerSetCache is a lazily evaluated index of provider sets.
type providerSetCache struct {
sets map[string]map[string]*providerSet
fset *token.FileSet
prog *loader.Program
r *importResolver
}
func newProviderSetCache(prog *loader.Program, r *importResolver) *providerSetCache {
return &providerSetCache{
fset: prog.Fset,
prog: prog,
r: r,
}
}
func (mc *providerSetCache) get(ref symref) (*providerSet, error) {
if mods, cached := mc.sets[ref.importPath]; cached {
mod := mods[ref.name]
if mod == nil {
return nil, fmt.Errorf("no such provider set %s in package %q", ref.name, ref.importPath)
}
return mod, nil
}
if mc.sets == nil {
mc.sets = make(map[string]map[string]*providerSet)
}
pkg := mc.prog.Package(ref.importPath)
mods, err := findProviderSets(findContext{
fset: mc.fset,
pkg: pkg.Pkg,
typeInfo: &pkg.Info,
r: mc.r,
}, pkg.Files)
if err != nil {
mc.sets[ref.importPath] = nil
return nil, err
}
mc.sets[ref.importPath] = mods
mod := mods[ref.name]
if mod == nil {
return nil, fmt.Errorf("no such provider set %s in package %q", ref.name, ref.importPath)
}
return mod, nil
}
// A symref is a parsed reference to a symbol (either a provider set or a Go object).
type symref struct {
importPath string
name string
}
func parseSymbolRef(r *importResolver, ref string, s *types.Scope, pkg string, pos token.Pos) (symref, error) {
// TODO(light): verify that provider set name is an identifier before returning
i := strings.LastIndexByte(ref, '.')
if i == -1 {
return symref{importPath: pkg, name: ref}, nil
}
imp, name := ref[:i], ref[i+1:]
if strings.HasPrefix(imp, `"`) {
path, err := strconv.Unquote(imp)
if err != nil {
return symref{}, fmt.Errorf("parse symbol reference %q: bad import path", ref)
}
path, err = r.resolve(pos, path)
if err != nil {
return symref{}, fmt.Errorf("parse symbol reference %q: %v", ref, err)
}
return symref{importPath: path, name: name}, nil
}
_, obj := s.LookupParent(imp, pos)
if obj == nil {
return symref{}, fmt.Errorf("parse symbol reference %q: unknown identifier %s", ref, imp)
}
pn, ok := obj.(*types.PkgName)
if !ok {
return symref{}, fmt.Errorf("parse symbol reference %q: %s does not name a package", ref, imp)
}
return symref{importPath: pn.Imported().Path(), name: name}, nil
}
func (ref symref) String() string {
return strconv.Quote(ref.importPath) + "." + ref.name
}
func (ref symref) resolveObject(pkg *types.Package) (types.Object, error) {
imp := findImport(pkg, ref.importPath)
if imp == nil {
return nil, fmt.Errorf("resolve Go reference %v: package not directly imported", ref)
}
obj := imp.Scope().Lookup(ref.name)
if obj == nil {
return nil, fmt.Errorf("resolve Go reference %v: %s not found in package", ref, ref.name)
}
return obj, nil
}
type importResolver struct {
fset *token.FileSet
bctx *build.Context
findPackage func(bctx *build.Context, importPath, fromDir string, mode build.ImportMode) (*build.Package, error)
}
func newImportResolver(c *loader.Config, fset *token.FileSet) *importResolver {
r := &importResolver{
fset: fset,
bctx: c.Build,
findPackage: c.FindPackage,
}
if r.bctx == nil {
r.bctx = &build.Default
}
if r.findPackage == nil {
r.findPackage = (*build.Context).Import
}
return r
}
func (r *importResolver) resolve(pos token.Pos, path string) (string, error) {
dir := filepath.Dir(r.fset.File(pos).Name())
pkg, err := r.findPackage(r.bctx, path, dir, build.FindOnly)
if err != nil {
return "", err
}
return pkg.ImportPath, nil
}
func findImport(pkg *types.Package, path string) *types.Package {
if pkg.Path() == path {
return pkg
}
for _, imp := range pkg.Imports() {
if imp.Path() == path {
return imp
}
}
return nil
}
// A directive is a parsed goose comment.
type directive struct {
pos token.Pos
kind string
line string
}
// A directiveGroup is a set of directives associated with a particular
// declaration.
type directiveGroup struct {
decl ast.Decl
dirs []directive
}
// parseFile extracts the directives from a file, grouped by declaration.
func parseFile(fset *token.FileSet, f *ast.File) []directiveGroup {
cmap := ast.NewCommentMap(fset, f, f.Comments)
// Reserve first group for directives that don't associate with a
// declaration, like import.
groups := make([]directiveGroup, 1, len(f.Decls)+1)
// Walk declarations and add to groups.
for _, decl := range f.Decls {
grp := directiveGroup{decl: decl}
ast.Inspect(decl, func(node ast.Node) bool {
if g := cmap[node]; len(g) > 0 {
for _, cg := range g {
start := len(grp.dirs)
grp.dirs = extractDirectives(grp.dirs, cg)
// Move directives that don't associate into the unassociated group.
n := 0
for i := start; i < len(grp.dirs); i++ {
if k := grp.dirs[i].kind; k == "provide" || k == "optional" || k == "use" {
grp.dirs[start+n] = grp.dirs[i]
n++
} else {
groups[0].dirs = append(groups[0].dirs, grp.dirs[i])
}
}
grp.dirs = grp.dirs[:start+n]
}
delete(cmap, node)
}
return true
})
if len(grp.dirs) > 0 {
groups = append(groups, grp)
}
}
// Place remaining directives into the unassociated group.
unassoc := &groups[0]
for _, g := range cmap {
for _, cg := range g {
unassoc.dirs = extractDirectives(unassoc.dirs, cg)
}
}
if len(unassoc.dirs) == 0 {
return groups[1:]
}
return groups
}
func extractDirectives(d []directive, cg *ast.CommentGroup) []directive {
const prefix = "goose:"
text := cg.Text()
for len(text) > 0 {
text = strings.TrimLeft(text, " \t\r\n")
if !strings.HasPrefix(text, prefix) {
break
}
line := text[len(prefix):]
// Text() is always newline terminated.
i := strings.IndexByte(line, '\n')
line, text = line[:i], line[i+1:]
if i := strings.IndexByte(line, ' '); i != -1 {
d = append(d, directive{
kind: line[:i],
line: strings.TrimSpace(line[i+1:]),
pos: cg.Pos(), // TODO(light): more precise position
})
} else {
d = append(d, directive{
kind: line,
pos: cg.Pos(), // TODO(light): more precise position
})
}
}
return d
}
// single finds at most one directive that matches the given kind.
func (dg directiveGroup) single(fset *token.FileSet, kind string) (directive, error) {
var found directive
ok := false
for _, d := range dg.dirs {
if d.kind != kind {
continue
}
if ok {
switch decl := dg.decl.(type) {
case *ast.FuncDecl:
return directive{}, fmt.Errorf("%v: multiple %s directives for %s", fset.Position(d.pos), kind, decl.Name.Name)
case *ast.GenDecl:
if decl.Tok == token.TYPE && len(decl.Specs) == 1 {
name := decl.Specs[0].(*ast.TypeSpec).Name.Name
return directive{}, fmt.Errorf("%v: multiple %s directives for %s", fset.Position(d.pos), kind, name)
}
return directive{}, fmt.Errorf("%v: multiple %s directives", fset.Position(d.pos), kind)
default:
return directive{}, fmt.Errorf("%v: multiple %s directives", fset.Position(d.pos), kind)
}
}
found, ok = d, true
}
return found, nil
}
func (d directive) isValid() bool {
return d.kind != ""
}
// args splits the directive line into tokens.
func (d directive) args() []string {
var args []string
start := -1
state := 0 // 0 = boundary, 1 = in token, 2 = in quote, 3 = quote backslash
for i, r := range d.line {
switch state {
case 0:
// Argument boundary
switch {
case r == '"':
start = i
state = 2
case !unicode.IsSpace(r):
start = i
state = 1
}
case 1:
// In token
switch {
case unicode.IsSpace(r):
args = append(args, d.line[start:i])
start = -1
state = 0
case r == '"':
state = 2
}
case 2:
// In quotes
switch {
case r == '"':
state = 1
case r == '\\':
state = 3
}
case 3:
// Quote backslash. Consumes one character and jumps back into "in quote" state.
state = 2
default:
panic("unreachable")
}
}
if start != -1 {
args = append(args, d.line[start:])
}
return args
}
// isInjectFile reports whether a given file is an injection template.
func isInjectFile(f *ast.File) bool {
// TODO(light): better determination
for _, cg := range f.Comments {
text := cg.Text()
if strings.HasPrefix(text, "+build") && strings.Contains(text, "gooseinject") {
return true
}
}
return false
}
// paramIndex returns the index of the parameter with the given name, or
// -1 if no such parameter exists.
func paramIndex(params *types.Tuple, name string) int {
for i := 0; i < params.Len(); i++ {
if params.At(i).Name() == name {
return i
}
}
return -1
}