goose: allow multiple arguments to use and import

Reviewed-by: Tuo Shan <shantuo@google.com>
This commit is contained in:
Ross Light
2018-04-02 10:57:48 -07:00
parent 5261a8a8bb
commit 73d4c0f0fc
11 changed files with 196 additions and 43 deletions

View File

@@ -71,11 +71,17 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
if d.kind != "use" {
return nil, fmt.Errorf("%v: cannot use %s directive on inject function", prog.Fset.Position(d.pos), d.kind)
}
ref, err := parseProviderSetRef(r, d.line, fileScope, g.currPackage, d.pos)
if err != nil {
return nil, fmt.Errorf("%v: %v", prog.Fset.Position(d.pos), err)
args := d.args()
if len(args) == 0 {
return nil, fmt.Errorf("%v: goose:use must have at least one provider set reference", prog.Fset.Position(d.pos))
}
for _, arg := range args {
ref, err := parseProviderSetRef(r, arg, fileScope, g.currPackage, d.pos)
if err != nil {
return nil, fmt.Errorf("%v: %v", prog.Fset.Position(d.pos), err)
}
sets = append(sets, ref)
}
sets = append(sets, ref)
}
sig := pkgInfo.ObjectOf(fn.Name).Type().(*types.Signature)
if err := g.inject(mc, fn.Name.Name, sig, sets); err != nil {

View File

@@ -9,6 +9,7 @@ import (
"path/filepath"
"strconv"
"strings"
"unicode"
"golang.org/x/tools/go/loader"
)
@@ -80,42 +81,43 @@ func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet
case "use":
// Ignore, picked up by injector flow.
case "import":
i := strings.IndexByte(d.line, ' ')
// TODO(light): allow multiple imports in one line
if i == -1 {
args := d.args()
if len(args) < 2 {
return fmt.Errorf("%s: invalid import: expected TARGET SETREF", fctx.fset.Position(d.pos))
}
name, spec := d.line[:i], d.line[i+1:]
ref, err := parseProviderSetRef(fctx.r, spec, scope, fctx.pkg.Path(), d.pos)
if err != nil {
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
}
if ref.importPath != fctx.pkg.Path() {
imported := false
for _, imp := range fctx.pkg.Imports() {
if ref.importPath == imp.Path() {
imported = true
break
name := args[0]
for _, spec := range args[1:] {
ref, err := parseProviderSetRef(fctx.r, spec, scope, fctx.pkg.Path(), d.pos)
if err != nil {
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
}
if ref.importPath != fctx.pkg.Path() {
imported := false
for _, imp := range fctx.pkg.Imports() {
if ref.importPath == imp.Path() {
imported = true
break
}
}
if !imported {
return fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fctx.fset.Position(d.pos), name, ref.importPath)
}
}
if !imported {
return fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fctx.fset.Position(d.pos), name, ref.importPath)
}
}
if mod := sets[name]; mod != nil {
found := false
for _, other := range mod.imports {
if ref == other.providerSetRef {
found = true
break
if mod := sets[name]; mod != nil {
found := false
for _, other := range mod.imports {
if ref == other.providerSetRef {
found = true
break
}
}
if !found {
mod.imports = append(mod.imports, providerSetImport{providerSetRef: ref, pos: d.pos})
}
} else {
sets[name] = &providerSet{
imports: []providerSetImport{{providerSetRef: ref, pos: d.pos}},
}
}
if !found {
mod.imports = append(mod.imports, providerSetImport{providerSetRef: ref, pos: d.pos})
}
} else {
sets[name] = &providerSet{
imports: []providerSetImport{{providerSetRef: ref, pos: d.pos}},
}
}
default:
@@ -148,7 +150,7 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope
for _, d := range dg.dirs {
if d.kind == "optional" {
// Marking the given argument names as optional inputs.
for _, arg := range strings.Fields(d.line) {
for _, arg := range d.args() {
pi := paramIndex(sig.Params(), arg)
if pi == -1 {
return fmt.Errorf("%v: %s is not a parameter of func %s", fctx.fset.Position(d.pos), arg, fn.Name.Name)
@@ -194,9 +196,11 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope
}
}
providerSetName := fn.Name.Name
if p.line != "" {
if args := p.args(); len(args) == 1 {
// TODO(light): validate identifier
providerSetName = p.line
providerSetName = args[0]
} else if len(args) > 1 {
return fmt.Errorf("%v: goose:provide takes at most one argument", fctx.fset.Position(fpos))
}
if mod := sets[providerSetName]; mod != nil {
for _, other := range mod.providers {
@@ -400,11 +404,9 @@ func extractDirectives(d []directive, cg *ast.CommentGroup) []directive {
break
}
line := text[len(prefix):]
if i := strings.IndexByte(line, '\n'); i != -1 {
line, text = line[:i], line[i+1:]
} else {
text = ""
}
// Text() is always newline terminated.
i := strings.IndexByte(line, '\n')
line, text = line[:i], line[i+1:]
if i := strings.IndexByte(line, ' '); i != -1 {
d = append(d, directive{
kind: line[:i],
@@ -452,6 +454,54 @@ func (d directive) isValid() bool {
return d.kind != ""
}
// args splits the directive line into tokens.
func (d directive) args() []string {
var args []string
start := -1
state := 0 // 0 = boundary, 1 = in token, 2 = in quote, 3 = quote backslash
for i, r := range d.line {
switch state {
case 0:
// Argument boundary
switch {
case r == '"':
start = i
state = 2
case !unicode.IsSpace(r):
start = i
state = 1
}
case 1:
// In token
switch {
case unicode.IsSpace(r):
args = append(args, d.line[start:i])
start = -1
state = 0
case r == '"':
state = 2
}
case 2:
// In quotes
switch {
case r == '"':
state = 1
case r == '\\':
state = 3
}
case 3:
// Quote backslash. Consumes one character and jumps back into "in quote" state.
state = 2
default:
panic("unreachable")
}
}
if start != -1 {
args = append(args, d.line[start:])
}
return args
}
// isInjectFile reports whether a given file is an injection template.
func isInjectFile(f *ast.File) bool {
// TODO(light): better determination

View File

@@ -0,0 +1,37 @@
package goose
import (
"testing"
)
func TestDirectiveArgs(t *testing.T) {
tests := []struct {
line string
args []string
}{
{"", []string{}},
{" \t ", []string{}},
{"foo", []string{"foo"}},
{"foo bar", []string{"foo", "bar"}},
{" foo \t bar ", []string{"foo", "bar"}},
{"foo \"bar \t baz\" fido", []string{"foo", "\"bar \t baz\"", "fido"}},
{"foo \"bar \t baz\".quux fido", []string{"foo", "\"bar \t baz\".quux", "fido"}},
}
eq := func(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
for _, test := range tests {
got := (directive{line: test.line}).args()
if !eq(got, test.args) {
t.Errorf("directive{line: %q}.args() = %q; want %q", test.line, got, test.args)
}
}
}

View File

@@ -0,0 +1,22 @@
package main
import "fmt"
func main() {
fmt.Println(injectFooBar())
}
type Foo int
type FooBar int
//goose:provide Foo
func provideFoo() Foo {
return 41
}
//goose:provide FooBar
func provideFooBar(foo Foo) FooBar {
return FooBar(foo) + 1
}
//goose:import Set Foo FooBar

View File

@@ -0,0 +1,7 @@
//+build gooseinject
package main
//goose:use Set
func injectFooBar() FooBar

View File

@@ -0,0 +1 @@
42

View File

@@ -0,0 +1 @@
foo

View File

@@ -0,0 +1,20 @@
package main
import "fmt"
func main() {
fmt.Println(injectFooBar())
}
type Foo int
type FooBar int
//goose:provide Foo
func provideFoo() Foo {
return 41
}
//goose:provide FooBar
func provideFooBar(foo Foo) FooBar {
return FooBar(foo) + 1
}

View File

@@ -0,0 +1,7 @@
//+build gooseinject
package main
//goose:use Foo FooBar
func injectFooBar() FooBar

View File

@@ -0,0 +1 @@
42

1
internal/goose/testdata/MultiUse/pkg vendored Normal file
View File

@@ -0,0 +1 @@
foo