goose: split into multiple files
No functional change, just to make future diffs easier. Reviewed-by: Tuo Shan <shantuo@google.com> Reviewed-by: Herbie Ong <herbie@google.com>
This commit is contained in:
152
internal/goose/analyze.go
Normal file
152
internal/goose/analyze.go
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
package goose
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"go/token"
|
||||||
|
"go/types"
|
||||||
|
|
||||||
|
"golang.org/x/tools/go/types/typeutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A call represents a step of an injector function.
|
||||||
|
type call struct {
|
||||||
|
// importPath and funcName identify the provider function to call.
|
||||||
|
importPath string
|
||||||
|
funcName string
|
||||||
|
|
||||||
|
// args is a list of arguments to call the provider with. Each element is either:
|
||||||
|
// a) one of the givens (args[i] < len(given)) or
|
||||||
|
// b) the result of a previous provider call (args[i] >= len(given)).
|
||||||
|
args []int
|
||||||
|
|
||||||
|
// out is the type produced by this provider call.
|
||||||
|
out types.Type
|
||||||
|
|
||||||
|
// hasErr is true if the provider call returns an error.
|
||||||
|
hasErr bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// solve finds the sequence of calls required to produce an output type
|
||||||
|
// with an optional set of provided inputs.
|
||||||
|
func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []providerSetRef) ([]call, error) {
|
||||||
|
for i, g := range given {
|
||||||
|
for _, h := range given[:i] {
|
||||||
|
if types.Identical(g, h) {
|
||||||
|
return nil, fmt.Errorf("multiple inputs of the same type %s", types.TypeString(g, nil))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
providers, err := buildProviderMap(mc, sets)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start building the mapping of type to local variable of the given type.
|
||||||
|
// The first len(given) local variables are the given types.
|
||||||
|
index := new(typeutil.Map)
|
||||||
|
for i, g := range given {
|
||||||
|
if p := providers.At(g); p != nil {
|
||||||
|
pp := p.(*providerInfo)
|
||||||
|
return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.funcName, mc.fset.Position(pp.pos))
|
||||||
|
}
|
||||||
|
index.Set(g, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Topological sort of the directed graph defined by the providers
|
||||||
|
// using a depth-first search. The graph may contain cycles, which
|
||||||
|
// should trigger an error.
|
||||||
|
var calls []call
|
||||||
|
var visit func(trail []types.Type) error
|
||||||
|
visit = func(trail []types.Type) error {
|
||||||
|
typ := trail[len(trail)-1]
|
||||||
|
if index.At(typ) != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for _, t := range trail[:len(trail)-1] {
|
||||||
|
if types.Identical(typ, t) {
|
||||||
|
// TODO(light): describe cycle
|
||||||
|
return fmt.Errorf("cycle for %s", types.TypeString(typ, nil))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p, _ := providers.At(typ).(*providerInfo)
|
||||||
|
if p == nil {
|
||||||
|
if len(trail) == 1 {
|
||||||
|
return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil))
|
||||||
|
}
|
||||||
|
// TODO(light): give name of provider
|
||||||
|
return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2], nil))
|
||||||
|
}
|
||||||
|
for _, a := range p.args {
|
||||||
|
// TODO(light): this will discard grown trail arrays.
|
||||||
|
if err := visit(append(trail, a)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
args := make([]int, len(p.args))
|
||||||
|
for i := range p.args {
|
||||||
|
args[i] = index.At(p.args[i]).(int)
|
||||||
|
}
|
||||||
|
index.Set(typ, len(given)+len(calls))
|
||||||
|
calls = append(calls, call{
|
||||||
|
importPath: p.importPath,
|
||||||
|
funcName: p.funcName,
|
||||||
|
args: args,
|
||||||
|
out: typ,
|
||||||
|
hasErr: p.hasErr,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := visit([]types.Type{out}); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return calls, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildProviderMap(mc *providerSetCache, sets []providerSetRef) (*typeutil.Map, error) {
|
||||||
|
type nextEnt struct {
|
||||||
|
to providerSetRef
|
||||||
|
|
||||||
|
from providerSetRef
|
||||||
|
pos token.Pos
|
||||||
|
}
|
||||||
|
|
||||||
|
pm := new(typeutil.Map) // to *providerInfo
|
||||||
|
visited := make(map[providerSetRef]struct{})
|
||||||
|
var next []nextEnt
|
||||||
|
for _, ref := range sets {
|
||||||
|
next = append(next, nextEnt{to: ref})
|
||||||
|
}
|
||||||
|
for len(next) > 0 {
|
||||||
|
curr := next[0]
|
||||||
|
copy(next, next[1:])
|
||||||
|
next = next[:len(next)-1]
|
||||||
|
if _, skip := visited[curr.to]; skip {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
visited[curr.to] = struct{}{}
|
||||||
|
mod, err := mc.get(curr.to)
|
||||||
|
if err != nil {
|
||||||
|
if !curr.pos.IsValid() {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("%v: %v", mc.fset.Position(curr.pos), err)
|
||||||
|
}
|
||||||
|
for _, p := range mod.providers {
|
||||||
|
if prev := pm.At(p.out); prev != nil {
|
||||||
|
pos := mc.fset.Position(p.pos)
|
||||||
|
typ := types.TypeString(p.out, nil)
|
||||||
|
prevPos := mc.fset.Position(prev.(*providerInfo).pos)
|
||||||
|
if curr.from.importPath != "" {
|
||||||
|
return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, curr.from, prevPos)
|
||||||
|
}
|
||||||
|
pm.Set(p.out, p)
|
||||||
|
}
|
||||||
|
for _, imp := range mod.imports {
|
||||||
|
next = append(next, nextEnt{to: imp.providerSetRef, from: curr.to, pos: imp.pos})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return pm, nil
|
||||||
|
}
|
||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"go/build"
|
"go/build"
|
||||||
"go/format"
|
"go/format"
|
||||||
"go/parser"
|
"go/parser"
|
||||||
"go/token"
|
|
||||||
"go/types"
|
"go/types"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -18,7 +17,6 @@ import (
|
|||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
"golang.org/x/tools/go/loader"
|
"golang.org/x/tools/go/loader"
|
||||||
"golang.org/x/tools/go/types/typeutil"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Generate performs dependency injection for a single package,
|
// Generate performs dependency injection for a single package,
|
||||||
@@ -310,430 +308,6 @@ func (g *gen) p(format string, args ...interface{}) {
|
|||||||
fmt.Fprintf(&g.buf, format, args...)
|
fmt.Fprintf(&g.buf, format, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// A providerSet describes a set of providers. The zero value is an empty
|
|
||||||
// providerSet.
|
|
||||||
type providerSet struct {
|
|
||||||
providers []*providerInfo
|
|
||||||
imports []providerSetImport
|
|
||||||
}
|
|
||||||
|
|
||||||
type providerSetImport struct {
|
|
||||||
providerSetRef
|
|
||||||
pos token.Pos
|
|
||||||
}
|
|
||||||
|
|
||||||
// findProviderSets processes a package and extracts the provider sets declared in it.
|
|
||||||
func findProviderSets(fset *token.FileSet, pkg *types.Package, typeInfo *types.Info, files []*ast.File) (map[string]*providerSet, error) {
|
|
||||||
sets := make(map[string]*providerSet)
|
|
||||||
var directives []directive
|
|
||||||
for _, f := range files {
|
|
||||||
fileScope := typeInfo.Scopes[f]
|
|
||||||
for _, c := range f.Comments {
|
|
||||||
directives = extractDirectives(directives[:0], c)
|
|
||||||
for _, d := range directives {
|
|
||||||
switch d.kind {
|
|
||||||
case "provide", "use":
|
|
||||||
// handled later
|
|
||||||
case "import":
|
|
||||||
if fileScope == nil {
|
|
||||||
return nil, fmt.Errorf("%s: no scope found for file (likely a bug)", fset.File(f.Pos()).Name())
|
|
||||||
}
|
|
||||||
i := strings.IndexByte(d.line, ' ')
|
|
||||||
// TODO(light): allow multiple imports in one line
|
|
||||||
if i == -1 {
|
|
||||||
return nil, fmt.Errorf("%s: invalid import: expected TARGET SETREF", fset.Position(d.pos))
|
|
||||||
}
|
|
||||||
name, spec := d.line[:i], d.line[i+1:]
|
|
||||||
ref, err := parseProviderSetRef(spec, fileScope, pkg.Path(), d.pos)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("%v: %v", fset.Position(d.pos), err)
|
|
||||||
}
|
|
||||||
if ref.importPath != pkg.Path() {
|
|
||||||
imported := false
|
|
||||||
for _, imp := range pkg.Imports() {
|
|
||||||
if ref.importPath == imp.Path() {
|
|
||||||
imported = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !imported {
|
|
||||||
return nil, fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fset.Position(d.pos), name, ref.importPath)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if mod := sets[name]; mod != nil {
|
|
||||||
found := false
|
|
||||||
for _, other := range mod.imports {
|
|
||||||
if ref == other.providerSetRef {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
mod.imports = append(mod.imports, providerSetImport{providerSetRef: ref, pos: d.pos})
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
sets[name] = &providerSet{
|
|
||||||
imports: []providerSetImport{{providerSetRef: ref, pos: d.pos}},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("%v: unknown directive %s", fset.Position(d.pos), d.kind)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cmap := ast.NewCommentMap(fset, f, f.Comments)
|
|
||||||
for _, decl := range f.Decls {
|
|
||||||
directives = directives[:0]
|
|
||||||
for _, cg := range cmap[decl] {
|
|
||||||
directives = extractDirectives(directives, cg)
|
|
||||||
}
|
|
||||||
fn, isFunction := decl.(*ast.FuncDecl)
|
|
||||||
var providerSetName string
|
|
||||||
for _, d := range directives {
|
|
||||||
if d.kind != "provide" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if providerSetName != "" {
|
|
||||||
return nil, fmt.Errorf("%v: multiple provide directives for %s", fset.Position(d.pos), fn.Name.Name)
|
|
||||||
}
|
|
||||||
if !isFunction {
|
|
||||||
return nil, fmt.Errorf("%v: only functions can be marked as providers", fset.Position(d.pos))
|
|
||||||
}
|
|
||||||
providerSetName = fn.Name.Name
|
|
||||||
if d.line != "" {
|
|
||||||
// TODO(light): validate identifier
|
|
||||||
providerSetName = d.line
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if providerSetName == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
fpos := fn.Pos()
|
|
||||||
sig := typeInfo.ObjectOf(fn.Name).Type().(*types.Signature)
|
|
||||||
r := sig.Results()
|
|
||||||
var hasErr bool
|
|
||||||
switch r.Len() {
|
|
||||||
case 1:
|
|
||||||
hasErr = false
|
|
||||||
case 2:
|
|
||||||
if t := r.At(1).Type(); !types.Identical(t, errorType) {
|
|
||||||
return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be error", fset.Position(fpos), fn.Name.Name)
|
|
||||||
}
|
|
||||||
hasErr = true
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fset.Position(fpos), fn.Name.Name)
|
|
||||||
}
|
|
||||||
out := r.At(0).Type()
|
|
||||||
p := sig.Params()
|
|
||||||
provider := &providerInfo{
|
|
||||||
importPath: pkg.Path(),
|
|
||||||
funcName: fn.Name.Name,
|
|
||||||
pos: fn.Pos(),
|
|
||||||
args: make([]types.Type, p.Len()),
|
|
||||||
out: out,
|
|
||||||
hasErr: hasErr,
|
|
||||||
}
|
|
||||||
for i := 0; i < p.Len(); i++ {
|
|
||||||
provider.args[i] = p.At(i).Type()
|
|
||||||
for j := 0; j < i; j++ {
|
|
||||||
if types.Identical(provider.args[i], provider.args[j]) {
|
|
||||||
return nil, fmt.Errorf("%v: provider has multiple parameters of type %s", fset.Position(fpos), types.TypeString(provider.args[j], nil))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if mod := sets[providerSetName]; mod != nil {
|
|
||||||
for _, other := range mod.providers {
|
|
||||||
if types.Identical(other.out, provider.out) {
|
|
||||||
return nil, fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fset.Position(fpos), providerSetName, types.TypeString(provider.out, nil), fset.Position(other.pos))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
mod.providers = append(mod.providers, provider)
|
|
||||||
} else {
|
|
||||||
sets[providerSetName] = &providerSet{
|
|
||||||
providers: []*providerInfo{provider},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return sets, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// providerSetCache is a lazily evaluated index of provider sets.
|
|
||||||
type providerSetCache struct {
|
|
||||||
sets map[string]map[string]*providerSet
|
|
||||||
fset *token.FileSet
|
|
||||||
prog *loader.Program
|
|
||||||
}
|
|
||||||
|
|
||||||
func newProviderSetCache(prog *loader.Program) *providerSetCache {
|
|
||||||
return &providerSetCache{
|
|
||||||
fset: prog.Fset,
|
|
||||||
prog: prog,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mc *providerSetCache) get(ref providerSetRef) (*providerSet, error) {
|
|
||||||
if mods, cached := mc.sets[ref.importPath]; cached {
|
|
||||||
mod := mods[ref.name]
|
|
||||||
if mod == nil {
|
|
||||||
return nil, fmt.Errorf("no such provider set %s in package %q", ref.name, ref.importPath)
|
|
||||||
}
|
|
||||||
return mod, nil
|
|
||||||
}
|
|
||||||
if mc.sets == nil {
|
|
||||||
mc.sets = make(map[string]map[string]*providerSet)
|
|
||||||
}
|
|
||||||
pkg := mc.prog.Package(ref.importPath)
|
|
||||||
mods, err := findProviderSets(mc.fset, pkg.Pkg, &pkg.Info, pkg.Files)
|
|
||||||
if err != nil {
|
|
||||||
mc.sets[ref.importPath] = nil
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
mc.sets[ref.importPath] = mods
|
|
||||||
mod := mods[ref.name]
|
|
||||||
if mod == nil {
|
|
||||||
return nil, fmt.Errorf("no such provider set %s in package %q", ref.name, ref.importPath)
|
|
||||||
}
|
|
||||||
return mod, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// solve finds the sequence of calls required to produce an output type
|
|
||||||
// with an optional set of provided inputs.
|
|
||||||
func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []providerSetRef) ([]call, error) {
|
|
||||||
for i, g := range given {
|
|
||||||
for _, h := range given[:i] {
|
|
||||||
if types.Identical(g, h) {
|
|
||||||
return nil, fmt.Errorf("multiple inputs of the same type %s", types.TypeString(g, nil))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
providers, err := buildProviderMap(mc, sets)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start building the mapping of type to local variable of the given type.
|
|
||||||
// The first len(given) local variables are the given types.
|
|
||||||
index := new(typeutil.Map)
|
|
||||||
for i, g := range given {
|
|
||||||
if p := providers.At(g); p != nil {
|
|
||||||
pp := p.(*providerInfo)
|
|
||||||
return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.funcName, mc.fset.Position(pp.pos))
|
|
||||||
}
|
|
||||||
index.Set(g, i)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Topological sort of the directed graph defined by the providers
|
|
||||||
// using a depth-first search. The graph may contain cycles, which
|
|
||||||
// should trigger an error.
|
|
||||||
var calls []call
|
|
||||||
var visit func(trail []types.Type) error
|
|
||||||
visit = func(trail []types.Type) error {
|
|
||||||
typ := trail[len(trail)-1]
|
|
||||||
if index.At(typ) != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
for _, t := range trail[:len(trail)-1] {
|
|
||||||
if types.Identical(typ, t) {
|
|
||||||
// TODO(light): describe cycle
|
|
||||||
return fmt.Errorf("cycle for %s", types.TypeString(typ, nil))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
p, _ := providers.At(typ).(*providerInfo)
|
|
||||||
if p == nil {
|
|
||||||
if len(trail) == 1 {
|
|
||||||
return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil))
|
|
||||||
}
|
|
||||||
// TODO(light): give name of provider
|
|
||||||
return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2], nil))
|
|
||||||
}
|
|
||||||
for _, a := range p.args {
|
|
||||||
// TODO(light): this will discard grown trail arrays.
|
|
||||||
if err := visit(append(trail, a)); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
args := make([]int, len(p.args))
|
|
||||||
for i := range p.args {
|
|
||||||
args[i] = index.At(p.args[i]).(int)
|
|
||||||
}
|
|
||||||
index.Set(typ, len(given)+len(calls))
|
|
||||||
calls = append(calls, call{
|
|
||||||
importPath: p.importPath,
|
|
||||||
funcName: p.funcName,
|
|
||||||
args: args,
|
|
||||||
out: typ,
|
|
||||||
hasErr: p.hasErr,
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err := visit([]types.Type{out}); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return calls, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildProviderMap(mc *providerSetCache, sets []providerSetRef) (*typeutil.Map, error) {
|
|
||||||
type nextEnt struct {
|
|
||||||
to providerSetRef
|
|
||||||
|
|
||||||
from providerSetRef
|
|
||||||
pos token.Pos
|
|
||||||
}
|
|
||||||
|
|
||||||
pm := new(typeutil.Map) // to *providerInfo
|
|
||||||
visited := make(map[providerSetRef]struct{})
|
|
||||||
var next []nextEnt
|
|
||||||
for _, ref := range sets {
|
|
||||||
next = append(next, nextEnt{to: ref})
|
|
||||||
}
|
|
||||||
for len(next) > 0 {
|
|
||||||
curr := next[0]
|
|
||||||
copy(next, next[1:])
|
|
||||||
next = next[:len(next)-1]
|
|
||||||
if _, skip := visited[curr.to]; skip {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
visited[curr.to] = struct{}{}
|
|
||||||
mod, err := mc.get(curr.to)
|
|
||||||
if err != nil {
|
|
||||||
if !curr.pos.IsValid() {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("%v: %v", mc.fset.Position(curr.pos), err)
|
|
||||||
}
|
|
||||||
for _, p := range mod.providers {
|
|
||||||
if prev := pm.At(p.out); prev != nil {
|
|
||||||
pos := mc.fset.Position(p.pos)
|
|
||||||
typ := types.TypeString(p.out, nil)
|
|
||||||
prevPos := mc.fset.Position(prev.(*providerInfo).pos)
|
|
||||||
if curr.from.importPath != "" {
|
|
||||||
return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, curr.from, prevPos)
|
|
||||||
}
|
|
||||||
pm.Set(p.out, p)
|
|
||||||
}
|
|
||||||
for _, imp := range mod.imports {
|
|
||||||
next = append(next, nextEnt{to: imp.providerSetRef, from: curr.to, pos: imp.pos})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return pm, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// A call represents a step of an injector function.
|
|
||||||
type call struct {
|
|
||||||
// importPath and funcName identify the provider function to call.
|
|
||||||
importPath string
|
|
||||||
funcName string
|
|
||||||
|
|
||||||
// args is a list of arguments to call the provider with. Each element is either:
|
|
||||||
// a) one of the givens (args[i] < len(given)) or
|
|
||||||
// b) the result of a previous provider call (args[i] >= len(given)).
|
|
||||||
args []int
|
|
||||||
|
|
||||||
// out is the type produced by this provider call.
|
|
||||||
out types.Type
|
|
||||||
|
|
||||||
// hasErr is true if the provider call returns an error.
|
|
||||||
hasErr bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// providerInfo records the signature of a provider function.
|
|
||||||
type providerInfo struct {
|
|
||||||
importPath string
|
|
||||||
funcName string
|
|
||||||
pos token.Pos
|
|
||||||
args []types.Type
|
|
||||||
out types.Type
|
|
||||||
hasErr bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// A providerSetRef is a parsed reference to a collection of providers.
|
|
||||||
type providerSetRef struct {
|
|
||||||
importPath string
|
|
||||||
name string
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseProviderSetRef(ref string, s *types.Scope, pkg string, pos token.Pos) (providerSetRef, error) {
|
|
||||||
// TODO(light): verify that provider set name is an identifier before returning
|
|
||||||
|
|
||||||
i := strings.LastIndexByte(ref, '.')
|
|
||||||
if i == -1 {
|
|
||||||
return providerSetRef{importPath: pkg, name: ref}, nil
|
|
||||||
}
|
|
||||||
imp, name := ref[:i], ref[i+1:]
|
|
||||||
if strings.HasPrefix(imp, `"`) {
|
|
||||||
path, err := strconv.Unquote(imp)
|
|
||||||
if err != nil {
|
|
||||||
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: bad import path", ref)
|
|
||||||
}
|
|
||||||
return providerSetRef{importPath: path, name: name}, nil
|
|
||||||
}
|
|
||||||
_, obj := s.LookupParent(imp, pos)
|
|
||||||
if obj == nil {
|
|
||||||
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: unknown identifier %s", ref, imp)
|
|
||||||
}
|
|
||||||
pn, ok := obj.(*types.PkgName)
|
|
||||||
if !ok {
|
|
||||||
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: %s does not name a package", ref, imp)
|
|
||||||
}
|
|
||||||
return providerSetRef{importPath: pn.Imported().Path(), name: name}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ref providerSetRef) String() string {
|
|
||||||
return strconv.Quote(ref.importPath) + "." + ref.name
|
|
||||||
}
|
|
||||||
|
|
||||||
type directive struct {
|
|
||||||
pos token.Pos
|
|
||||||
kind string
|
|
||||||
line string
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractDirectives(d []directive, cg *ast.CommentGroup) []directive {
|
|
||||||
const prefix = "goose:"
|
|
||||||
text := cg.Text()
|
|
||||||
for len(text) > 0 {
|
|
||||||
text = strings.TrimLeft(text, " \t\r\n")
|
|
||||||
if !strings.HasPrefix(text, prefix) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
line := text[len(prefix):]
|
|
||||||
if i := strings.IndexByte(line, '\n'); i != -1 {
|
|
||||||
line, text = line[:i], line[i+1:]
|
|
||||||
} else {
|
|
||||||
text = ""
|
|
||||||
}
|
|
||||||
if i := strings.IndexByte(line, ' '); i != -1 {
|
|
||||||
d = append(d, directive{
|
|
||||||
kind: line[:i],
|
|
||||||
line: strings.TrimSpace(line[i+1:]),
|
|
||||||
pos: cg.Pos(), // TODO(light): more precise position
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
d = append(d, directive{
|
|
||||||
kind: line,
|
|
||||||
pos: cg.Pos(), // TODO(light): more precise position
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
|
|
||||||
// isInjectFile reports whether a given file is an injection template.
|
|
||||||
func isInjectFile(f *ast.File) bool {
|
|
||||||
// TODO(light): better determination
|
|
||||||
for _, cg := range f.Comments {
|
|
||||||
text := cg.Text()
|
|
||||||
if strings.HasPrefix(text, "+build") && strings.Contains(text, "gooseinject") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// zeroValue returns the shortest expression that evaluates to the zero
|
// zeroValue returns the shortest expression that evaluates to the zero
|
||||||
// value for the given type.
|
// value for the given type.
|
||||||
func zeroValue(t types.Type, qf types.Qualifier) string {
|
func zeroValue(t types.Type, qf types.Qualifier) string {
|
||||||
|
|||||||
293
internal/goose/parse.go
Normal file
293
internal/goose/parse.go
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
package goose
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"go/ast"
|
||||||
|
"go/token"
|
||||||
|
"go/types"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/tools/go/loader"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A providerSet describes a set of providers. The zero value is an empty
|
||||||
|
// providerSet.
|
||||||
|
type providerSet struct {
|
||||||
|
providers []*providerInfo
|
||||||
|
imports []providerSetImport
|
||||||
|
}
|
||||||
|
|
||||||
|
type providerSetImport struct {
|
||||||
|
providerSetRef
|
||||||
|
pos token.Pos
|
||||||
|
}
|
||||||
|
|
||||||
|
// providerInfo records the signature of a provider function.
|
||||||
|
type providerInfo struct {
|
||||||
|
importPath string
|
||||||
|
funcName string
|
||||||
|
pos token.Pos
|
||||||
|
args []types.Type
|
||||||
|
out types.Type
|
||||||
|
hasErr bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// findProviderSets processes a package and extracts the provider sets declared in it.
|
||||||
|
func findProviderSets(fset *token.FileSet, pkg *types.Package, typeInfo *types.Info, files []*ast.File) (map[string]*providerSet, error) {
|
||||||
|
sets := make(map[string]*providerSet)
|
||||||
|
var directives []directive
|
||||||
|
for _, f := range files {
|
||||||
|
fileScope := typeInfo.Scopes[f]
|
||||||
|
for _, c := range f.Comments {
|
||||||
|
directives = extractDirectives(directives[:0], c)
|
||||||
|
for _, d := range directives {
|
||||||
|
switch d.kind {
|
||||||
|
case "provide", "use":
|
||||||
|
// handled later
|
||||||
|
case "import":
|
||||||
|
if fileScope == nil {
|
||||||
|
return nil, fmt.Errorf("%s: no scope found for file (likely a bug)", fset.File(f.Pos()).Name())
|
||||||
|
}
|
||||||
|
i := strings.IndexByte(d.line, ' ')
|
||||||
|
// TODO(light): allow multiple imports in one line
|
||||||
|
if i == -1 {
|
||||||
|
return nil, fmt.Errorf("%s: invalid import: expected TARGET SETREF", fset.Position(d.pos))
|
||||||
|
}
|
||||||
|
name, spec := d.line[:i], d.line[i+1:]
|
||||||
|
ref, err := parseProviderSetRef(spec, fileScope, pkg.Path(), d.pos)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%v: %v", fset.Position(d.pos), err)
|
||||||
|
}
|
||||||
|
if ref.importPath != pkg.Path() {
|
||||||
|
imported := false
|
||||||
|
for _, imp := range pkg.Imports() {
|
||||||
|
if ref.importPath == imp.Path() {
|
||||||
|
imported = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !imported {
|
||||||
|
return nil, fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fset.Position(d.pos), name, ref.importPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if mod := sets[name]; mod != nil {
|
||||||
|
found := false
|
||||||
|
for _, other := range mod.imports {
|
||||||
|
if ref == other.providerSetRef {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
mod.imports = append(mod.imports, providerSetImport{providerSetRef: ref, pos: d.pos})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sets[name] = &providerSet{
|
||||||
|
imports: []providerSetImport{{providerSetRef: ref, pos: d.pos}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("%v: unknown directive %s", fset.Position(d.pos), d.kind)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cmap := ast.NewCommentMap(fset, f, f.Comments)
|
||||||
|
for _, decl := range f.Decls {
|
||||||
|
directives = directives[:0]
|
||||||
|
for _, cg := range cmap[decl] {
|
||||||
|
directives = extractDirectives(directives, cg)
|
||||||
|
}
|
||||||
|
fn, isFunction := decl.(*ast.FuncDecl)
|
||||||
|
var providerSetName string
|
||||||
|
for _, d := range directives {
|
||||||
|
if d.kind != "provide" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if providerSetName != "" {
|
||||||
|
return nil, fmt.Errorf("%v: multiple provide directives for %s", fset.Position(d.pos), fn.Name.Name)
|
||||||
|
}
|
||||||
|
if !isFunction {
|
||||||
|
return nil, fmt.Errorf("%v: only functions can be marked as providers", fset.Position(d.pos))
|
||||||
|
}
|
||||||
|
providerSetName = fn.Name.Name
|
||||||
|
if d.line != "" {
|
||||||
|
// TODO(light): validate identifier
|
||||||
|
providerSetName = d.line
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if providerSetName == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fpos := fn.Pos()
|
||||||
|
sig := typeInfo.ObjectOf(fn.Name).Type().(*types.Signature)
|
||||||
|
r := sig.Results()
|
||||||
|
var hasErr bool
|
||||||
|
switch r.Len() {
|
||||||
|
case 1:
|
||||||
|
hasErr = false
|
||||||
|
case 2:
|
||||||
|
if t := r.At(1).Type(); !types.Identical(t, errorType) {
|
||||||
|
return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be error", fset.Position(fpos), fn.Name.Name)
|
||||||
|
}
|
||||||
|
hasErr = true
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fset.Position(fpos), fn.Name.Name)
|
||||||
|
}
|
||||||
|
out := r.At(0).Type()
|
||||||
|
p := sig.Params()
|
||||||
|
provider := &providerInfo{
|
||||||
|
importPath: pkg.Path(),
|
||||||
|
funcName: fn.Name.Name,
|
||||||
|
pos: fn.Pos(),
|
||||||
|
args: make([]types.Type, p.Len()),
|
||||||
|
out: out,
|
||||||
|
hasErr: hasErr,
|
||||||
|
}
|
||||||
|
for i := 0; i < p.Len(); i++ {
|
||||||
|
provider.args[i] = p.At(i).Type()
|
||||||
|
for j := 0; j < i; j++ {
|
||||||
|
if types.Identical(provider.args[i], provider.args[j]) {
|
||||||
|
return nil, fmt.Errorf("%v: provider has multiple parameters of type %s", fset.Position(fpos), types.TypeString(provider.args[j], nil))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if mod := sets[providerSetName]; mod != nil {
|
||||||
|
for _, other := range mod.providers {
|
||||||
|
if types.Identical(other.out, provider.out) {
|
||||||
|
return nil, fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fset.Position(fpos), providerSetName, types.TypeString(provider.out, nil), fset.Position(other.pos))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mod.providers = append(mod.providers, provider)
|
||||||
|
} else {
|
||||||
|
sets[providerSetName] = &providerSet{
|
||||||
|
providers: []*providerInfo{provider},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sets, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// providerSetCache is a lazily evaluated index of provider sets.
|
||||||
|
type providerSetCache struct {
|
||||||
|
sets map[string]map[string]*providerSet
|
||||||
|
fset *token.FileSet
|
||||||
|
prog *loader.Program
|
||||||
|
}
|
||||||
|
|
||||||
|
func newProviderSetCache(prog *loader.Program) *providerSetCache {
|
||||||
|
return &providerSetCache{
|
||||||
|
fset: prog.Fset,
|
||||||
|
prog: prog,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *providerSetCache) get(ref providerSetRef) (*providerSet, error) {
|
||||||
|
if mods, cached := mc.sets[ref.importPath]; cached {
|
||||||
|
mod := mods[ref.name]
|
||||||
|
if mod == nil {
|
||||||
|
return nil, fmt.Errorf("no such provider set %s in package %q", ref.name, ref.importPath)
|
||||||
|
}
|
||||||
|
return mod, nil
|
||||||
|
}
|
||||||
|
if mc.sets == nil {
|
||||||
|
mc.sets = make(map[string]map[string]*providerSet)
|
||||||
|
}
|
||||||
|
pkg := mc.prog.Package(ref.importPath)
|
||||||
|
mods, err := findProviderSets(mc.fset, pkg.Pkg, &pkg.Info, pkg.Files)
|
||||||
|
if err != nil {
|
||||||
|
mc.sets[ref.importPath] = nil
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
mc.sets[ref.importPath] = mods
|
||||||
|
mod := mods[ref.name]
|
||||||
|
if mod == nil {
|
||||||
|
return nil, fmt.Errorf("no such provider set %s in package %q", ref.name, ref.importPath)
|
||||||
|
}
|
||||||
|
return mod, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// A providerSetRef is a parsed reference to a collection of providers.
|
||||||
|
type providerSetRef struct {
|
||||||
|
importPath string
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseProviderSetRef(ref string, s *types.Scope, pkg string, pos token.Pos) (providerSetRef, error) {
|
||||||
|
// TODO(light): verify that provider set name is an identifier before returning
|
||||||
|
|
||||||
|
i := strings.LastIndexByte(ref, '.')
|
||||||
|
if i == -1 {
|
||||||
|
return providerSetRef{importPath: pkg, name: ref}, nil
|
||||||
|
}
|
||||||
|
imp, name := ref[:i], ref[i+1:]
|
||||||
|
if strings.HasPrefix(imp, `"`) {
|
||||||
|
path, err := strconv.Unquote(imp)
|
||||||
|
if err != nil {
|
||||||
|
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: bad import path", ref)
|
||||||
|
}
|
||||||
|
return providerSetRef{importPath: path, name: name}, nil
|
||||||
|
}
|
||||||
|
_, obj := s.LookupParent(imp, pos)
|
||||||
|
if obj == nil {
|
||||||
|
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: unknown identifier %s", ref, imp)
|
||||||
|
}
|
||||||
|
pn, ok := obj.(*types.PkgName)
|
||||||
|
if !ok {
|
||||||
|
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: %s does not name a package", ref, imp)
|
||||||
|
}
|
||||||
|
return providerSetRef{importPath: pn.Imported().Path(), name: name}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ref providerSetRef) String() string {
|
||||||
|
return strconv.Quote(ref.importPath) + "." + ref.name
|
||||||
|
}
|
||||||
|
|
||||||
|
type directive struct {
|
||||||
|
pos token.Pos
|
||||||
|
kind string
|
||||||
|
line string
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractDirectives(d []directive, cg *ast.CommentGroup) []directive {
|
||||||
|
const prefix = "goose:"
|
||||||
|
text := cg.Text()
|
||||||
|
for len(text) > 0 {
|
||||||
|
text = strings.TrimLeft(text, " \t\r\n")
|
||||||
|
if !strings.HasPrefix(text, prefix) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
line := text[len(prefix):]
|
||||||
|
if i := strings.IndexByte(line, '\n'); i != -1 {
|
||||||
|
line, text = line[:i], line[i+1:]
|
||||||
|
} else {
|
||||||
|
text = ""
|
||||||
|
}
|
||||||
|
if i := strings.IndexByte(line, ' '); i != -1 {
|
||||||
|
d = append(d, directive{
|
||||||
|
kind: line[:i],
|
||||||
|
line: strings.TrimSpace(line[i+1:]),
|
||||||
|
pos: cg.Pos(), // TODO(light): more precise position
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
d = append(d, directive{
|
||||||
|
kind: line,
|
||||||
|
pos: cg.Pos(), // TODO(light): more precise position
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
// isInjectFile reports whether a given file is an injection template.
|
||||||
|
func isInjectFile(f *ast.File) bool {
|
||||||
|
// TODO(light): better determination
|
||||||
|
for _, cg := range f.Comments {
|
||||||
|
text := cg.Text()
|
||||||
|
if strings.HasPrefix(text, "+build") && strings.Contains(text, "gooseinject") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user