From fe016541749767cf17789c9079961539e1ba5907 Mon Sep 17 00:00:00 2001 From: Robert van Gent Date: Thu, 16 May 2019 09:56:42 -0700 Subject: [PATCH] cmd/wire: add a --header_file flag to the "gen" and "diff" commands (#175) --- cmd/wire/main.go | 49 ++++++++++++++++--- internal/check_api_change.sh | 2 +- internal/wire/testdata/Header/foo/foo.go | 29 +++++++++++ internal/wire/testdata/Header/foo/wire.go | 26 ++++++++++ internal/wire/testdata/Header/header | 2 + internal/wire/testdata/Header/pkg | 1 + .../wire/testdata/Header/want/program_out.txt | 1 + .../wire/testdata/Header/want/wire_gen.go | 15 ++++++ internal/wire/wire.go | 14 +++++- internal/wire/wire_test.go | 5 +- 10 files changed, 133 insertions(+), 11 deletions(-) create mode 100644 internal/wire/testdata/Header/foo/foo.go create mode 100644 internal/wire/testdata/Header/foo/wire.go create mode 100644 internal/wire/testdata/Header/header create mode 100644 internal/wire/testdata/Header/pkg create mode 100644 internal/wire/testdata/Header/want/program_out.txt create mode 100644 internal/wire/testdata/Header/want/wire_gen.go diff --git a/cmd/wire/main.go b/cmd/wire/main.go index 2256155..202be20 100644 --- a/cmd/wire/main.go +++ b/cmd/wire/main.go @@ -83,7 +83,23 @@ func packages(f *flag.FlagSet) []string { return pkgs } -type genCmd struct{} +// newGenerateOptions returns an initialized wire.GenerateOptions, possibly +// with the Header option set. +func newGenerateOptions(headerFile string) (*wire.GenerateOptions, error) { + opts := new(wire.GenerateOptions) + if headerFile != "" { + var err error + opts.Header, err = ioutil.ReadFile(headerFile) + if err != nil { + return nil, fmt.Errorf("failed to read header file %q: %v", headerFile, err) + } + } + return opts, nil +} + +type genCmd struct { + headerFile string +} func (*genCmd) Name() string { return "gen" } func (*genCmd) Synopsis() string { @@ -97,14 +113,22 @@ func (*genCmd) Usage() string { If no packages are listed, it defaults to ".". ` } -func (*genCmd) SetFlags(_ *flag.FlagSet) {} -func (*genCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { +func (cmd *genCmd) SetFlags(f *flag.FlagSet) { + f.StringVar(&cmd.headerFile, "header_file", "", "path to file to insert as a header in wire_gen.go") +} + +func (cmd *genCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { wd, err := os.Getwd() if err != nil { log.Println("failed to get working directory: ", err) return subcommands.ExitFailure } - outs, errs := wire.Generate(ctx, wd, os.Environ(), packages(f)) + opts, err := newGenerateOptions(cmd.headerFile) + if err != nil { + log.Println(err) + return subcommands.ExitFailure + } + outs, errs := wire.Generate(ctx, wd, os.Environ(), packages(f), opts) if len(errs) > 0 { logErrors(errs) log.Println("generate failed") @@ -138,7 +162,9 @@ func (*genCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{} return subcommands.ExitSuccess } -type diffCmd struct{} +type diffCmd struct { + headerFile string +} func (*diffCmd) Name() string { return "diff" } func (*diffCmd) Synopsis() string { @@ -156,8 +182,10 @@ func (*diffCmd) Usage() string { plus an error if trouble. ` } -func (*diffCmd) SetFlags(_ *flag.FlagSet) {} -func (*diffCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { +func (cmd *diffCmd) SetFlags(f *flag.FlagSet) { + f.StringVar(&cmd.headerFile, "header_file", "", "path to file to insert as a header in wire_gen.go") +} +func (cmd *diffCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { const ( errReturn = subcommands.ExitStatus(2) diffReturn = subcommands.ExitStatus(1) @@ -167,7 +195,12 @@ func (*diffCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{ log.Println("failed to get working directory: ", err) return errReturn } - outs, errs := wire.Generate(ctx, wd, os.Environ(), packages(f)) + opts, err := newGenerateOptions(cmd.headerFile) + if err != nil { + log.Println(err) + return subcommands.ExitFailure + } + outs, errs := wire.Generate(ctx, wd, os.Environ(), packages(f), opts) if len(errs) > 0 { logErrors(errs) log.Println("generate failed") diff --git a/internal/check_api_change.sh b/internal/check_api_change.sh index c435b29..fdbf988 100755 --- a/internal/check_api_change.sh +++ b/internal/check_api_change.sh @@ -52,7 +52,7 @@ trap cleanup EXIT git clone -b "$UPSTREAM_BRANCH" . "$MASTER_CLONE_DIR" &> /dev/null incompatible_change_pkgs=() -PKGS=$(cd "$MASTER_CLONE_DIR"; go list ./... | grep -v test) +PKGS=$(cd "$MASTER_CLONE_DIR"; go list ./... | grep -v test | grep -v internal) for pkg in $PKGS; do echo " Testing ${pkg}..." diff --git a/internal/wire/testdata/Header/foo/foo.go b/internal/wire/testdata/Header/foo/foo.go new file mode 100644 index 0000000..67bc8a3 --- /dev/null +++ b/internal/wire/testdata/Header/foo/foo.go @@ -0,0 +1,29 @@ +// Copyright 2019 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "fmt" +) + +func main() { + fmt.Println(injectFoo()) +} + +type Foo int + +func provideFoo() Foo { + return 41 +} diff --git a/internal/wire/testdata/Header/foo/wire.go b/internal/wire/testdata/Header/foo/wire.go new file mode 100644 index 0000000..957f177 --- /dev/null +++ b/internal/wire/testdata/Header/foo/wire.go @@ -0,0 +1,26 @@ +// Copyright 2019 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//+build wireinject + +package main + +import ( + "github.com/google/wire" +) + +func injectFoo() Foo { + wire.Build(provideFoo) + return Foo(0) +} diff --git a/internal/wire/testdata/Header/header b/internal/wire/testdata/Header/header new file mode 100644 index 0000000..5ad7152 --- /dev/null +++ b/internal/wire/testdata/Header/header @@ -0,0 +1,2 @@ +// This is a sample header file. +// diff --git a/internal/wire/testdata/Header/pkg b/internal/wire/testdata/Header/pkg new file mode 100644 index 0000000..f7a5c8c --- /dev/null +++ b/internal/wire/testdata/Header/pkg @@ -0,0 +1 @@ +example.com/foo diff --git a/internal/wire/testdata/Header/want/program_out.txt b/internal/wire/testdata/Header/want/program_out.txt new file mode 100644 index 0000000..87523dd --- /dev/null +++ b/internal/wire/testdata/Header/want/program_out.txt @@ -0,0 +1 @@ +41 diff --git a/internal/wire/testdata/Header/want/wire_gen.go b/internal/wire/testdata/Header/want/wire_gen.go new file mode 100644 index 0000000..5acba8f --- /dev/null +++ b/internal/wire/testdata/Header/want/wire_gen.go @@ -0,0 +1,15 @@ +// This is a sample header file. +// +// Code generated by Wire. DO NOT EDIT. + +//go:generate wire +//+build !wireinject + +package main + +// Injectors from wire.go: + +func injectFoo() Foo { + foo := provideFoo() + return foo +} diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 4aabf10..b585c8d 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -60,6 +60,12 @@ func (gen GenerateResult) Commit() error { return ioutil.WriteFile(gen.OutputPath, gen.Content, 0666) } +// GenerateOptions holds options for Generate. +type GenerateOptions struct { + // Header will be inserted at the start of each generated file. + Header []byte +} + // Generate performs dependency injection for the packages that match the given // patterns, return a GenerateResult for each package. The package pattern is // defined by the underlying build system. For the go tool, this is described at @@ -72,7 +78,10 @@ func (gen GenerateResult) Commit() error { // takes precedence. // // Generate may return one or more errors if it failed to load the packages. -func Generate(ctx context.Context, wd string, env []string, patterns []string) ([]GenerateResult, []error) { +func Generate(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions) ([]GenerateResult, []error) { + if opts == nil { + opts = &GenerateOptions{} + } pkgs, errs := load(ctx, wd, env, patterns) if len(errs) > 0 { return nil, errs @@ -94,6 +103,9 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string) ( } copyNonInjectorDecls(g, injectorFiles, pkg.TypesInfo) goSrc := g.frame() + if len(opts.Header) > 0 { + goSrc = append(opts.Header, goSrc...) + } fmtSrc, err := format.Source(goSrc) if err != nil { // This is likely a bug from a poorly generated source file. diff --git a/internal/wire/wire_test.go b/internal/wire/wire_test.go index 82dce4c..4717c74 100644 --- a/internal/wire/wire_test.go +++ b/internal/wire/wire_test.go @@ -90,7 +90,7 @@ func TestWire(t *testing.T) { t.Fatal(err) } wd := filepath.Join(gopath, "src", "example.com") - gens, errs := Generate(ctx, wd, append(os.Environ(), "GOPATH="+gopath), []string{test.pkg}) + gens, errs := Generate(ctx, wd, append(os.Environ(), "GOPATH="+gopath), []string{test.pkg}, &GenerateOptions{Header: test.header}) var gen GenerateResult if len(gens) > 1 { t.Fatalf("got %d generated files, want 0 or 1", len(gens)) @@ -428,6 +428,7 @@ func scrubLineColumn(s string) (replacement string, n int) { type testCase struct { name string pkg string + header []byte goFiles map[string][]byte wantProgramOutput []byte wantWireOutput []byte @@ -471,6 +472,7 @@ func loadTestCase(root string, wireGoSrc []byte) (*testCase, error) { if err != nil { return nil, fmt.Errorf("load test case %s: %v", name, err) } + header, _ := ioutil.ReadFile(filepath.Join(root, "header")) var wantProgramOutput []byte var wantWireOutput []byte wireErrb, err := ioutil.ReadFile(filepath.Join(root, "want", "wire_errs.txt")) @@ -521,6 +523,7 @@ func loadTestCase(root string, wireGoSrc []byte) (*testCase, error) { return &testCase{ name: name, pkg: string(bytes.TrimSpace(pkg)), + header: header, goFiles: goFiles, wantWireOutput: wantWireOutput, wantProgramOutput: wantProgramOutput,