diff --git a/cmd/wire/main.go b/cmd/wire/main.go index f3b153e..fa37e51 100644 --- a/cmd/wire/main.go +++ b/cmd/wire/main.go @@ -100,6 +100,7 @@ func newGenerateOptions(headerFile string) (*wire.GenerateOptions, error) { type genCmd struct { headerFile string prefixFileName string + tags string } func (*genCmd) Name() string { return "gen" } @@ -117,6 +118,7 @@ func (*genCmd) Usage() string { func (cmd *genCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&cmd.headerFile, "header_file", "", "path to file to insert as a header in wire_gen.go") f.StringVar(&cmd.prefixFileName, "output_file_prefix", "", "string to prepend to output file names.") + f.StringVar(&cmd.tags, "tags", "", "append build tags to the default wirebuild") } func (cmd *genCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { @@ -132,6 +134,7 @@ func (cmd *genCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa } opts.PrefixOutputFile = cmd.prefixFileName + opts.Tags = cmd.tags outs, errs := wire.Generate(ctx, wd, os.Environ(), packages(f), opts) if len(errs) > 0 { @@ -169,6 +172,7 @@ func (cmd *genCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa type diffCmd struct { headerFile string + tags string } func (*diffCmd) Name() string { return "diff" } @@ -189,6 +193,7 @@ func (*diffCmd) Usage() string { } func (cmd *diffCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&cmd.headerFile, "header_file", "", "path to file to insert as a header in wire_gen.go") + f.StringVar(&cmd.tags, "tags", "", "append build tags to the default wirebuild") } func (cmd *diffCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { const ( @@ -206,6 +211,8 @@ func (cmd *diffCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interf return subcommands.ExitFailure } + opts.Tags = cmd.tags + outs, errs := wire.Generate(ctx, wd, os.Environ(), packages(f), opts) if len(errs) > 0 { logErrors(errs) @@ -253,7 +260,9 @@ func (cmd *diffCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interf return subcommands.ExitSuccess } -type showCmd struct{} +type showCmd struct { + tags string +} func (*showCmd) Name() string { return "show" } func (*showCmd) Synopsis() string { @@ -270,14 +279,16 @@ func (*showCmd) Usage() string { If no packages are listed, it defaults to ".". ` } -func (*showCmd) SetFlags(_ *flag.FlagSet) {} -func (*showCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { +func (cmd *showCmd) SetFlags(f *flag.FlagSet) { + f.StringVar(&cmd.tags, "tags", "", "append build tags to the default wirebuild") +} +func (cmd *showCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { wd, err := os.Getwd() if err != nil { log.Println("failed to get working directory: ", err) return subcommands.ExitFailure } - info, errs := wire.Load(ctx, wd, os.Environ(), packages(f)) + info, errs := wire.Load(ctx, wd, os.Environ(), cmd.tags, packages(f)) if info != nil { keys := make([]wire.ProviderSetID, 0, len(info.Sets)) for k := range info.Sets { @@ -341,14 +352,16 @@ func (*showCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{ return subcommands.ExitSuccess } -type checkCmd struct{} +type checkCmd struct { + tags string +} func (*checkCmd) Name() string { return "check" } func (*checkCmd) Synopsis() string { return "print any Wire errors found" } func (*checkCmd) Usage() string { - return `check [packages] + return `check [-tags tag,list] [packages] Given one or more packages, check prints any type-checking or Wire errors found with top-level variable provider sets or injector functions. @@ -356,14 +369,16 @@ func (*checkCmd) Usage() string { If no packages are listed, it defaults to ".". ` } -func (*checkCmd) SetFlags(_ *flag.FlagSet) {} -func (*checkCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { +func (cmd *checkCmd) SetFlags(f *flag.FlagSet) { + f.StringVar(&cmd.tags, "tags", "", "append build tags to the default wirebuild") +} +func (cmd *checkCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { wd, err := os.Getwd() if err != nil { log.Println("failed to get working directory: ", err) return subcommands.ExitFailure } - _, errs := wire.Load(ctx, wd, os.Environ(), packages(f)) + _, errs := wire.Load(ctx, wd, os.Environ(), cmd.tags, packages(f)) if len(errs) > 0 { logErrors(errs) log.Println("error loading packages") diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 7a2e1a8..93fbda8 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -248,8 +248,8 @@ type Field struct { // env is nil or empty, it is interpreted as an empty set of variables. // In case of duplicate environment variables, the last one in the list // takes precedence. -func Load(ctx context.Context, wd string, env []string, patterns []string) (*Info, []error) { - pkgs, errs := load(ctx, wd, env, patterns) +func Load(ctx context.Context, wd string, env []string, tags string, patterns []string) (*Info, []error) { + pkgs, errs := load(ctx, wd, env, tags, patterns) if len(errs) > 0 { return nil, errs } @@ -349,7 +349,7 @@ func Load(ctx context.Context, wd string, env []string, patterns []string) (*Inf // env is nil or empty, it is interpreted as an empty set of variables. // In case of duplicate environment variables, the last one in the list // takes precedence. -func load(ctx context.Context, wd string, env []string, patterns []string) ([]*packages.Package, []error) { +func load(ctx context.Context, wd string, env []string, tags string, patterns []string) ([]*packages.Package, []error) { cfg := &packages.Config{ Context: ctx, Mode: packages.LoadAllSyntax, @@ -358,6 +358,9 @@ func load(ctx context.Context, wd string, env []string, patterns []string) ([]*p BuildFlags: []string{"-tags=wireinject"}, // TODO(light): Use ParseFile to skip function bodies and comments in indirect packages. } + if len(tags) > 0 { + cfg.BuildFlags[0] += " " + tags + } escaped := make([]string, len(patterns)) for i := range patterns { escaped[i] = "pattern=" + patterns[i] diff --git a/internal/wire/wire.go b/internal/wire/wire.go index a75144f..d23e6b6 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -65,6 +65,7 @@ type GenerateOptions struct { // Header will be inserted at the start of each generated file. Header []byte PrefixOutputFile string + Tags string } // Generate performs dependency injection for the packages that match the given @@ -83,7 +84,7 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o if opts == nil { opts = &GenerateOptions{} } - pkgs, errs := load(ctx, wd, env, patterns) + pkgs, errs := load(ctx, wd, env, opts.Tags, patterns) if len(errs) > 0 { return nil, errs } @@ -103,7 +104,7 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o continue } copyNonInjectorDecls(g, injectorFiles, pkg.TypesInfo) - goSrc := g.frame() + goSrc := g.frame(opts.Tags) if len(opts.Header) > 0 { goSrc = append(opts.Header, goSrc...) } @@ -257,13 +258,16 @@ func newGen(pkg *packages.Package) *gen { } // frame bakes the built up source body into an unformatted Go source file. -func (g *gen) frame() []byte { +func (g *gen) frame(tags string) []byte { if g.buf.Len() == 0 { return nil } var buf bytes.Buffer + if len(tags) > 0 { + tags = fmt.Sprintf(" gen -tags \"%s\"", tags) + } buf.WriteString("// Code generated by Wire. DO NOT EDIT.\n\n") - buf.WriteString("//go:generate wire\n") + buf.WriteString("//go:generate wire" + tags + "\n") buf.WriteString("//+build !wireinject\n\n") buf.WriteString("package ") buf.WriteString(g.pkg.Name)