goose: add optional provider inputs
Reviewed-by: Tuo Shan <shantuo@google.com>
This commit is contained in:
18
README.md
18
README.md
@@ -206,6 +206,24 @@ through the dependency graph, you would create a wrapping type:
|
|||||||
type MySQLConnectionString string
|
type MySQLConnectionString string
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Advanced Features
|
||||||
|
|
||||||
|
### Optional Inputs
|
||||||
|
|
||||||
|
A provider input can be marked optional using `goose:optional`:
|
||||||
|
|
||||||
|
```go
|
||||||
|
//goose:provide Bar
|
||||||
|
//goose:optional foo
|
||||||
|
|
||||||
|
func provideBar(foo Foo) Bar {
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
If used as part of an injector that does not bring in the `Foo` dependency, then
|
||||||
|
the injector will pass the provider the zero value as the `foo` argument.
|
||||||
|
|
||||||
## Future Work
|
## Future Work
|
||||||
|
|
||||||
- Support for map bindings.
|
- Support for map bindings.
|
||||||
|
|||||||
@@ -14,14 +14,16 @@ type call struct {
|
|||||||
importPath string
|
importPath string
|
||||||
funcName string
|
funcName string
|
||||||
|
|
||||||
// args is a list of arguments to call the provider with. Each element is either:
|
// args is a list of arguments to call the provider with. Each element is:
|
||||||
// a) one of the givens (args[i] < len(given)) or
|
// a) one of the givens (args[i] < len(given)),
|
||||||
// b) the result of a previous provider call (args[i] >= len(given)).
|
// b) the result of a previous provider call (args[i] >= len(given)), or
|
||||||
|
// c) the zero value for the type (args[i] == -1).
|
||||||
args []int
|
args []int
|
||||||
|
|
||||||
|
// ins is the list of types this call receives as arguments.
|
||||||
|
ins []types.Type
|
||||||
// out is the type produced by this provider call.
|
// out is the type produced by this provider call.
|
||||||
out types.Type
|
out types.Type
|
||||||
|
|
||||||
// hasErr is true if the provider call returns an error.
|
// hasErr is true if the provider call returns an error.
|
||||||
hasErr bool
|
hasErr bool
|
||||||
}
|
}
|
||||||
@@ -56,14 +58,14 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []prov
|
|||||||
// using a depth-first search. The graph may contain cycles, which
|
// using a depth-first search. The graph may contain cycles, which
|
||||||
// should trigger an error.
|
// should trigger an error.
|
||||||
var calls []call
|
var calls []call
|
||||||
var visit func(trail []types.Type) error
|
var visit func(trail []providerInput) error
|
||||||
visit = func(trail []types.Type) error {
|
visit = func(trail []providerInput) error {
|
||||||
typ := trail[len(trail)-1]
|
typ := trail[len(trail)-1].typ
|
||||||
if index.At(typ) != nil {
|
if index.At(typ) != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
for _, t := range trail[:len(trail)-1] {
|
for _, in := range trail[:len(trail)-1] {
|
||||||
if types.Identical(typ, t) {
|
if types.Identical(typ, in.typ) {
|
||||||
// TODO(light): describe cycle
|
// TODO(light): describe cycle
|
||||||
return fmt.Errorf("cycle for %s", types.TypeString(typ, nil))
|
return fmt.Errorf("cycle for %s", types.TypeString(typ, nil))
|
||||||
}
|
}
|
||||||
@@ -71,11 +73,14 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []prov
|
|||||||
|
|
||||||
p, _ := providers.At(typ).(*providerInfo)
|
p, _ := providers.At(typ).(*providerInfo)
|
||||||
if p == nil {
|
if p == nil {
|
||||||
|
if trail[len(trail)-1].optional {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
if len(trail) == 1 {
|
if len(trail) == 1 {
|
||||||
return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil))
|
return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil))
|
||||||
}
|
}
|
||||||
// TODO(light): give name of provider
|
// TODO(light): give name of provider
|
||||||
return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2], nil))
|
return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2].typ, nil))
|
||||||
}
|
}
|
||||||
for _, a := range p.args {
|
for _, a := range p.args {
|
||||||
// TODO(light): this will discard grown trail arrays.
|
// TODO(light): this will discard grown trail arrays.
|
||||||
@@ -84,20 +89,27 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []prov
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
args := make([]int, len(p.args))
|
args := make([]int, len(p.args))
|
||||||
|
ins := make([]types.Type, len(p.args))
|
||||||
for i := range p.args {
|
for i := range p.args {
|
||||||
args[i] = index.At(p.args[i]).(int)
|
ins[i] = p.args[i].typ
|
||||||
|
if x := index.At(p.args[i].typ); x != nil {
|
||||||
|
args[i] = x.(int)
|
||||||
|
} else {
|
||||||
|
args[i] = -1
|
||||||
|
}
|
||||||
}
|
}
|
||||||
index.Set(typ, len(given)+len(calls))
|
index.Set(typ, len(given)+len(calls))
|
||||||
calls = append(calls, call{
|
calls = append(calls, call{
|
||||||
importPath: p.importPath,
|
importPath: p.importPath,
|
||||||
funcName: p.funcName,
|
funcName: p.funcName,
|
||||||
args: args,
|
args: args,
|
||||||
|
ins: ins,
|
||||||
out: typ,
|
out: typ,
|
||||||
hasErr: p.hasErr,
|
hasErr: p.hasErr,
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if err := visit([]types.Type{out}); err != nil {
|
if err := visit([]providerInput{{typ: out}}); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return calls, nil
|
return calls, nil
|
||||||
|
|||||||
@@ -174,6 +174,11 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se
|
|||||||
}
|
}
|
||||||
for _, c := range calls {
|
for _, c := range calls {
|
||||||
g.qualifyImport(c.importPath)
|
g.qualifyImport(c.importPath)
|
||||||
|
for i := range c.args {
|
||||||
|
if c.args[i] == -1 {
|
||||||
|
zeroValue(c.ins[i], g.qualifyPkg)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
outTypeString := types.TypeString(outType, g.qualifyPkg)
|
outTypeString := types.TypeString(outType, g.qualifyPkg)
|
||||||
zv := zeroValue(outType, g.qualifyPkg)
|
zv := zeroValue(outType, g.qualifyPkg)
|
||||||
@@ -236,7 +241,9 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se
|
|||||||
if j > 0 {
|
if j > 0 {
|
||||||
g.p(", ")
|
g.p(", ")
|
||||||
}
|
}
|
||||||
if a < params.Len() {
|
if a == -1 {
|
||||||
|
g.p("%s", zeroValue(c.ins[j], g.qualifyPkg))
|
||||||
|
} else if a < params.Len() {
|
||||||
g.p("%s", paramNames[a])
|
g.p("%s", paramNames[a])
|
||||||
} else {
|
} else {
|
||||||
g.p("%s", localNames[a-params.Len()])
|
g.p("%s", localNames[a-params.Len()])
|
||||||
|
|||||||
@@ -30,140 +30,41 @@ type providerInfo struct {
|
|||||||
importPath string
|
importPath string
|
||||||
funcName string
|
funcName string
|
||||||
pos token.Pos
|
pos token.Pos
|
||||||
args []types.Type
|
args []providerInput
|
||||||
out types.Type
|
out types.Type
|
||||||
hasErr bool
|
hasErr bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type providerInput struct {
|
||||||
|
typ types.Type
|
||||||
|
optional bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type findContext struct {
|
||||||
|
fset *token.FileSet
|
||||||
|
pkg *types.Package
|
||||||
|
typeInfo *types.Info
|
||||||
|
r *importResolver
|
||||||
|
}
|
||||||
|
|
||||||
// findProviderSets processes a package and extracts the provider sets declared in it.
|
// findProviderSets processes a package and extracts the provider sets declared in it.
|
||||||
func findProviderSets(fset *token.FileSet, pkg *types.Package, r *importResolver, typeInfo *types.Info, files []*ast.File) (map[string]*providerSet, error) {
|
func findProviderSets(fctx findContext, files []*ast.File) (map[string]*providerSet, error) {
|
||||||
sets := make(map[string]*providerSet)
|
sets := make(map[string]*providerSet)
|
||||||
var directives []directive
|
|
||||||
for _, f := range files {
|
for _, f := range files {
|
||||||
fileScope := typeInfo.Scopes[f]
|
fileScope := fctx.typeInfo.Scopes[f]
|
||||||
for _, c := range f.Comments {
|
if fileScope == nil {
|
||||||
directives = extractDirectives(directives[:0], c)
|
return nil, fmt.Errorf("%s: no scope found for file (likely a bug)", fctx.fset.File(f.Pos()).Name())
|
||||||
for _, d := range directives {
|
|
||||||
switch d.kind {
|
|
||||||
case "provide", "use":
|
|
||||||
// handled later
|
|
||||||
case "import":
|
|
||||||
if fileScope == nil {
|
|
||||||
return nil, fmt.Errorf("%s: no scope found for file (likely a bug)", fset.File(f.Pos()).Name())
|
|
||||||
}
|
|
||||||
i := strings.IndexByte(d.line, ' ')
|
|
||||||
// TODO(light): allow multiple imports in one line
|
|
||||||
if i == -1 {
|
|
||||||
return nil, fmt.Errorf("%s: invalid import: expected TARGET SETREF", fset.Position(d.pos))
|
|
||||||
}
|
|
||||||
name, spec := d.line[:i], d.line[i+1:]
|
|
||||||
ref, err := parseProviderSetRef(r, spec, fileScope, pkg.Path(), d.pos)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("%v: %v", fset.Position(d.pos), err)
|
|
||||||
}
|
|
||||||
if ref.importPath != pkg.Path() {
|
|
||||||
imported := false
|
|
||||||
for _, imp := range pkg.Imports() {
|
|
||||||
if ref.importPath == imp.Path() {
|
|
||||||
imported = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !imported {
|
|
||||||
return nil, fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fset.Position(d.pos), name, ref.importPath)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if mod := sets[name]; mod != nil {
|
|
||||||
found := false
|
|
||||||
for _, other := range mod.imports {
|
|
||||||
if ref == other.providerSetRef {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
mod.imports = append(mod.imports, providerSetImport{providerSetRef: ref, pos: d.pos})
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
sets[name] = &providerSet{
|
|
||||||
imports: []providerSetImport{{providerSetRef: ref, pos: d.pos}},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("%v: unknown directive %s", fset.Position(d.pos), d.kind)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
cmap := ast.NewCommentMap(fset, f, f.Comments)
|
for _, dg := range parseFile(fctx.fset, f) {
|
||||||
for _, decl := range f.Decls {
|
if dg.decl != nil {
|
||||||
directives = directives[:0]
|
if err := processDeclDirectives(fctx, sets, fileScope, dg); err != nil {
|
||||||
for _, cg := range cmap[decl] {
|
return nil, err
|
||||||
directives = extractDirectives(directives, cg)
|
|
||||||
}
|
|
||||||
fn, isFunction := decl.(*ast.FuncDecl)
|
|
||||||
var providerSetName string
|
|
||||||
for _, d := range directives {
|
|
||||||
if d.kind != "provide" {
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
if providerSetName != "" {
|
|
||||||
return nil, fmt.Errorf("%v: multiple provide directives for %s", fset.Position(d.pos), fn.Name.Name)
|
|
||||||
}
|
|
||||||
if !isFunction {
|
|
||||||
return nil, fmt.Errorf("%v: only functions can be marked as providers", fset.Position(d.pos))
|
|
||||||
}
|
|
||||||
providerSetName = fn.Name.Name
|
|
||||||
if d.line != "" {
|
|
||||||
// TODO(light): validate identifier
|
|
||||||
providerSetName = d.line
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if providerSetName == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
fpos := fn.Pos()
|
|
||||||
sig := typeInfo.ObjectOf(fn.Name).Type().(*types.Signature)
|
|
||||||
r := sig.Results()
|
|
||||||
var hasErr bool
|
|
||||||
switch r.Len() {
|
|
||||||
case 1:
|
|
||||||
hasErr = false
|
|
||||||
case 2:
|
|
||||||
if t := r.At(1).Type(); !types.Identical(t, errorType) {
|
|
||||||
return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be error", fset.Position(fpos), fn.Name.Name)
|
|
||||||
}
|
|
||||||
hasErr = true
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fset.Position(fpos), fn.Name.Name)
|
|
||||||
}
|
|
||||||
out := r.At(0).Type()
|
|
||||||
p := sig.Params()
|
|
||||||
provider := &providerInfo{
|
|
||||||
importPath: pkg.Path(),
|
|
||||||
funcName: fn.Name.Name,
|
|
||||||
pos: fn.Pos(),
|
|
||||||
args: make([]types.Type, p.Len()),
|
|
||||||
out: out,
|
|
||||||
hasErr: hasErr,
|
|
||||||
}
|
|
||||||
for i := 0; i < p.Len(); i++ {
|
|
||||||
provider.args[i] = p.At(i).Type()
|
|
||||||
for j := 0; j < i; j++ {
|
|
||||||
if types.Identical(provider.args[i], provider.args[j]) {
|
|
||||||
return nil, fmt.Errorf("%v: provider has multiple parameters of type %s", fset.Position(fpos), types.TypeString(provider.args[j], nil))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if mod := sets[providerSetName]; mod != nil {
|
|
||||||
for _, other := range mod.providers {
|
|
||||||
if types.Identical(other.out, provider.out) {
|
|
||||||
return nil, fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fset.Position(fpos), providerSetName, types.TypeString(provider.out, nil), fset.Position(other.pos))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
mod.providers = append(mod.providers, provider)
|
|
||||||
} else {
|
} else {
|
||||||
sets[providerSetName] = &providerSet{
|
for _, d := range dg.dirs {
|
||||||
providers: []*providerInfo{provider},
|
if err := processUnassociatedDirective(fctx, sets, fileScope, d); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -171,6 +72,147 @@ func findProviderSets(fset *token.FileSet, pkg *types.Package, r *importResolver
|
|||||||
return sets, nil
|
return sets, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// processUnassociatedDirective handles any directive that was not associated with a top-level declaration.
|
||||||
|
func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet, scope *types.Scope, d directive) error {
|
||||||
|
switch d.kind {
|
||||||
|
case "provide", "optional":
|
||||||
|
return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(d.pos))
|
||||||
|
case "use":
|
||||||
|
// Ignore, picked up by injector flow.
|
||||||
|
case "import":
|
||||||
|
i := strings.IndexByte(d.line, ' ')
|
||||||
|
// TODO(light): allow multiple imports in one line
|
||||||
|
if i == -1 {
|
||||||
|
return fmt.Errorf("%s: invalid import: expected TARGET SETREF", fctx.fset.Position(d.pos))
|
||||||
|
}
|
||||||
|
name, spec := d.line[:i], d.line[i+1:]
|
||||||
|
ref, err := parseProviderSetRef(fctx.r, spec, scope, fctx.pkg.Path(), d.pos)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
|
||||||
|
}
|
||||||
|
if ref.importPath != fctx.pkg.Path() {
|
||||||
|
imported := false
|
||||||
|
for _, imp := range fctx.pkg.Imports() {
|
||||||
|
if ref.importPath == imp.Path() {
|
||||||
|
imported = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !imported {
|
||||||
|
return fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fctx.fset.Position(d.pos), name, ref.importPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if mod := sets[name]; mod != nil {
|
||||||
|
found := false
|
||||||
|
for _, other := range mod.imports {
|
||||||
|
if ref == other.providerSetRef {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
mod.imports = append(mod.imports, providerSetImport{providerSetRef: ref, pos: d.pos})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sets[name] = &providerSet{
|
||||||
|
imports: []providerSetImport{{providerSetRef: ref, pos: d.pos}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%v: unknown directive %s", fctx.fset.Position(d.pos), d.kind)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// processDeclDirectives processes the directives associated with a top-level declaration.
|
||||||
|
func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope *types.Scope, dg directiveGroup) error {
|
||||||
|
p, err := dg.single(fctx.fset, "provide")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !p.isValid() {
|
||||||
|
for _, d := range dg.dirs {
|
||||||
|
if d.kind == "optional" {
|
||||||
|
return fmt.Errorf("%v: cannot use goose:%s directive on non-provider", fctx.fset.Position(d.pos), d.kind)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
fn, ok := dg.decl.(*ast.FuncDecl)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(p.pos))
|
||||||
|
}
|
||||||
|
sig := fctx.typeInfo.ObjectOf(fn.Name).Type().(*types.Signature)
|
||||||
|
|
||||||
|
optionals := make([]bool, sig.Params().Len())
|
||||||
|
for _, d := range dg.dirs {
|
||||||
|
if d.kind == "optional" {
|
||||||
|
// Marking the given argument names as optional inputs.
|
||||||
|
for _, arg := range strings.Fields(d.line) {
|
||||||
|
pi := paramIndex(sig.Params(), arg)
|
||||||
|
if pi == -1 {
|
||||||
|
return fmt.Errorf("%v: %s is not a parameter of func %s", fctx.fset.Position(d.pos), arg, fn.Name.Name)
|
||||||
|
}
|
||||||
|
optionals[pi] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fpos := fn.Pos()
|
||||||
|
r := sig.Results()
|
||||||
|
var hasErr bool
|
||||||
|
switch r.Len() {
|
||||||
|
case 1:
|
||||||
|
hasErr = false
|
||||||
|
case 2:
|
||||||
|
if t := r.At(1).Type(); !types.Identical(t, errorType) {
|
||||||
|
return fmt.Errorf("%v: wrong signature for provider %s: second return type must be error", fctx.fset.Position(fpos), fn.Name.Name)
|
||||||
|
}
|
||||||
|
hasErr = true
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fctx.fset.Position(fpos), fn.Name.Name)
|
||||||
|
}
|
||||||
|
out := r.At(0).Type()
|
||||||
|
params := sig.Params()
|
||||||
|
provider := &providerInfo{
|
||||||
|
importPath: fctx.pkg.Path(),
|
||||||
|
funcName: fn.Name.Name,
|
||||||
|
pos: fn.Pos(),
|
||||||
|
args: make([]providerInput, params.Len()),
|
||||||
|
out: out,
|
||||||
|
hasErr: hasErr,
|
||||||
|
}
|
||||||
|
for i := 0; i < params.Len(); i++ {
|
||||||
|
provider.args[i] = providerInput{
|
||||||
|
typ: params.At(i).Type(),
|
||||||
|
optional: optionals[i],
|
||||||
|
}
|
||||||
|
for j := 0; j < i; j++ {
|
||||||
|
if types.Identical(provider.args[i].typ, provider.args[j].typ) {
|
||||||
|
return fmt.Errorf("%v: provider has multiple parameters of type %s", fctx.fset.Position(fpos), types.TypeString(provider.args[j].typ, nil))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
providerSetName := fn.Name.Name
|
||||||
|
if p.line != "" {
|
||||||
|
// TODO(light): validate identifier
|
||||||
|
providerSetName = p.line
|
||||||
|
}
|
||||||
|
if mod := sets[providerSetName]; mod != nil {
|
||||||
|
for _, other := range mod.providers {
|
||||||
|
if types.Identical(other.out, provider.out) {
|
||||||
|
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(fn.Pos()), providerSetName, types.TypeString(provider.out, nil), fctx.fset.Position(other.pos))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mod.providers = append(mod.providers, provider)
|
||||||
|
} else {
|
||||||
|
sets[providerSetName] = &providerSet{
|
||||||
|
providers: []*providerInfo{provider},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// providerSetCache is a lazily evaluated index of provider sets.
|
// providerSetCache is a lazily evaluated index of provider sets.
|
||||||
type providerSetCache struct {
|
type providerSetCache struct {
|
||||||
sets map[string]map[string]*providerSet
|
sets map[string]map[string]*providerSet
|
||||||
@@ -199,7 +241,12 @@ func (mc *providerSetCache) get(ref providerSetRef) (*providerSet, error) {
|
|||||||
mc.sets = make(map[string]map[string]*providerSet)
|
mc.sets = make(map[string]map[string]*providerSet)
|
||||||
}
|
}
|
||||||
pkg := mc.prog.Package(ref.importPath)
|
pkg := mc.prog.Package(ref.importPath)
|
||||||
mods, err := findProviderSets(mc.fset, pkg.Pkg, mc.r, &pkg.Info, pkg.Files)
|
mods, err := findProviderSets(findContext{
|
||||||
|
fset: mc.fset,
|
||||||
|
pkg: pkg.Pkg,
|
||||||
|
typeInfo: &pkg.Info,
|
||||||
|
r: mc.r,
|
||||||
|
}, pkg.Files)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mc.sets[ref.importPath] = nil
|
mc.sets[ref.importPath] = nil
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -282,12 +329,68 @@ func (r *importResolver) resolve(pos token.Pos, path string) (string, error) {
|
|||||||
return pkg.ImportPath, nil
|
return pkg.ImportPath, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A directive is a parsed goose comment.
|
||||||
type directive struct {
|
type directive struct {
|
||||||
pos token.Pos
|
pos token.Pos
|
||||||
kind string
|
kind string
|
||||||
line string
|
line string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A directiveGroup is a set of directives associated with a particular
|
||||||
|
// declaration.
|
||||||
|
type directiveGroup struct {
|
||||||
|
decl ast.Decl
|
||||||
|
dirs []directive
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseFile extracts the directives from a file, grouped by declaration.
|
||||||
|
func parseFile(fset *token.FileSet, f *ast.File) []directiveGroup {
|
||||||
|
cmap := ast.NewCommentMap(fset, f, f.Comments)
|
||||||
|
// Reserve first group for directives that don't associate with a
|
||||||
|
// declaration, like import.
|
||||||
|
groups := make([]directiveGroup, 1, len(f.Decls)+1)
|
||||||
|
// Walk declarations and add to groups.
|
||||||
|
for _, decl := range f.Decls {
|
||||||
|
grp := directiveGroup{decl: decl}
|
||||||
|
ast.Inspect(decl, func(node ast.Node) bool {
|
||||||
|
if g := cmap[node]; len(g) > 0 {
|
||||||
|
for _, cg := range g {
|
||||||
|
start := len(grp.dirs)
|
||||||
|
grp.dirs = extractDirectives(grp.dirs, cg)
|
||||||
|
|
||||||
|
// Move directives that don't associate into the unassociated group.
|
||||||
|
n := 0
|
||||||
|
for i := start; i < len(grp.dirs); i++ {
|
||||||
|
if k := grp.dirs[i].kind; k == "provide" || k == "optional" || k == "use" {
|
||||||
|
grp.dirs[start+n] = grp.dirs[i]
|
||||||
|
n++
|
||||||
|
} else {
|
||||||
|
groups[0].dirs = append(groups[0].dirs, grp.dirs[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
grp.dirs = grp.dirs[:start+n]
|
||||||
|
}
|
||||||
|
delete(cmap, node)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if len(grp.dirs) > 0 {
|
||||||
|
groups = append(groups, grp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Place remaining directives into the unassociated group.
|
||||||
|
unassoc := &groups[0]
|
||||||
|
for _, g := range cmap {
|
||||||
|
for _, cg := range g {
|
||||||
|
unassoc.dirs = extractDirectives(unassoc.dirs, cg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(unassoc.dirs) == 0 {
|
||||||
|
return groups[1:]
|
||||||
|
}
|
||||||
|
return groups
|
||||||
|
}
|
||||||
|
|
||||||
func extractDirectives(d []directive, cg *ast.CommentGroup) []directive {
|
func extractDirectives(d []directive, cg *ast.CommentGroup) []directive {
|
||||||
const prefix = "goose:"
|
const prefix = "goose:"
|
||||||
text := cg.Text()
|
text := cg.Text()
|
||||||
@@ -318,6 +421,37 @@ func extractDirectives(d []directive, cg *ast.CommentGroup) []directive {
|
|||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// single finds at most one directive that matches the given kind.
|
||||||
|
func (dg directiveGroup) single(fset *token.FileSet, kind string) (directive, error) {
|
||||||
|
var found directive
|
||||||
|
ok := false
|
||||||
|
for _, d := range dg.dirs {
|
||||||
|
if d.kind != kind {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
switch decl := dg.decl.(type) {
|
||||||
|
case *ast.FuncDecl:
|
||||||
|
return directive{}, fmt.Errorf("%v: multiple %s directives for %s", fset.Position(d.pos), kind, decl.Name.Name)
|
||||||
|
case *ast.GenDecl:
|
||||||
|
if decl.Tok == token.TYPE && len(decl.Specs) == 1 {
|
||||||
|
name := decl.Specs[0].(*ast.TypeSpec).Name.Name
|
||||||
|
return directive{}, fmt.Errorf("%v: multiple %s directives for %s", fset.Position(d.pos), kind, name)
|
||||||
|
}
|
||||||
|
return directive{}, fmt.Errorf("%v: multiple %s directives", fset.Position(d.pos), kind)
|
||||||
|
default:
|
||||||
|
return directive{}, fmt.Errorf("%v: multiple %s directives", fset.Position(d.pos), kind)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
found, ok = d, true
|
||||||
|
}
|
||||||
|
return found, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d directive) isValid() bool {
|
||||||
|
return d.kind != ""
|
||||||
|
}
|
||||||
|
|
||||||
// isInjectFile reports whether a given file is an injection template.
|
// isInjectFile reports whether a given file is an injection template.
|
||||||
func isInjectFile(f *ast.File) bool {
|
func isInjectFile(f *ast.File) bool {
|
||||||
// TODO(light): better determination
|
// TODO(light): better determination
|
||||||
@@ -329,3 +463,14 @@ func isInjectFile(f *ast.File) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// paramIndex returns the index of the parameter with the given name, or
|
||||||
|
// -1 if no such parameter exists.
|
||||||
|
func paramIndex(params *types.Tuple, name string) int {
|
||||||
|
for i := 0; i < params.Len(); i++ {
|
||||||
|
if params.At(i).Name() == name {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|||||||
16
internal/goose/testdata/OptionalMissing/foo/foo.go
vendored
Normal file
16
internal/goose/testdata/OptionalMissing/foo/foo.go
vendored
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
fmt.Println(injectBar())
|
||||||
|
}
|
||||||
|
|
||||||
|
type foo int
|
||||||
|
type bar int
|
||||||
|
|
||||||
|
//goose:provide
|
||||||
|
//goose:optional f
|
||||||
|
func provideBar(f foo) bar {
|
||||||
|
return bar(f)
|
||||||
|
}
|
||||||
7
internal/goose/testdata/OptionalMissing/foo/foo_goose.go
vendored
Normal file
7
internal/goose/testdata/OptionalMissing/foo/foo_goose.go
vendored
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
//+build gooseinject
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
//goose:use provideBar
|
||||||
|
|
||||||
|
func injectBar() bar
|
||||||
1
internal/goose/testdata/OptionalMissing/out.txt
vendored
Normal file
1
internal/goose/testdata/OptionalMissing/out.txt
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
0
|
||||||
1
internal/goose/testdata/OptionalMissing/pkg
vendored
Normal file
1
internal/goose/testdata/OptionalMissing/pkg
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
foo
|
||||||
16
internal/goose/testdata/OptionalPresent/foo/foo.go
vendored
Normal file
16
internal/goose/testdata/OptionalPresent/foo/foo.go
vendored
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
fmt.Println(injectBar(42))
|
||||||
|
}
|
||||||
|
|
||||||
|
type foo int
|
||||||
|
type bar int
|
||||||
|
|
||||||
|
//goose:provide
|
||||||
|
//goose:optional f
|
||||||
|
func provideBar(f foo) bar {
|
||||||
|
return bar(f)
|
||||||
|
}
|
||||||
7
internal/goose/testdata/OptionalPresent/foo/foo_goose.go
vendored
Normal file
7
internal/goose/testdata/OptionalPresent/foo/foo_goose.go
vendored
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
//+build gooseinject
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
//goose:use provideBar
|
||||||
|
|
||||||
|
func injectBar(foo) bar
|
||||||
1
internal/goose/testdata/OptionalPresent/out.txt
vendored
Normal file
1
internal/goose/testdata/OptionalPresent/out.txt
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
42
|
||||||
1
internal/goose/testdata/OptionalPresent/pkg
vendored
Normal file
1
internal/goose/testdata/OptionalPresent/pkg
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
foo
|
||||||
Reference in New Issue
Block a user