wire/internal/wire: use on-disk GOPATH in generate tests (google/go-cloud#616)

The primary motivation is to permit a move to using go/packages instead
of go/loader. go/packages runs exclusively by shelling out to the go
tool, which precludes use of the in-memory "magic" GOPATH being used
up to this point.

This has a secondary effect of removing a lot of code to support "magic"
GOPATH from the test infrastructure. This is on the whole good, but
necessitated a change in the error scrubbing: since the filenames are
no longer fixed, error scrubbing also must remove the leading
$GOPATH/src lines.

Another related change: since all callers of Generate needed to know the
package path in order to write out wire_gen.go (necessitating a
find-only import search) and Generate already has this information,
Generate now returns this information to the caller. This should further
reduce callers' coupling to Wire's load internals. It also eliminates
code duplication.

This should hopefully shake out any difference in path separators for
running on Windows, but I have not tested that yet.

Updates google/go-cloud#78
Updates google/go-cloud#323
This commit is contained in:
Ross Light
2018-11-06 08:44:51 -08:00
parent ab113bf8d1
commit 64470a2452
24 changed files with 301 additions and 367 deletions

View File

@@ -18,14 +18,12 @@
package main package main
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"go/build"
"go/token" "go/token"
"go/types" "go/types"
"io/ioutil"
"os" "os"
"path/filepath"
"reflect" "reflect"
"sort" "sort"
"strconv" "strconv"
@@ -74,25 +72,17 @@ func generate(pkg string) error {
if err != nil { if err != nil {
return err return err
} }
pkgInfo, err := build.Default.Import(pkg, wd, build.FindOnly) out, errs := wire.Generate(context.Background(), wd, os.Environ(), pkg)
if err != nil {
return err
}
out, errs := wire.Generate(&build.Default, wd, pkg)
if len(errs) > 0 { if len(errs) > 0 {
logErrors(errs) logErrors(errs)
return errors.New("generate failed") return errors.New("generate failed")
} }
if len(out) == 0 { if len(out.Content) == 0 {
// No Wire directives, don't write anything. // No Wire directives, don't write anything.
fmt.Fprintln(os.Stderr, "wire: no injector found for", pkg) fmt.Fprintln(os.Stderr, "wire: no injector found for", pkg)
return nil return nil
} }
p := filepath.Join(pkgInfo.Dir, "wire_gen.go") return out.Commit()
if err := ioutil.WriteFile(p, out, 0666); err != nil {
return err
}
return nil
} }
// show runs the show subcommand. // show runs the show subcommand.
@@ -106,7 +96,7 @@ func show(pkgs ...string) error {
if err != nil { if err != nil {
return err return err
} }
info, errs := wire.Load(&build.Default, wd, pkgs) info, errs := wire.Load(context.Background(), wd, os.Environ(), pkgs)
if info != nil { if info != nil {
keys := make([]wire.ProviderSetID, 0, len(info.Sets)) keys := make([]wire.ProviderSetID, 0, len(info.Sets))
for k := range info.Sets { for k := range info.Sets {
@@ -185,7 +175,7 @@ func check(pkgs ...string) error {
if err != nil { if err != nil {
return err return err
} }
_, errs := wire.Load(&build.Default, wd, pkgs) _, errs := wire.Load(context.Background(), wd, os.Environ(), pkgs)
if len(errs) > 0 { if len(errs) > 0 {
logErrors(errs) logErrors(errs)
return errors.New("error loading packages") return errors.New("error loading packages")

View File

@@ -15,6 +15,7 @@
package wire package wire
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"go/ast" "go/ast"
@@ -190,11 +191,19 @@ type Value struct {
info *types.Info info *types.Info
} }
// Load finds all the provider sets in the given packages, as well as // Load finds all the provider sets in the packages that match the given
// the provider sets' transitive dependencies. It may return both errors // patterns, as well as the provider sets' transitive dependencies. It
// and Info. // may return both errors and Info. The patterns are defined by the
func Load(bctx *build.Context, wd string, pkgs []string) (*Info, []error) { // underlying build system. For the go tool, this is described at
prog, errs := load(bctx, wd, pkgs) // https://golang.org/cmd/go/#hdr-Package_lists_and_patterns
//
// wd is the working directory and env is the set of environment
// variables to use when loading the packages specified by patterns. If
// env is nil or empty, it is interpreted as an empty set of variables.
// In case of duplicate environment variables, the last one in the list
// takes precedence.
func Load(ctx context.Context, wd string, env []string, patterns []string) (*Info, []error) {
prog, errs := load(ctx, wd, env, patterns)
if len(errs) > 0 { if len(errs) > 0 {
return nil, errs return nil, errs
} }
@@ -275,12 +284,22 @@ func Load(bctx *build.Context, wd string, pkgs []string) (*Info, []error) {
return info, ec.errors return info, ec.errors
} }
// load typechecks the packages, including function body type checking // load typechecks the packages that match the given patterns, including
// for the packages directly named. // function body type checking for the packages that directly match. The
func load(bctx *build.Context, wd string, pkgs []string) (*loader.Program, []error) { // patterns are defined by the underlying build system. For the go tool,
// this is described at
// https://golang.org/cmd/go/#hdr-Package_lists_and_patterns
//
// wd is the working directory and env is the set of environment
// variables to use when loading the packages specified by patterns. If
// env is nil or empty, it is interpreted as an empty set of variables.
// In case of duplicate environment variables, the last one in the list
// takes precedence.
func load(ctx context.Context, wd string, env []string, patterns []string) (*loader.Program, []error) {
bctx := buildContextFromEnv(env)
var foundPkgs []*build.Package var foundPkgs []*build.Package
ec := new(errorCollector) ec := new(errorCollector)
for _, name := range pkgs { for _, name := range patterns {
p, err := bctx.Import(name, wd, build.FindOnly) p, err := bctx.Import(name, wd, build.FindOnly)
if err != nil { if err != nil {
ec.add(err) ec.add(err)
@@ -320,7 +339,7 @@ func load(bctx *build.Context, wd string, pkgs []string) (*loader.Program, []err
return pkg, err return pkg, err
}, },
} }
for _, name := range pkgs { for _, name := range patterns {
conf.Import(name) conf.Import(name)
} }
@@ -334,6 +353,35 @@ func load(bctx *build.Context, wd string, pkgs []string) (*loader.Program, []err
return prog, nil return prog, nil
} }
func buildContextFromEnv(env []string) *build.Context {
// TODO(#78): Remove this function in favor of using go/packages,
// which does not need a *build.Context.
getenv := func(name string) string {
for i := len(env) - 1; i >= 0; i-- {
if strings.HasPrefix(env[i], name+"=") {
return env[i][len(name)+1:]
}
}
return ""
}
bctx := new(build.Context)
*bctx = build.Default
if v := getenv("GOARCH"); v != "" {
bctx.GOARCH = v
}
if v := getenv("GOOS"); v != "" {
bctx.GOOS = v
}
if v := getenv("GOROOT"); v != "" {
bctx.GOROOT = v
}
if v := getenv("GOPATH"); v != "" {
bctx.GOPATH = v
}
return bctx
}
func importPathInPkgList(pkgs []*build.Package, path string) bool { func importPathInPkgList(pkgs []*build.Package, path string) bool {
for _, p := range pkgs { for _, p := range pkgs {
if path == p.ImportPath { if path == p.ImportPath {

View File

@@ -1 +1 @@
./example.com/foo ./foo

View File

@@ -1,4 +1,4 @@
/wire_gopath/src/example.com/foo/wire.go:x:y: cycle for example.com/foo.Bar: example.com/foo/wire.go:x:y: cycle for example.com/foo.Bar:
example.com/foo.Bar (example.com/foo.provideBar) -> example.com/foo.Bar (example.com/foo.provideBar) ->
example.com/foo.Foo (example.com/foo.provideFoo) -> example.com/foo.Foo (example.com/foo.provideFoo) ->
example.com/foo.Baz (example.com/foo.provideBaz) -> example.com/foo.Baz (example.com/foo.provideBaz) ->

View File

@@ -1 +1 @@
/wire_gopath/src/example.com/foo/wire.go:x:y: var example.com/foo.myFakeSet struct{} is not a provider or a provider set example.com/foo/wire.go:x:y: var example.com/foo.myFakeSet struct{} is not a provider or a provider set

View File

@@ -1 +1 @@
/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectBar: input of example.com/foo.Foo conflicts with provider provideFoo at /wire_gopath/src/example.com/foo/foo.go:x:y example.com/foo/wire.go:x:y: inject injectBar: input of example.com/foo.Foo conflicts with provider provideFoo at example.com/foo/foo.go:x:y

View File

@@ -1 +1 @@
/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectFoo: provider for example.com/foo.Foo returns cleanup but injection does not return cleanup function example.com/foo/wire.go:x:y: inject injectFoo: provider for example.com/foo.Foo returns cleanup but injection does not return cleanup function

View File

@@ -1 +1 @@
/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectFoo: provider for example.com/foo.Foo returns error but injection not allowed to fail example.com/foo/wire.go:x:y: inject injectFoo: provider for example.com/foo.Foo returns error but injection not allowed to fail

View File

@@ -1 +1 @@
/wire_gopath/src/example.com/foo/wire.go:x:y: string does not implement example.com/foo.Fooer example.com/foo/wire.go:x:y: string does not implement example.com/foo.Fooer

View File

@@ -1 +1 @@
/wire_gopath/src/example.com/foo/wire.go:x:y: first argument to Bind must be a pointer to an interface type; found string example.com/foo/wire.go:x:y: first argument to Bind must be a pointer to an interface type; found string

View File

@@ -1 +1 @@
/wire_gopath/src/example.com/foo/wire.go:x:y: too few arguments in call to wire.Bind example.com/foo/wire.go:x:y: too few arguments in call to wire.Bind

View File

@@ -1 +1 @@
/wire_gopath/src/example.com/foo/wire.go:x:y: string does not implement io.Reader example.com/foo/wire.go:x:y: string does not implement io.Reader

View File

@@ -1 +1 @@
/wire_gopath/src/example.com/foo/wire.go:x:y: first argument to InterfaceValue must be a pointer to an interface type; found string example.com/foo/wire.go:x:y: first argument to InterfaceValue must be a pointer to an interface type; found string

View File

@@ -1 +1 @@
/wire_gopath/src/example.com/foo/wire.go:x:y: too few arguments in call to wire.InterfaceValue example.com/foo/wire.go:x:y: too few arguments in call to wire.InterfaceValue

View File

@@ -1,41 +1,41 @@
/wire_gopath/src/example.com/foo/wire.go:x:y: wire.Build has multiple bindings for example.com/foo.Foo example.com/foo/wire.go:x:y: wire.Build has multiple bindings for example.com/foo.Foo
current: current:
<- provider "provideFooAgain" (/wire_gopath/src/example.com/foo/foo.go:x:y) <- provider "provideFooAgain" (example.com/foo/foo.go:x:y)
previous: previous:
<- provider "provideFoo" (/wire_gopath/src/example.com/foo/foo.go:x:y) <- provider "provideFoo" (example.com/foo/foo.go:x:y)
/wire_gopath/src/example.com/foo/wire.go:x:y: wire.Build has multiple bindings for example.com/foo.Foo example.com/foo/wire.go:x:y: wire.Build has multiple bindings for example.com/foo.Foo
current: current:
<- provider "provideFoo" (/wire_gopath/src/example.com/foo/foo.go:x:y) <- provider "provideFoo" (example.com/foo/foo.go:x:y)
previous: previous:
<- provider "provideFoo" (/wire_gopath/src/example.com/foo/foo.go:x:y) <- provider "provideFoo" (example.com/foo/foo.go:x:y)
<- provider set "Set" (/wire_gopath/src/example.com/foo/foo.go:x:y) <- provider set "Set" (example.com/foo/foo.go:x:y)
/wire_gopath/src/example.com/foo/wire.go:x:y: wire.Build has multiple bindings for example.com/foo.Foo example.com/foo/wire.go:x:y: wire.Build has multiple bindings for example.com/foo.Foo
current: current:
<- provider "provideFoo" (/wire_gopath/src/example.com/foo/foo.go:x:y) <- provider "provideFoo" (example.com/foo/foo.go:x:y)
previous: previous:
<- provider "provideFoo" (/wire_gopath/src/example.com/foo/foo.go:x:y) <- provider "provideFoo" (example.com/foo/foo.go:x:y)
<- provider set "Set" (/wire_gopath/src/example.com/foo/foo.go:x:y) <- provider set "Set" (example.com/foo/foo.go:x:y)
<- provider set "SuperSet" (/wire_gopath/src/example.com/foo/foo.go:x:y) <- provider set "SuperSet" (example.com/foo/foo.go:x:y)
/wire_gopath/src/example.com/foo/foo.go:x:y: SetWithDuplicateBindings has multiple bindings for example.com/foo.Foo example.com/foo/foo.go:x:y: SetWithDuplicateBindings has multiple bindings for example.com/foo.Foo
current: current:
<- provider "provideFoo" (/wire_gopath/src/example.com/foo/foo.go:x:y) <- provider "provideFoo" (example.com/foo/foo.go:x:y)
<- provider set "Set" (/wire_gopath/src/example.com/foo/foo.go:x:y) <- provider set "Set" (example.com/foo/foo.go:x:y)
<- provider set "SuperSet" (/wire_gopath/src/example.com/foo/foo.go:x:y) <- provider set "SuperSet" (example.com/foo/foo.go:x:y)
previous: previous:
<- provider "provideFoo" (/wire_gopath/src/example.com/foo/foo.go:x:y) <- provider "provideFoo" (example.com/foo/foo.go:x:y)
<- provider set "Set" (/wire_gopath/src/example.com/foo/foo.go:x:y) <- provider set "Set" (example.com/foo/foo.go:x:y)
/wire_gopath/src/example.com/foo/wire.go:x:y: wire.Build has multiple bindings for example.com/foo.Foo example.com/foo/wire.go:x:y: wire.Build has multiple bindings for example.com/foo.Foo
current: current:
<- wire.Value (/wire_gopath/src/example.com/foo/wire.go:x:y) <- wire.Value (example.com/foo/wire.go:x:y)
previous: previous:
<- provider "provideFoo" (/wire_gopath/src/example.com/foo/foo.go:x:y) <- provider "provideFoo" (example.com/foo/foo.go:x:y)
/wire_gopath/src/example.com/foo/wire.go:x:y: wire.Build has multiple bindings for example.com/foo.Bar example.com/foo/wire.go:x:y: wire.Build has multiple bindings for example.com/foo.Bar
current: current:
<- wire.Bind (/wire_gopath/src/example.com/foo/wire.go:x:y) <- wire.Bind (example.com/foo/wire.go:x:y)
previous: previous:
<- provider "provideBar" (/wire_gopath/src/example.com/foo/foo.go:x:y) <- provider "provideBar" (example.com/foo/foo.go:x:y)

View File

@@ -1,12 +1,12 @@
/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectMissingOutputType: no provider found for example.com/foo.Foo, output of injector example.com/foo/wire.go:x:y: inject injectMissingOutputType: no provider found for example.com/foo.Foo, output of injector
/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectMultipleMissingTypes: no provider found for example.com/foo.Foo example.com/foo/wire.go:x:y: inject injectMultipleMissingTypes: no provider found for example.com/foo.Foo
needed by example.com/foo.Baz in provider "provideBaz" (/wire_gopath/src/example.com/foo/foo.go:x:y) needed by example.com/foo.Baz in provider "provideBaz" (example.com/foo/foo.go:x:y)
/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectMultipleMissingTypes: no provider found for example.com/foo.Bar example.com/foo/wire.go:x:y: inject injectMultipleMissingTypes: no provider found for example.com/foo.Bar
needed by example.com/foo.Baz in provider "provideBaz" (/wire_gopath/src/example.com/foo/foo.go:x:y) needed by example.com/foo.Baz in provider "provideBaz" (example.com/foo/foo.go:x:y)
/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectMissingRecursiveType: no provider found for example.com/foo.Foo example.com/foo/wire.go:x:y: inject injectMissingRecursiveType: no provider found for example.com/foo.Foo
needed by example.com/foo.Zip in provider "provideZip" (/wire_gopath/src/example.com/foo/foo.go:x:y) needed by example.com/foo.Zip in provider "provideZip" (example.com/foo/foo.go:x:y)
needed by example.com/foo.Zap in provider "provideZap" (/wire_gopath/src/example.com/foo/foo.go:x:y) needed by example.com/foo.Zap in provider "provideZap" (example.com/foo/foo.go:x:y)
needed by example.com/foo.Zop in provider "provideZop" (/wire_gopath/src/example.com/foo/foo.go:x:y) needed by example.com/foo.Zop in provider "provideZop" (example.com/foo/foo.go:x:y)

View File

@@ -1 +1 @@
/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectFooer: no provider found for example.com/foo.Fooer, output of injector example.com/foo/wire.go:x:y: inject injectFooer: no provider found for example.com/foo.Fooer, output of injector

View File

@@ -1 +1 @@
/wire_gopath/src/example.com/foo/wire.go:x:y: foo not exported by package bar example.com/foo/wire.go:x:y: foo not exported by package bar

View File

@@ -1 +1 @@
/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectedMessage: value string can't be used: uses unexported identifier privateMsg example.com/foo/wire.go:x:y: inject injectedMessage: value string can't be used: uses unexported identifier privateMsg

View File

@@ -1,7 +1,7 @@
/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectBar: unused provider set "unusedSet" example.com/foo/wire.go:x:y: inject injectBar: unused provider set "unusedSet"
/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectBar: unused provider "provideUnused" example.com/foo/wire.go:x:y: inject injectBar: unused provider "provideUnused"
/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectBar: unused value of type string example.com/foo/wire.go:x:y: inject injectBar: unused value of type string
/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectBar: unused interface binding to type example.com/foo.Fooer example.com/foo/wire.go:x:y: inject injectBar: unused interface binding to type example.com/foo.Fooer

View File

@@ -1 +1 @@
/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectBar: value int can't be used: f is not declared in package scope example.com/foo/wire.go:x:y: inject injectBar: value int can't be used: f is not declared in package scope

View File

@@ -1 +1 @@
/wire_gopath/src/example.com/foo/wire.go:x:y: argument to Value may not be an interface value (found io.Reader); use InterfaceValue instead example.com/foo/wire.go:x:y: argument to Value may not be an interface value (found io.Reader); use InterfaceValue instead

View File

@@ -18,13 +18,15 @@ package wire
import ( import (
"bytes" "bytes"
"context"
"errors"
"fmt" "fmt"
"go/ast" "go/ast"
"go/build"
"go/format" "go/format"
"go/printer" "go/printer"
"go/token" "go/token"
"go/types" "go/types"
"io/ioutil"
"path/filepath" "path/filepath"
"sort" "sort"
"strconv" "strconv"
@@ -36,22 +38,50 @@ import (
"golang.org/x/tools/go/loader" "golang.org/x/tools/go/loader"
) )
// GeneratedFile stores the content of a call to Generate and the
// desired on-disk location of the file.
type GeneratedFile struct {
Path string
Content []byte
}
// Commit writes the generated file to disk.
func (gen GeneratedFile) Commit() error {
if len(gen.Content) == 0 {
return nil
}
return ioutil.WriteFile(gen.Path, gen.Content, 0666)
}
// Generate performs dependency injection for a single package, // Generate performs dependency injection for a single package,
// returning the gofmt'd Go source code. // returning the gofmt'd Go source code. The package pattern is defined
func Generate(bctx *build.Context, wd string, pkg string) ([]byte, []error) { // by the underlying build system. For the go tool, this is described at
prog, errs := load(bctx, wd, []string{pkg}) // https://golang.org/cmd/go/#hdr-Package_lists_and_patterns
//
// wd is the working directory and env is the set of environment
// variables to use when loading the package specified by pkgPattern. If
// env is nil or empty, it is interpreted as an empty set of variables.
// In case of duplicate environment variables, the last one in the list
// takes precedence.
func Generate(ctx context.Context, wd string, env []string, pkgPattern string) (GeneratedFile, []error) {
prog, errs := load(ctx, wd, env, []string{pkgPattern})
if len(errs) > 0 { if len(errs) > 0 {
return nil, errs return GeneratedFile{}, errs
} }
if len(prog.InitialPackages()) != 1 { if len(prog.InitialPackages()) != 1 {
// This is more of a violated precondition than anything else. // This is more of a violated precondition than anything else.
return nil, []error{fmt.Errorf("load: got %d packages", len(prog.InitialPackages()))} return GeneratedFile{}, []error{fmt.Errorf("load: got %d packages", len(prog.InitialPackages()))}
} }
pkgInfo := prog.InitialPackages()[0] pkgInfo := prog.InitialPackages()[0]
outDir, err := detectOutputDir(prog.Fset, pkgInfo.Files)
if err != nil {
return GeneratedFile{}, []error{fmt.Errorf("load: %v", err)}
}
outFname := filepath.Join(outDir, "wire_gen.go")
g := newGen(prog, pkgInfo.Pkg.Path()) g := newGen(prog, pkgInfo.Pkg.Path())
injectorFiles, errs := generateInjectors(g, pkgInfo) injectorFiles, errs := generateInjectors(g, pkgInfo)
if len(errs) > 0 { if len(errs) > 0 {
return nil, errs return GeneratedFile{}, errs
} }
copyNonInjectorDecls(g, injectorFiles, &pkgInfo.Info) copyNonInjectorDecls(g, injectorFiles, &pkgInfo.Info)
goSrc := g.frame() goSrc := g.frame()
@@ -59,9 +89,22 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, []error) {
if err != nil { if err != nil {
// This is likely a bug from a poorly generated source file. // This is likely a bug from a poorly generated source file.
// Return an error and the unformatted source. // Return an error and the unformatted source.
return goSrc, []error{err} return GeneratedFile{Path: outFname, Content: goSrc}, []error{err}
} }
return fmtSrc, nil return GeneratedFile{Path: outFname, Content: fmtSrc}, nil
}
func detectOutputDir(fset *token.FileSet, files []*ast.File) (string, error) {
if len(files) == 0 {
return "", errors.New("no files to derive output directory from")
}
dir := filepath.Dir(fset.File(files[0].Package).Name())
for _, f := range files[1:] {
if dir2 := filepath.Dir(fset.File(f.Package).Name()); dir2 != dir {
return "", fmt.Errorf("found conflicting directories %q and %q", dir, dir2)
}
}
return dir, nil
} }
// generateInjectors generates the injectors for a given package. // generateInjectors generates the injectors for a given package.

View File

@@ -16,21 +16,16 @@ package wire
import ( import (
"bytes" "bytes"
"errors" "context"
"fmt" "fmt"
"go/build" "go/build"
"go/types" "go/types"
"io"
"io/ioutil" "io/ioutil"
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"regexp"
"runtime"
"sort"
"strings" "strings"
"testing" "testing"
"time"
"unicode" "unicode"
"unicode/utf8" "unicode/utf8"
@@ -63,29 +58,39 @@ func TestWire(t *testing.T) {
} }
tests = append(tests, test) tests = append(tests, test)
} }
wd := filepath.Join(magicGOPATH(), "src")
var goToolPath string
if *setup.Record { if *setup.Record {
if _, err := os.Stat(filepath.Join(build.Default.GOROOT, "bin", "go")); err != nil { goToolPath = filepath.Join(build.Default.GOROOT, "bin", "go")
if _, err := os.Stat(goToolPath); err != nil {
t.Fatal("go toolchain not available:", err) t.Fatal("go toolchain not available:", err)
} }
} }
ctx := context.Background()
for _, test := range tests { for _, test := range tests {
test := test test := test
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
t.Parallel() t.Parallel()
// Run Wire from a fake build context. // Materialize a temporary GOPATH directory.
bctx := test.buildContext() gopath, err := ioutil.TempDir("", "wire_test")
gen, errs := Generate(bctx, wd, test.pkg) if err != nil {
if len(gen) > 0 { t.Fatal(err)
defer t.Logf("wire_gen.go:\n%s", gen) }
defer os.RemoveAll(gopath)
if err := test.materialize(gopath); err != nil {
t.Fatal(err)
}
wd := filepath.Join(gopath, "src", "example.com")
gen, errs := Generate(ctx, wd, append(os.Environ(), "GOPATH="+gopath), test.pkg)
if len(gen.Content) > 0 {
defer t.Logf("wire_gen.go:\n%s", gen.Content)
} }
if len(errs) > 0 { if len(errs) > 0 {
gotErrStrings := make([]string, len(errs)) gotErrStrings := make([]string, len(errs))
for i, e := range errs { for i, e := range errs {
gotErrStrings[i] = scrubError(e.Error()) t.Log(e.Error())
t.Log(gotErrStrings[i]) gotErrStrings[i] = scrubError(gopath, e.Error())
} }
if !test.wantWireError { if !test.wantWireError {
t.Fatal("Did not expect errors. To -record an error, create want/wire_errs.txt.") t.Fatal("Did not expect errors. To -record an error, create want/wire_errs.txt.")
@@ -105,26 +110,37 @@ func TestWire(t *testing.T) {
if test.wantWireError { if test.wantWireError {
t.Fatal("wire succeeded; want error") t.Fatal("wire succeeded; want error")
} }
outPathSane := true
if prefix := gopath + string(os.PathSeparator) + "src" + string(os.PathSeparator); !strings.HasPrefix(gen.Path, prefix) {
outPathSane = false
t.Errorf("suggested output path = %q; want to start with %q", gen.Path, prefix)
}
if *setup.Record { if *setup.Record {
// Record ==> Build the generated Wire code, // Record ==> Build the generated Wire code,
// check that the program's output matches the // check that the program's output matches the
// expected output, save wire output on // expected output, save wire output on
// success. // success.
if err := goBuildCheck(test, wd, bctx, gen); err != nil { if !outPathSane {
return
}
if err := gen.Commit(); err != nil {
t.Fatalf("failed to write wire_gen.go to test GOPATH: %v", err)
}
if err := goBuildCheck(goToolPath, gopath, test); err != nil {
t.Fatalf("go build check failed: %v", err) t.Fatalf("go build check failed: %v", err)
} }
wireGenFile := filepath.Join(testRoot, test.name, "want", "wire_gen.go") testdataWireGenPath := filepath.Join(testRoot, test.name, "want", "wire_gen.go")
if err := ioutil.WriteFile(wireGenFile, gen, 0666); err != nil { if err := ioutil.WriteFile(testdataWireGenPath, gen.Content, 0666); err != nil {
t.Fatalf("failed to write wire_gen.go file: %v", err) t.Fatalf("failed to record wire_gen.go to testdata: %v", err)
} }
} else { } else {
// Replay ==> Load golden file and compare to // Replay ==> Load golden file and compare to
// generated result. This check is meant to // generated result. This check is meant to
// detect non-deterministic behavior in the // detect non-deterministic behavior in the
// Generate function. // Generate function.
if !bytes.Equal(gen, test.wantWireOutput) { if !bytes.Equal(gen.Content, test.wantWireOutput) {
gotS, wantS := string(gen), string(test.wantWireOutput) gotS, wantS := string(gen.Content), string(test.wantWireOutput)
diff := cmp.Diff(strings.Split(gotS, "\n"), strings.Split(wantS, "\n")) diff := cmp.Diff(strings.Split(gotS, "\n"), strings.Split(wantS, "\n"))
t.Fatalf("wire output differs from golden file. If this change is expected, run with -record to update the wire_gen.go file.\n*** got:\n%s\n\n*** want:\n%s\n\n*** diff:\n%s", gotS, wantS, diff) t.Fatalf("wire output differs from golden file. If this change is expected, run with -record to update the wire_gen.go file.\n*** got:\n%s\n\n*** want:\n%s\n\n*** diff:\n%s", gotS, wantS, diff)
} }
@@ -133,49 +149,27 @@ func TestWire(t *testing.T) {
} }
} }
func goBuildCheck(test *testCase, wd string, bctx *build.Context, gen []byte) error { func goBuildCheck(goToolPath, gopath string, test *testCase) error {
// Find the absolute import path, since test.pkg may be a relative // Write go.mod files for example.com and the wire package.
// import path. // TODO(#78): Move this to happen in materialize() once modules work.
genPkg, err := bctx.Import(test.pkg, wd, build.FindOnly)
if err != nil {
return err
}
// Run a `go build` with the generated output.
gopath, err := ioutil.TempDir("", "wire_test")
if err != nil {
return err
}
defer os.RemoveAll(gopath)
if err := test.materialize(gopath); err != nil {
return err
}
if len(gen) > 0 {
genPath := filepath.Join(gopath, "src", filepath.FromSlash(genPkg.ImportPath), "wire_gen.go")
if err := ioutil.WriteFile(genPath, gen, 0666); err != nil {
return err
}
}
if err := writeGoMod(gopath); err != nil { if err := writeGoMod(gopath); err != nil {
return err return err
} }
// Run `go build`.
testExePath := filepath.Join(gopath, "bin", "testprog") testExePath := filepath.Join(gopath, "bin", "testprog")
realBuildCtx := &build.Context{
GOARCH: bctx.GOARCH,
GOOS: bctx.GOOS,
GOROOT: bctx.GOROOT,
GOPATH: gopath,
CgoEnabled: bctx.CgoEnabled,
Compiler: bctx.Compiler,
BuildTags: bctx.BuildTags,
ReleaseTags: bctx.ReleaseTags,
}
buildDir := filepath.Join(gopath, "src", genPkg.ImportPath)
buildCmd := []string{"build", "-o", testExePath} buildCmd := []string{"build", "-o", testExePath}
if test.name == "Vendor" && os.Getenv("GO111MODULE") == "on" { if test.name == "Vendor" && os.Getenv("GO111MODULE") == "on" {
buildCmd = append(buildCmd, "-mod=vendor") buildCmd = append(buildCmd, "-mod=vendor")
} }
if err := runGo(realBuildCtx, buildDir, buildCmd...); err != nil { buildCmd = append(buildCmd, test.pkg)
cmd := exec.Command(goToolPath, buildCmd...)
cmd.Dir = filepath.Join(gopath, "src", "example.com")
cmd.Env = append(os.Environ(), "GOPATH="+gopath)
if buildOut, err := cmd.CombinedOutput(); err != nil {
if len(buildOut) > 0 {
return fmt.Errorf("build: %v; output:\n%s", err, buildOut)
}
return fmt.Errorf("build: %v", err) return fmt.Errorf("build: %v", err)
} }
@@ -332,24 +326,15 @@ func TestDisambiguate(t *testing.T) {
func isIdent(s string) bool { func isIdent(s string) bool {
if len(s) == 0 { if len(s) == 0 {
if s == "foo" {
panic("BREAK3")
}
return false return false
} }
r, i := utf8.DecodeRuneInString(s) r, i := utf8.DecodeRuneInString(s)
if !unicode.IsLetter(r) && r != '_' { if !unicode.IsLetter(r) && r != '_' {
if s == "foo" {
panic("BREAK2")
}
return false return false
} }
for i < len(s) { for i < len(s) {
r, sz := utf8.DecodeRuneInString(s[i:]) r, sz := utf8.DecodeRuneInString(s[i:])
if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' { if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' {
if s == "foo" {
panic("BREAK1")
}
return false return false
} }
i += sz i += sz
@@ -357,6 +342,80 @@ func isIdent(s string) bool {
return true return true
} }
// scrubError rewrites the given string to remove occurrences of GOPATH/src,
// rewrites OS-specific path separators to slashes, and any line/column
// information to a fixed ":x:y". For example, if the gopath parameter is
// "C:\GOPATH" and running on Windows, the string
// "C:\GOPATH\src\foo\bar.go:15:4" would be rewritten to "foo/bar.go:x:y".
func scrubError(gopath string, s string) string {
sb := new(strings.Builder)
query := gopath + string(os.PathSeparator) + "src" + string(os.PathSeparator)
for {
// Find next occurrence of source root. This indicates the next path to
// scrub.
start := strings.Index(s, query)
if start == -1 {
sb.WriteString(s)
break
}
// Find end of file name (extension ".go").
fileStart := start + len(query)
fileEnd := strings.Index(s[fileStart:], ".go")
if fileEnd == -1 {
// If no ".go" occurs to end of string, further searches will fail too.
// Break the loop.
sb.WriteString(s)
break
}
fileEnd += fileStart + 3 // Advance to end of extension.
// Write out file name and advance scrub position.
file := s[fileStart:fileEnd]
if os.PathSeparator != '/' {
file = strings.Replace(file, string(os.PathSeparator), "/", -1)
}
sb.WriteString(s[:start])
sb.WriteString(file)
s = s[fileEnd:]
// Peek past to see if there is line/column info.
linecol, linecolLen := scrubLineColumn(s)
sb.WriteString(linecol)
s = s[linecolLen:]
}
return sb.String()
}
func scrubLineColumn(s string) (replacement string, n int) {
if !strings.HasPrefix(s, ":") {
return "", 0
}
// Skip first colon and run of digits.
for n++; len(s) > n && '0' <= s[n] && s[n] <= '9'; {
n++
}
if n == 1 {
// No digits followed colon.
return "", 0
}
// Start on column part.
if !strings.HasPrefix(s[n:], ":") {
return ":x", n
}
lineEnd := n
// Skip second colon and run of digits.
for n++; len(s) > n && '0' <= s[n] && s[n] <= '9'; {
n++
}
if n == lineEnd+1 {
// No digits followed second colon.
return ":x", lineEnd
}
return ":x:y", n
}
type testCase struct { type testCase struct {
name string name string
pkg string pkg string
@@ -367,14 +426,6 @@ type testCase struct {
wantWireErrorStrings []string wantWireErrorStrings []string
} }
var scrubLineNumberAndPositionRegex = regexp.MustCompile("\\.go:[\\d]+:[\\d]+")
var scrubLineNumberRegex = regexp.MustCompile("\\.go:[\\d]+")
func scrubError(s string) string {
s = scrubLineNumberAndPositionRegex.ReplaceAllString(s, ".go:x:y")
return scrubLineNumberRegex.ReplaceAllString(s, ".go:x")
}
// loadTestCase reads a test case from a directory. // loadTestCase reads a test case from a directory.
// //
// The directory structure is: // The directory structure is:
@@ -395,7 +446,7 @@ func scrubError(s string) string {
// missing if no errors expected. // missing if no errors expected.
// Distinct errors are separated by a blank line, // Distinct errors are separated by a blank line,
// and line numbers and line positions are scrubbed // and line numbers and line positions are scrubbed
// (e.g., "foo.go:52:8" --> "foo.go:x:y"). // (e.g. "$GOPATH/src/foo.go:52:8" --> "foo.go:x:y").
// //
// wire_gen.go // wire_gen.go
// verified output of wire from a test run with // verified output of wire from a test run with
@@ -417,7 +468,7 @@ func loadTestCase(root string, wireGoSrc []byte) (*testCase, error) {
wantWireError := err == nil wantWireError := err == nil
var wantWireErrorStrings []string var wantWireErrorStrings []string
if wantWireError { if wantWireError {
wantWireErrorStrings = strings.Split(scrubError(string(wireErrb)), "\n\n") wantWireErrorStrings = strings.Split(string(wireErrb), "\n\n")
} else { } else {
if !*setup.Record { if !*setup.Record {
wantWireOutput, err = ioutil.ReadFile(filepath.Join(root, "want", "wire_gen.go")) wantWireOutput, err = ioutil.ReadFile(filepath.Join(root, "want", "wire_gen.go"))
@@ -448,7 +499,7 @@ func loadTestCase(root string, wireGoSrc []byte) (*testCase, error) {
if err != nil { if err != nil {
return err return err
} }
goFiles[filepath.Join("example.com", rel)] = data goFiles["example.com/"+filepath.ToSlash(rel)] = data
return nil return nil
}) })
if err != nil { if err != nil {
@@ -465,187 +516,11 @@ func loadTestCase(root string, wireGoSrc []byte) (*testCase, error) {
}, nil }, nil
} }
func (test *testCase) buildContext() *build.Context {
return &build.Context{
GOARCH: build.Default.GOARCH,
GOOS: build.Default.GOOS,
GOROOT: build.Default.GOROOT,
GOPATH: magicGOPATH(),
CgoEnabled: build.Default.CgoEnabled,
Compiler: build.Default.Compiler,
ReleaseTags: build.Default.ReleaseTags,
HasSubdir: test.hasSubdir,
ReadDir: test.readDir,
OpenFile: test.openFile,
IsDir: test.isDir,
}
}
const (
magicGOPATHUnix = "/wire_gopath"
magicGOPATHWindows = `C:\wire_gopath`
)
func magicGOPATH() string {
if runtime.GOOS == "windows" {
return magicGOPATHWindows
}
return magicGOPATHUnix
}
func (test *testCase) hasSubdir(root, dir string) (rel string, ok bool) {
// Don't consult filesystem, just lexical.
if dir == root {
return "", true
}
prefix := root
if !strings.HasSuffix(prefix, string(filepath.Separator)) {
prefix += string(filepath.Separator)
}
if !strings.HasPrefix(dir, prefix) {
return "", false
}
return filepath.ToSlash(dir[len(prefix):]), true
}
func (test *testCase) resolve(path string) (resolved string, pathType int) {
subpath, isMagic := test.hasSubdir(magicGOPATH(), path)
if !isMagic {
return path, systemPath
}
if subpath == "src" {
return "", gopathRoot
}
const srcPrefix = "src/"
if !strings.HasPrefix(subpath, srcPrefix) {
return subpath, gopathRoot
}
return subpath[len(srcPrefix):], gopathSrc
}
// Path types
const (
systemPath = iota
gopathRoot
gopathSrc
)
func (test *testCase) readDir(dir string) ([]os.FileInfo, error) {
rpath, pathType := test.resolve(dir)
switch {
case pathType == systemPath:
return ioutil.ReadDir(rpath)
case pathType == gopathRoot && rpath == "":
return []os.FileInfo{dirInfo{name: "src"}}, nil
case pathType == gopathSrc:
names := make([]string, 0, len(test.goFiles))
prefix := rpath + string(filepath.Separator)
for name := range test.goFiles {
if strings.HasPrefix(name, prefix) {
names = append(names, name[len(prefix):])
}
}
sort.Strings(names)
ents := make([]os.FileInfo, 0, len(names))
for _, name := range names {
if i := strings.IndexRune(name, filepath.Separator); i != -1 {
// Directory
dirName := name[:i]
if len(ents) == 0 || ents[len(ents)-1].Name() != dirName {
ents = append(ents, dirInfo{name: dirName})
}
continue
}
ents = append(ents, fileInfo{
name: name,
size: int64(len(test.goFiles[name])),
})
}
return ents, nil
default:
return nil, &os.PathError{
Op: "open",
Path: dir,
Err: os.ErrNotExist,
}
}
}
func (test *testCase) isDir(path string) bool {
rpath, pathType := test.resolve(path)
switch {
case pathType == systemPath:
info, err := os.Stat(rpath)
return err == nil && info.IsDir()
case pathType == gopathRoot && rpath == "":
return true
case pathType == gopathSrc:
prefix := rpath + string(filepath.Separator)
for name := range test.goFiles {
if strings.HasPrefix(name, prefix) {
return true
}
}
return false
default:
return false
}
}
type dirInfo struct {
name string
}
func (d dirInfo) Name() string { return d.name }
func (d dirInfo) Size() int64 { return 0 }
func (d dirInfo) Mode() os.FileMode { return os.ModeDir | os.ModePerm }
func (d dirInfo) ModTime() time.Time { return time.Unix(0, 0) }
func (d dirInfo) IsDir() bool { return true }
func (d dirInfo) Sys() interface{} { return nil }
type fileInfo struct {
name string
size int64
}
func (f fileInfo) Name() string { return f.name }
func (f fileInfo) Size() int64 { return f.size }
func (f fileInfo) Mode() os.FileMode { return os.ModeDir | 0666 }
func (f fileInfo) ModTime() time.Time { return time.Unix(0, 0) }
func (f fileInfo) IsDir() bool { return false }
func (f fileInfo) Sys() interface{} { return nil }
func (test *testCase) openFile(path string) (io.ReadCloser, error) {
rpath, pathType := test.resolve(path)
switch {
case pathType == systemPath:
return os.Open(path)
case pathType == gopathSrc:
content, ok := test.goFiles[rpath]
if !ok {
return nil, &os.PathError{
Op: "open",
Path: path,
Err: errors.New("does not exist or is not a file"),
}
}
return ioutil.NopCloser(bytes.NewReader(content)), nil
default:
return nil, &os.PathError{
Op: "open",
Path: path,
Err: errors.New("does not exist or is not a file"),
}
}
}
// materialize creates a new GOPATH at the given directory, which may or // materialize creates a new GOPATH at the given directory, which may or
// may not exist. // may not exist.
func (test *testCase) materialize(gopath string) error { func (test *testCase) materialize(gopath string) error {
for name, content := range test.goFiles { for name, content := range test.goFiles {
dst := filepath.Join(gopath, "src", name) dst := filepath.Join(gopath, "src", filepath.FromSlash(name))
if err := os.MkdirAll(filepath.Dir(dst), 0777); err != nil { if err := os.MkdirAll(filepath.Dir(dst), 0777); err != nil {
return fmt.Errorf("materialize GOPATH: %v", err) return fmt.Errorf("materialize GOPATH: %v", err)
} }
@@ -675,11 +550,11 @@ func (test *testCase) materialize(gopath string) error {
// //
// ... (Dependency files copied) // ... (Dependency files copied)
func writeGoMod(gopath string) error { func writeGoMod(gopath string) error {
importPath := "example.com" const importPath = "example.com"
depPath := "github.com/google/go-cloud" const depPath = "github.com/google/go-cloud"
depLoc := filepath.Join(gopath, "src", filepath.FromSlash(depPath)) depLoc := filepath.Join(gopath, "src", filepath.FromSlash(depPath))
example := fmt.Sprintf("module %s\n\nreplace %s => %s\n", importPath, depPath, depLoc) example := fmt.Sprintf("module %s\n\nreplace %s => %s\n", importPath, depPath, depLoc)
gomod := filepath.Join(gopath, "src", importPath, "go.mod") gomod := filepath.Join(gopath, "src", filepath.FromSlash(importPath), "go.mod")
if err := ioutil.WriteFile(gomod, []byte(example), 0666); err != nil { if err := ioutil.WriteFile(gomod, []byte(example), 0666); err != nil {
return fmt.Errorf("generate go.mod for %s: %v", gomod, err) return fmt.Errorf("generate go.mod for %s: %v", gomod, err)
} }
@@ -688,25 +563,3 @@ func writeGoMod(gopath string) error {
} }
return nil return nil
} }
// runGo runs a go command in dir.
func runGo(bctx *build.Context, dir string, args ...string) error {
exe := filepath.Join(bctx.GOROOT, "bin", "go")
c := exec.Command(exe, args...)
c.Env = append(os.Environ(), "GOROOT="+bctx.GOROOT, "GOARCH="+bctx.GOARCH, "GOOS="+bctx.GOOS, "GOPATH="+bctx.GOPATH)
c.Dir = dir
if bctx.CgoEnabled {
c.Env = append(c.Env, "CGO_ENABLED=1")
} else {
c.Env = append(c.Env, "CGO_ENABLED=0")
}
// TODO(someday): Set -compiler flag if needed.
out, err := c.CombinedOutput()
if err != nil {
if len(out) > 0 {
return fmt.Errorf("%v; output:\n%s", err, out)
}
return err
}
return nil
}