goose: add optional provider inputs

Reviewed-by: Tuo Shan <shantuo@google.com>
This commit is contained in:
Ross Light
2018-03-30 21:34:08 -07:00
parent c594f05699
commit 479a501c08
12 changed files with 370 additions and 138 deletions

View File

@@ -30,140 +30,41 @@ type providerInfo struct {
importPath string
funcName string
pos token.Pos
args []types.Type
args []providerInput
out types.Type
hasErr bool
}
type providerInput struct {
typ types.Type
optional bool
}
type findContext struct {
fset *token.FileSet
pkg *types.Package
typeInfo *types.Info
r *importResolver
}
// findProviderSets processes a package and extracts the provider sets declared in it.
func findProviderSets(fset *token.FileSet, pkg *types.Package, r *importResolver, typeInfo *types.Info, files []*ast.File) (map[string]*providerSet, error) {
func findProviderSets(fctx findContext, files []*ast.File) (map[string]*providerSet, error) {
sets := make(map[string]*providerSet)
var directives []directive
for _, f := range files {
fileScope := typeInfo.Scopes[f]
for _, c := range f.Comments {
directives = extractDirectives(directives[:0], c)
for _, d := range directives {
switch d.kind {
case "provide", "use":
// handled later
case "import":
if fileScope == nil {
return nil, fmt.Errorf("%s: no scope found for file (likely a bug)", fset.File(f.Pos()).Name())
}
i := strings.IndexByte(d.line, ' ')
// TODO(light): allow multiple imports in one line
if i == -1 {
return nil, fmt.Errorf("%s: invalid import: expected TARGET SETREF", fset.Position(d.pos))
}
name, spec := d.line[:i], d.line[i+1:]
ref, err := parseProviderSetRef(r, spec, fileScope, pkg.Path(), d.pos)
if err != nil {
return nil, fmt.Errorf("%v: %v", fset.Position(d.pos), err)
}
if ref.importPath != pkg.Path() {
imported := false
for _, imp := range pkg.Imports() {
if ref.importPath == imp.Path() {
imported = true
break
}
}
if !imported {
return nil, fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fset.Position(d.pos), name, ref.importPath)
}
}
if mod := sets[name]; mod != nil {
found := false
for _, other := range mod.imports {
if ref == other.providerSetRef {
found = true
break
}
}
if !found {
mod.imports = append(mod.imports, providerSetImport{providerSetRef: ref, pos: d.pos})
}
} else {
sets[name] = &providerSet{
imports: []providerSetImport{{providerSetRef: ref, pos: d.pos}},
}
}
default:
return nil, fmt.Errorf("%v: unknown directive %s", fset.Position(d.pos), d.kind)
}
}
fileScope := fctx.typeInfo.Scopes[f]
if fileScope == nil {
return nil, fmt.Errorf("%s: no scope found for file (likely a bug)", fctx.fset.File(f.Pos()).Name())
}
cmap := ast.NewCommentMap(fset, f, f.Comments)
for _, decl := range f.Decls {
directives = directives[:0]
for _, cg := range cmap[decl] {
directives = extractDirectives(directives, cg)
}
fn, isFunction := decl.(*ast.FuncDecl)
var providerSetName string
for _, d := range directives {
if d.kind != "provide" {
continue
for _, dg := range parseFile(fctx.fset, f) {
if dg.decl != nil {
if err := processDeclDirectives(fctx, sets, fileScope, dg); err != nil {
return nil, err
}
if providerSetName != "" {
return nil, fmt.Errorf("%v: multiple provide directives for %s", fset.Position(d.pos), fn.Name.Name)
}
if !isFunction {
return nil, fmt.Errorf("%v: only functions can be marked as providers", fset.Position(d.pos))
}
providerSetName = fn.Name.Name
if d.line != "" {
// TODO(light): validate identifier
providerSetName = d.line
}
}
if providerSetName == "" {
continue
}
fpos := fn.Pos()
sig := typeInfo.ObjectOf(fn.Name).Type().(*types.Signature)
r := sig.Results()
var hasErr bool
switch r.Len() {
case 1:
hasErr = false
case 2:
if t := r.At(1).Type(); !types.Identical(t, errorType) {
return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be error", fset.Position(fpos), fn.Name.Name)
}
hasErr = true
default:
return nil, fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fset.Position(fpos), fn.Name.Name)
}
out := r.At(0).Type()
p := sig.Params()
provider := &providerInfo{
importPath: pkg.Path(),
funcName: fn.Name.Name,
pos: fn.Pos(),
args: make([]types.Type, p.Len()),
out: out,
hasErr: hasErr,
}
for i := 0; i < p.Len(); i++ {
provider.args[i] = p.At(i).Type()
for j := 0; j < i; j++ {
if types.Identical(provider.args[i], provider.args[j]) {
return nil, fmt.Errorf("%v: provider has multiple parameters of type %s", fset.Position(fpos), types.TypeString(provider.args[j], nil))
}
}
}
if mod := sets[providerSetName]; mod != nil {
for _, other := range mod.providers {
if types.Identical(other.out, provider.out) {
return nil, fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fset.Position(fpos), providerSetName, types.TypeString(provider.out, nil), fset.Position(other.pos))
}
}
mod.providers = append(mod.providers, provider)
} else {
sets[providerSetName] = &providerSet{
providers: []*providerInfo{provider},
for _, d := range dg.dirs {
if err := processUnassociatedDirective(fctx, sets, fileScope, d); err != nil {
return nil, err
}
}
}
}
@@ -171,6 +72,147 @@ func findProviderSets(fset *token.FileSet, pkg *types.Package, r *importResolver
return sets, nil
}
// processUnassociatedDirective handles any directive that was not associated with a top-level declaration.
func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet, scope *types.Scope, d directive) error {
switch d.kind {
case "provide", "optional":
return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(d.pos))
case "use":
// Ignore, picked up by injector flow.
case "import":
i := strings.IndexByte(d.line, ' ')
// TODO(light): allow multiple imports in one line
if i == -1 {
return fmt.Errorf("%s: invalid import: expected TARGET SETREF", fctx.fset.Position(d.pos))
}
name, spec := d.line[:i], d.line[i+1:]
ref, err := parseProviderSetRef(fctx.r, spec, scope, fctx.pkg.Path(), d.pos)
if err != nil {
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
}
if ref.importPath != fctx.pkg.Path() {
imported := false
for _, imp := range fctx.pkg.Imports() {
if ref.importPath == imp.Path() {
imported = true
break
}
}
if !imported {
return fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fctx.fset.Position(d.pos), name, ref.importPath)
}
}
if mod := sets[name]; mod != nil {
found := false
for _, other := range mod.imports {
if ref == other.providerSetRef {
found = true
break
}
}
if !found {
mod.imports = append(mod.imports, providerSetImport{providerSetRef: ref, pos: d.pos})
}
} else {
sets[name] = &providerSet{
imports: []providerSetImport{{providerSetRef: ref, pos: d.pos}},
}
}
default:
return fmt.Errorf("%v: unknown directive %s", fctx.fset.Position(d.pos), d.kind)
}
return nil
}
// processDeclDirectives processes the directives associated with a top-level declaration.
func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope *types.Scope, dg directiveGroup) error {
p, err := dg.single(fctx.fset, "provide")
if err != nil {
return err
}
if !p.isValid() {
for _, d := range dg.dirs {
if d.kind == "optional" {
return fmt.Errorf("%v: cannot use goose:%s directive on non-provider", fctx.fset.Position(d.pos), d.kind)
}
}
return nil
}
fn, ok := dg.decl.(*ast.FuncDecl)
if !ok {
return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(p.pos))
}
sig := fctx.typeInfo.ObjectOf(fn.Name).Type().(*types.Signature)
optionals := make([]bool, sig.Params().Len())
for _, d := range dg.dirs {
if d.kind == "optional" {
// Marking the given argument names as optional inputs.
for _, arg := range strings.Fields(d.line) {
pi := paramIndex(sig.Params(), arg)
if pi == -1 {
return fmt.Errorf("%v: %s is not a parameter of func %s", fctx.fset.Position(d.pos), arg, fn.Name.Name)
}
optionals[pi] = true
}
}
}
fpos := fn.Pos()
r := sig.Results()
var hasErr bool
switch r.Len() {
case 1:
hasErr = false
case 2:
if t := r.At(1).Type(); !types.Identical(t, errorType) {
return fmt.Errorf("%v: wrong signature for provider %s: second return type must be error", fctx.fset.Position(fpos), fn.Name.Name)
}
hasErr = true
default:
return fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fctx.fset.Position(fpos), fn.Name.Name)
}
out := r.At(0).Type()
params := sig.Params()
provider := &providerInfo{
importPath: fctx.pkg.Path(),
funcName: fn.Name.Name,
pos: fn.Pos(),
args: make([]providerInput, params.Len()),
out: out,
hasErr: hasErr,
}
for i := 0; i < params.Len(); i++ {
provider.args[i] = providerInput{
typ: params.At(i).Type(),
optional: optionals[i],
}
for j := 0; j < i; j++ {
if types.Identical(provider.args[i].typ, provider.args[j].typ) {
return fmt.Errorf("%v: provider has multiple parameters of type %s", fctx.fset.Position(fpos), types.TypeString(provider.args[j].typ, nil))
}
}
}
providerSetName := fn.Name.Name
if p.line != "" {
// TODO(light): validate identifier
providerSetName = p.line
}
if mod := sets[providerSetName]; mod != nil {
for _, other := range mod.providers {
if types.Identical(other.out, provider.out) {
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(fn.Pos()), providerSetName, types.TypeString(provider.out, nil), fctx.fset.Position(other.pos))
}
}
mod.providers = append(mod.providers, provider)
} else {
sets[providerSetName] = &providerSet{
providers: []*providerInfo{provider},
}
}
return nil
}
// providerSetCache is a lazily evaluated index of provider sets.
type providerSetCache struct {
sets map[string]map[string]*providerSet
@@ -199,7 +241,12 @@ func (mc *providerSetCache) get(ref providerSetRef) (*providerSet, error) {
mc.sets = make(map[string]map[string]*providerSet)
}
pkg := mc.prog.Package(ref.importPath)
mods, err := findProviderSets(mc.fset, pkg.Pkg, mc.r, &pkg.Info, pkg.Files)
mods, err := findProviderSets(findContext{
fset: mc.fset,
pkg: pkg.Pkg,
typeInfo: &pkg.Info,
r: mc.r,
}, pkg.Files)
if err != nil {
mc.sets[ref.importPath] = nil
return nil, err
@@ -282,12 +329,68 @@ func (r *importResolver) resolve(pos token.Pos, path string) (string, error) {
return pkg.ImportPath, nil
}
// A directive is a parsed goose comment.
type directive struct {
pos token.Pos
kind string
line string
}
// A directiveGroup is a set of directives associated with a particular
// declaration.
type directiveGroup struct {
decl ast.Decl
dirs []directive
}
// parseFile extracts the directives from a file, grouped by declaration.
func parseFile(fset *token.FileSet, f *ast.File) []directiveGroup {
cmap := ast.NewCommentMap(fset, f, f.Comments)
// Reserve first group for directives that don't associate with a
// declaration, like import.
groups := make([]directiveGroup, 1, len(f.Decls)+1)
// Walk declarations and add to groups.
for _, decl := range f.Decls {
grp := directiveGroup{decl: decl}
ast.Inspect(decl, func(node ast.Node) bool {
if g := cmap[node]; len(g) > 0 {
for _, cg := range g {
start := len(grp.dirs)
grp.dirs = extractDirectives(grp.dirs, cg)
// Move directives that don't associate into the unassociated group.
n := 0
for i := start; i < len(grp.dirs); i++ {
if k := grp.dirs[i].kind; k == "provide" || k == "optional" || k == "use" {
grp.dirs[start+n] = grp.dirs[i]
n++
} else {
groups[0].dirs = append(groups[0].dirs, grp.dirs[i])
}
}
grp.dirs = grp.dirs[:start+n]
}
delete(cmap, node)
}
return true
})
if len(grp.dirs) > 0 {
groups = append(groups, grp)
}
}
// Place remaining directives into the unassociated group.
unassoc := &groups[0]
for _, g := range cmap {
for _, cg := range g {
unassoc.dirs = extractDirectives(unassoc.dirs, cg)
}
}
if len(unassoc.dirs) == 0 {
return groups[1:]
}
return groups
}
func extractDirectives(d []directive, cg *ast.CommentGroup) []directive {
const prefix = "goose:"
text := cg.Text()
@@ -318,6 +421,37 @@ func extractDirectives(d []directive, cg *ast.CommentGroup) []directive {
return d
}
// single finds at most one directive that matches the given kind.
func (dg directiveGroup) single(fset *token.FileSet, kind string) (directive, error) {
var found directive
ok := false
for _, d := range dg.dirs {
if d.kind != kind {
continue
}
if ok {
switch decl := dg.decl.(type) {
case *ast.FuncDecl:
return directive{}, fmt.Errorf("%v: multiple %s directives for %s", fset.Position(d.pos), kind, decl.Name.Name)
case *ast.GenDecl:
if decl.Tok == token.TYPE && len(decl.Specs) == 1 {
name := decl.Specs[0].(*ast.TypeSpec).Name.Name
return directive{}, fmt.Errorf("%v: multiple %s directives for %s", fset.Position(d.pos), kind, name)
}
return directive{}, fmt.Errorf("%v: multiple %s directives", fset.Position(d.pos), kind)
default:
return directive{}, fmt.Errorf("%v: multiple %s directives", fset.Position(d.pos), kind)
}
}
found, ok = d, true
}
return found, nil
}
func (d directive) isValid() bool {
return d.kind != ""
}
// isInjectFile reports whether a given file is an injection template.
func isInjectFile(f *ast.File) bool {
// TODO(light): better determination
@@ -329,3 +463,14 @@ func isInjectFile(f *ast.File) bool {
}
return false
}
// paramIndex returns the index of the parameter with the given name, or
// -1 if no such parameter exists.
func paramIndex(params *types.Tuple, name string) int {
for i := 0; i < params.Len(); i++ {
if params.At(i).Name() == name {
return i
}
}
return -1
}