diff --git a/internal/goose/goose.go b/internal/goose/goose.go index 253af42..7ae2daf 100644 --- a/internal/goose/goose.go +++ b/internal/goose/goose.go @@ -71,11 +71,17 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) { if d.kind != "use" { return nil, fmt.Errorf("%v: cannot use %s directive on inject function", prog.Fset.Position(d.pos), d.kind) } - ref, err := parseProviderSetRef(r, d.line, fileScope, g.currPackage, d.pos) - if err != nil { - return nil, fmt.Errorf("%v: %v", prog.Fset.Position(d.pos), err) + args := d.args() + if len(args) == 0 { + return nil, fmt.Errorf("%v: goose:use must have at least one provider set reference", prog.Fset.Position(d.pos)) + } + for _, arg := range args { + ref, err := parseProviderSetRef(r, arg, fileScope, g.currPackage, d.pos) + if err != nil { + return nil, fmt.Errorf("%v: %v", prog.Fset.Position(d.pos), err) + } + sets = append(sets, ref) } - sets = append(sets, ref) } sig := pkgInfo.ObjectOf(fn.Name).Type().(*types.Signature) if err := g.inject(mc, fn.Name.Name, sig, sets); err != nil { diff --git a/internal/goose/parse.go b/internal/goose/parse.go index a38e3e0..6e38a5b 100644 --- a/internal/goose/parse.go +++ b/internal/goose/parse.go @@ -9,6 +9,7 @@ import ( "path/filepath" "strconv" "strings" + "unicode" "golang.org/x/tools/go/loader" ) @@ -80,42 +81,43 @@ func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet case "use": // Ignore, picked up by injector flow. case "import": - i := strings.IndexByte(d.line, ' ') - // TODO(light): allow multiple imports in one line - if i == -1 { + args := d.args() + if len(args) < 2 { return fmt.Errorf("%s: invalid import: expected TARGET SETREF", fctx.fset.Position(d.pos)) } - name, spec := d.line[:i], d.line[i+1:] - ref, err := parseProviderSetRef(fctx.r, spec, scope, fctx.pkg.Path(), d.pos) - if err != nil { - return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err) - } - if ref.importPath != fctx.pkg.Path() { - imported := false - for _, imp := range fctx.pkg.Imports() { - if ref.importPath == imp.Path() { - imported = true - break + name := args[0] + for _, spec := range args[1:] { + ref, err := parseProviderSetRef(fctx.r, spec, scope, fctx.pkg.Path(), d.pos) + if err != nil { + return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err) + } + if ref.importPath != fctx.pkg.Path() { + imported := false + for _, imp := range fctx.pkg.Imports() { + if ref.importPath == imp.Path() { + imported = true + break + } + } + if !imported { + return fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fctx.fset.Position(d.pos), name, ref.importPath) } } - if !imported { - return fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fctx.fset.Position(d.pos), name, ref.importPath) - } - } - if mod := sets[name]; mod != nil { - found := false - for _, other := range mod.imports { - if ref == other.providerSetRef { - found = true - break + if mod := sets[name]; mod != nil { + found := false + for _, other := range mod.imports { + if ref == other.providerSetRef { + found = true + break + } + } + if !found { + mod.imports = append(mod.imports, providerSetImport{providerSetRef: ref, pos: d.pos}) + } + } else { + sets[name] = &providerSet{ + imports: []providerSetImport{{providerSetRef: ref, pos: d.pos}}, } - } - if !found { - mod.imports = append(mod.imports, providerSetImport{providerSetRef: ref, pos: d.pos}) - } - } else { - sets[name] = &providerSet{ - imports: []providerSetImport{{providerSetRef: ref, pos: d.pos}}, } } default: @@ -148,7 +150,7 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope for _, d := range dg.dirs { if d.kind == "optional" { // Marking the given argument names as optional inputs. - for _, arg := range strings.Fields(d.line) { + for _, arg := range d.args() { pi := paramIndex(sig.Params(), arg) if pi == -1 { return fmt.Errorf("%v: %s is not a parameter of func %s", fctx.fset.Position(d.pos), arg, fn.Name.Name) @@ -194,9 +196,11 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope } } providerSetName := fn.Name.Name - if p.line != "" { + if args := p.args(); len(args) == 1 { // TODO(light): validate identifier - providerSetName = p.line + providerSetName = args[0] + } else if len(args) > 1 { + return fmt.Errorf("%v: goose:provide takes at most one argument", fctx.fset.Position(fpos)) } if mod := sets[providerSetName]; mod != nil { for _, other := range mod.providers { @@ -400,11 +404,9 @@ func extractDirectives(d []directive, cg *ast.CommentGroup) []directive { break } line := text[len(prefix):] - if i := strings.IndexByte(line, '\n'); i != -1 { - line, text = line[:i], line[i+1:] - } else { - text = "" - } + // Text() is always newline terminated. + i := strings.IndexByte(line, '\n') + line, text = line[:i], line[i+1:] if i := strings.IndexByte(line, ' '); i != -1 { d = append(d, directive{ kind: line[:i], @@ -452,6 +454,54 @@ func (d directive) isValid() bool { return d.kind != "" } +// args splits the directive line into tokens. +func (d directive) args() []string { + var args []string + start := -1 + state := 0 // 0 = boundary, 1 = in token, 2 = in quote, 3 = quote backslash + for i, r := range d.line { + switch state { + case 0: + // Argument boundary + switch { + case r == '"': + start = i + state = 2 + case !unicode.IsSpace(r): + start = i + state = 1 + } + case 1: + // In token + switch { + case unicode.IsSpace(r): + args = append(args, d.line[start:i]) + start = -1 + state = 0 + case r == '"': + state = 2 + } + case 2: + // In quotes + switch { + case r == '"': + state = 1 + case r == '\\': + state = 3 + } + case 3: + // Quote backslash. Consumes one character and jumps back into "in quote" state. + state = 2 + default: + panic("unreachable") + } + } + if start != -1 { + args = append(args, d.line[start:]) + } + return args +} + // isInjectFile reports whether a given file is an injection template. func isInjectFile(f *ast.File) bool { // TODO(light): better determination diff --git a/internal/goose/parse_test.go b/internal/goose/parse_test.go new file mode 100644 index 0000000..7ebef72 --- /dev/null +++ b/internal/goose/parse_test.go @@ -0,0 +1,37 @@ +package goose + +import ( + "testing" +) + +func TestDirectiveArgs(t *testing.T) { + tests := []struct { + line string + args []string + }{ + {"", []string{}}, + {" \t ", []string{}}, + {"foo", []string{"foo"}}, + {"foo bar", []string{"foo", "bar"}}, + {" foo \t bar ", []string{"foo", "bar"}}, + {"foo \"bar \t baz\" fido", []string{"foo", "\"bar \t baz\"", "fido"}}, + {"foo \"bar \t baz\".quux fido", []string{"foo", "\"bar \t baz\".quux", "fido"}}, + } + eq := func(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true + } + for _, test := range tests { + got := (directive{line: test.line}).args() + if !eq(got, test.args) { + t.Errorf("directive{line: %q}.args() = %q; want %q", test.line, got, test.args) + } + } +} diff --git a/internal/goose/testdata/MultiImport/foo/foo.go b/internal/goose/testdata/MultiImport/foo/foo.go new file mode 100644 index 0000000..1721d3a --- /dev/null +++ b/internal/goose/testdata/MultiImport/foo/foo.go @@ -0,0 +1,22 @@ +package main + +import "fmt" + +func main() { + fmt.Println(injectFooBar()) +} + +type Foo int +type FooBar int + +//goose:provide Foo +func provideFoo() Foo { + return 41 +} + +//goose:provide FooBar +func provideFooBar(foo Foo) FooBar { + return FooBar(foo) + 1 +} + +//goose:import Set Foo FooBar diff --git a/internal/goose/testdata/MultiImport/foo/goose.go b/internal/goose/testdata/MultiImport/foo/goose.go new file mode 100644 index 0000000..73f5093 --- /dev/null +++ b/internal/goose/testdata/MultiImport/foo/goose.go @@ -0,0 +1,7 @@ +//+build gooseinject + +package main + +//goose:use Set + +func injectFooBar() FooBar diff --git a/internal/goose/testdata/MultiImport/out.txt b/internal/goose/testdata/MultiImport/out.txt new file mode 100644 index 0000000..d81cc07 --- /dev/null +++ b/internal/goose/testdata/MultiImport/out.txt @@ -0,0 +1 @@ +42 diff --git a/internal/goose/testdata/MultiImport/pkg b/internal/goose/testdata/MultiImport/pkg new file mode 100644 index 0000000..257cc56 --- /dev/null +++ b/internal/goose/testdata/MultiImport/pkg @@ -0,0 +1 @@ +foo diff --git a/internal/goose/testdata/MultiUse/foo/foo.go b/internal/goose/testdata/MultiUse/foo/foo.go new file mode 100644 index 0000000..232c06a --- /dev/null +++ b/internal/goose/testdata/MultiUse/foo/foo.go @@ -0,0 +1,20 @@ +package main + +import "fmt" + +func main() { + fmt.Println(injectFooBar()) +} + +type Foo int +type FooBar int + +//goose:provide Foo +func provideFoo() Foo { + return 41 +} + +//goose:provide FooBar +func provideFooBar(foo Foo) FooBar { + return FooBar(foo) + 1 +} diff --git a/internal/goose/testdata/MultiUse/foo/goose.go b/internal/goose/testdata/MultiUse/foo/goose.go new file mode 100644 index 0000000..d273f88 --- /dev/null +++ b/internal/goose/testdata/MultiUse/foo/goose.go @@ -0,0 +1,7 @@ +//+build gooseinject + +package main + +//goose:use Foo FooBar + +func injectFooBar() FooBar diff --git a/internal/goose/testdata/MultiUse/out.txt b/internal/goose/testdata/MultiUse/out.txt new file mode 100644 index 0000000..d81cc07 --- /dev/null +++ b/internal/goose/testdata/MultiUse/out.txt @@ -0,0 +1 @@ +42 diff --git a/internal/goose/testdata/MultiUse/pkg b/internal/goose/testdata/MultiUse/pkg new file mode 100644 index 0000000..257cc56 --- /dev/null +++ b/internal/goose/testdata/MultiUse/pkg @@ -0,0 +1 @@ +foo