goose: allow multiple arguments to use and import

Reviewed-by: Tuo Shan <shantuo@google.com>
This commit is contained in:
Ross Light
2018-04-02 10:57:48 -07:00
parent 5261a8a8bb
commit 73d4c0f0fc
11 changed files with 196 additions and 43 deletions

View File

@@ -71,11 +71,17 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
if d.kind != "use" { if d.kind != "use" {
return nil, fmt.Errorf("%v: cannot use %s directive on inject function", prog.Fset.Position(d.pos), d.kind) return nil, fmt.Errorf("%v: cannot use %s directive on inject function", prog.Fset.Position(d.pos), d.kind)
} }
ref, err := parseProviderSetRef(r, d.line, fileScope, g.currPackage, d.pos) args := d.args()
if err != nil { if len(args) == 0 {
return nil, fmt.Errorf("%v: %v", prog.Fset.Position(d.pos), err) return nil, fmt.Errorf("%v: goose:use must have at least one provider set reference", prog.Fset.Position(d.pos))
}
for _, arg := range args {
ref, err := parseProviderSetRef(r, arg, fileScope, g.currPackage, d.pos)
if err != nil {
return nil, fmt.Errorf("%v: %v", prog.Fset.Position(d.pos), err)
}
sets = append(sets, ref)
} }
sets = append(sets, ref)
} }
sig := pkgInfo.ObjectOf(fn.Name).Type().(*types.Signature) sig := pkgInfo.ObjectOf(fn.Name).Type().(*types.Signature)
if err := g.inject(mc, fn.Name.Name, sig, sets); err != nil { if err := g.inject(mc, fn.Name.Name, sig, sets); err != nil {

View File

@@ -9,6 +9,7 @@ import (
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
"unicode"
"golang.org/x/tools/go/loader" "golang.org/x/tools/go/loader"
) )
@@ -80,42 +81,43 @@ func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet
case "use": case "use":
// Ignore, picked up by injector flow. // Ignore, picked up by injector flow.
case "import": case "import":
i := strings.IndexByte(d.line, ' ') args := d.args()
// TODO(light): allow multiple imports in one line if len(args) < 2 {
if i == -1 {
return fmt.Errorf("%s: invalid import: expected TARGET SETREF", fctx.fset.Position(d.pos)) return fmt.Errorf("%s: invalid import: expected TARGET SETREF", fctx.fset.Position(d.pos))
} }
name, spec := d.line[:i], d.line[i+1:] name := args[0]
ref, err := parseProviderSetRef(fctx.r, spec, scope, fctx.pkg.Path(), d.pos) for _, spec := range args[1:] {
if err != nil { ref, err := parseProviderSetRef(fctx.r, spec, scope, fctx.pkg.Path(), d.pos)
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err) if err != nil {
} return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
if ref.importPath != fctx.pkg.Path() { }
imported := false if ref.importPath != fctx.pkg.Path() {
for _, imp := range fctx.pkg.Imports() { imported := false
if ref.importPath == imp.Path() { for _, imp := range fctx.pkg.Imports() {
imported = true if ref.importPath == imp.Path() {
break imported = true
break
}
}
if !imported {
return fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fctx.fset.Position(d.pos), name, ref.importPath)
} }
} }
if !imported { if mod := sets[name]; mod != nil {
return fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fctx.fset.Position(d.pos), name, ref.importPath) found := false
} for _, other := range mod.imports {
} if ref == other.providerSetRef {
if mod := sets[name]; mod != nil { found = true
found := false break
for _, other := range mod.imports { }
if ref == other.providerSetRef { }
found = true if !found {
break mod.imports = append(mod.imports, providerSetImport{providerSetRef: ref, pos: d.pos})
}
} else {
sets[name] = &providerSet{
imports: []providerSetImport{{providerSetRef: ref, pos: d.pos}},
} }
}
if !found {
mod.imports = append(mod.imports, providerSetImport{providerSetRef: ref, pos: d.pos})
}
} else {
sets[name] = &providerSet{
imports: []providerSetImport{{providerSetRef: ref, pos: d.pos}},
} }
} }
default: default:
@@ -148,7 +150,7 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope
for _, d := range dg.dirs { for _, d := range dg.dirs {
if d.kind == "optional" { if d.kind == "optional" {
// Marking the given argument names as optional inputs. // Marking the given argument names as optional inputs.
for _, arg := range strings.Fields(d.line) { for _, arg := range d.args() {
pi := paramIndex(sig.Params(), arg) pi := paramIndex(sig.Params(), arg)
if pi == -1 { if pi == -1 {
return fmt.Errorf("%v: %s is not a parameter of func %s", fctx.fset.Position(d.pos), arg, fn.Name.Name) return fmt.Errorf("%v: %s is not a parameter of func %s", fctx.fset.Position(d.pos), arg, fn.Name.Name)
@@ -194,9 +196,11 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope
} }
} }
providerSetName := fn.Name.Name providerSetName := fn.Name.Name
if p.line != "" { if args := p.args(); len(args) == 1 {
// TODO(light): validate identifier // TODO(light): validate identifier
providerSetName = p.line providerSetName = args[0]
} else if len(args) > 1 {
return fmt.Errorf("%v: goose:provide takes at most one argument", fctx.fset.Position(fpos))
} }
if mod := sets[providerSetName]; mod != nil { if mod := sets[providerSetName]; mod != nil {
for _, other := range mod.providers { for _, other := range mod.providers {
@@ -400,11 +404,9 @@ func extractDirectives(d []directive, cg *ast.CommentGroup) []directive {
break break
} }
line := text[len(prefix):] line := text[len(prefix):]
if i := strings.IndexByte(line, '\n'); i != -1 { // Text() is always newline terminated.
line, text = line[:i], line[i+1:] i := strings.IndexByte(line, '\n')
} else { line, text = line[:i], line[i+1:]
text = ""
}
if i := strings.IndexByte(line, ' '); i != -1 { if i := strings.IndexByte(line, ' '); i != -1 {
d = append(d, directive{ d = append(d, directive{
kind: line[:i], kind: line[:i],
@@ -452,6 +454,54 @@ func (d directive) isValid() bool {
return d.kind != "" return d.kind != ""
} }
// args splits the directive line into tokens.
func (d directive) args() []string {
var args []string
start := -1
state := 0 // 0 = boundary, 1 = in token, 2 = in quote, 3 = quote backslash
for i, r := range d.line {
switch state {
case 0:
// Argument boundary
switch {
case r == '"':
start = i
state = 2
case !unicode.IsSpace(r):
start = i
state = 1
}
case 1:
// In token
switch {
case unicode.IsSpace(r):
args = append(args, d.line[start:i])
start = -1
state = 0
case r == '"':
state = 2
}
case 2:
// In quotes
switch {
case r == '"':
state = 1
case r == '\\':
state = 3
}
case 3:
// Quote backslash. Consumes one character and jumps back into "in quote" state.
state = 2
default:
panic("unreachable")
}
}
if start != -1 {
args = append(args, d.line[start:])
}
return args
}
// isInjectFile reports whether a given file is an injection template. // isInjectFile reports whether a given file is an injection template.
func isInjectFile(f *ast.File) bool { func isInjectFile(f *ast.File) bool {
// TODO(light): better determination // TODO(light): better determination

View File

@@ -0,0 +1,37 @@
package goose
import (
"testing"
)
func TestDirectiveArgs(t *testing.T) {
tests := []struct {
line string
args []string
}{
{"", []string{}},
{" \t ", []string{}},
{"foo", []string{"foo"}},
{"foo bar", []string{"foo", "bar"}},
{" foo \t bar ", []string{"foo", "bar"}},
{"foo \"bar \t baz\" fido", []string{"foo", "\"bar \t baz\"", "fido"}},
{"foo \"bar \t baz\".quux fido", []string{"foo", "\"bar \t baz\".quux", "fido"}},
}
eq := func(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
for _, test := range tests {
got := (directive{line: test.line}).args()
if !eq(got, test.args) {
t.Errorf("directive{line: %q}.args() = %q; want %q", test.line, got, test.args)
}
}
}

View File

@@ -0,0 +1,22 @@
package main
import "fmt"
func main() {
fmt.Println(injectFooBar())
}
type Foo int
type FooBar int
//goose:provide Foo
func provideFoo() Foo {
return 41
}
//goose:provide FooBar
func provideFooBar(foo Foo) FooBar {
return FooBar(foo) + 1
}
//goose:import Set Foo FooBar

View File

@@ -0,0 +1,7 @@
//+build gooseinject
package main
//goose:use Set
func injectFooBar() FooBar

View File

@@ -0,0 +1 @@
42

View File

@@ -0,0 +1 @@
foo

View File

@@ -0,0 +1,20 @@
package main
import "fmt"
func main() {
fmt.Println(injectFooBar())
}
type Foo int
type FooBar int
//goose:provide Foo
func provideFoo() Foo {
return 41
}
//goose:provide FooBar
func provideFooBar(foo Foo) FooBar {
return FooBar(foo) + 1
}

View File

@@ -0,0 +1,7 @@
//+build gooseinject
package main
//goose:use Foo FooBar
func injectFooBar() FooBar

View File

@@ -0,0 +1 @@
42

1
internal/goose/testdata/MultiUse/pkg vendored Normal file
View File

@@ -0,0 +1 @@
foo