wire: change call sites to allow multiple errors (google/go-cloud#118)
Updated the call sites to allow multiple errors to be returned from the package. Load is now permitted to return partial success. Updates google/go-cloud#5
This commit is contained in:
@@ -18,6 +18,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/build"
|
||||
"go/token"
|
||||
@@ -68,9 +69,10 @@ func generate(pkg string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
out, err := wire.Generate(&build.Default, wd, pkg)
|
||||
if err != nil {
|
||||
return err
|
||||
out, errs := wire.Generate(&build.Default, wd, pkg)
|
||||
if len(errs) > 0 {
|
||||
logErrors(errs)
|
||||
return errors.New("generate failed")
|
||||
}
|
||||
if len(out) == 0 {
|
||||
// No Wire directives, don't write anything.
|
||||
@@ -94,10 +96,8 @@ func show(pkgs ...string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
info, err := wire.Load(&build.Default, wd, pkgs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
info, errs := wire.Load(&build.Default, wd, pkgs)
|
||||
if info != nil {
|
||||
keys := make([]wire.ProviderSetID, 0, len(info.Sets))
|
||||
for k := range info.Sets {
|
||||
keys = append(keys, k)
|
||||
@@ -144,6 +144,11 @@ func show(pkgs ...string) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(errs) > 0 {
|
||||
logErrors(errs)
|
||||
return errors.New("error loading packages")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -328,3 +333,9 @@ func formatProviderSetName(importPath, varName string) string {
|
||||
// Since varName is an identifier, it doesn't make sense to quote.
|
||||
return strconv.Quote(importPath) + "." + varName
|
||||
}
|
||||
|
||||
func logErrors(errs []error) {
|
||||
for _, err := range errs {
|
||||
fmt.Fprintln(os.Stderr, strings.Replace(err.Error(), "\n", "\n\t", -1))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,8 +144,9 @@ type Value struct {
|
||||
}
|
||||
|
||||
// Load finds all the provider sets in the given packages, as well as
|
||||
// the provider sets' transitive dependencies.
|
||||
func Load(bctx *build.Context, wd string, pkgs []string) (*Info, error) {
|
||||
// the provider sets' transitive dependencies. It may return both an error
|
||||
// and Info.
|
||||
func Load(bctx *build.Context, wd string, pkgs []string) (*Info, []error) {
|
||||
// TODO(light): Stop errors from printing to stderr.
|
||||
conf := &loader.Config{
|
||||
Build: bctx,
|
||||
@@ -157,7 +158,7 @@ func Load(bctx *build.Context, wd string, pkgs []string) (*Info, error) {
|
||||
}
|
||||
prog, err := conf.Load()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load: %v", err)
|
||||
return nil, []error{fmt.Errorf("load: %v", err)}
|
||||
}
|
||||
info := &Info{
|
||||
Fset: prog.Fset,
|
||||
|
||||
@@ -38,10 +38,10 @@ import (
|
||||
|
||||
// Generate performs dependency injection for a single package,
|
||||
// returning the gofmt'd Go source code.
|
||||
func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
|
||||
func Generate(bctx *build.Context, wd string, pkg string) ([]byte, []error) {
|
||||
mainPkg, err := bctx.Import(pkg, wd, build.FindOnly)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load: %v", err)
|
||||
return nil, []error{fmt.Errorf("load: %v", err)}
|
||||
}
|
||||
// TODO(light): Stop errors from printing to stderr.
|
||||
conf := &loader.Config{
|
||||
@@ -59,17 +59,17 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
|
||||
|
||||
prog, err := conf.Load()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load: %v", err)
|
||||
return nil, []error{fmt.Errorf("load: %v", err)}
|
||||
}
|
||||
if len(prog.InitialPackages()) != 1 {
|
||||
// This is more of a violated precondition than anything else.
|
||||
return nil, fmt.Errorf("load: got %d packages", len(prog.InitialPackages()))
|
||||
return nil, []error{fmt.Errorf("load: got %d packages", len(prog.InitialPackages()))}
|
||||
}
|
||||
pkgInfo := prog.InitialPackages()[0]
|
||||
g := newGen(prog, pkgInfo.Pkg.Path())
|
||||
injectorFiles, err := generateInjectors(g, pkgInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, []error{err}
|
||||
}
|
||||
copyNonInjectorDecls(g, injectorFiles, &pkgInfo.Info)
|
||||
goSrc := g.frame()
|
||||
@@ -77,7 +77,7 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
|
||||
if err != nil {
|
||||
// This is likely a bug from a poorly generated source file.
|
||||
// Return an error and the unformatted source.
|
||||
return goSrc, err
|
||||
return goSrc, []error{err}
|
||||
}
|
||||
return fmtSrc, nil
|
||||
}
|
||||
|
||||
@@ -69,17 +69,17 @@ func TestWire(t *testing.T) {
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
bctx := test.buildContext()
|
||||
gen, err := Generate(bctx, wd, test.pkg)
|
||||
gen, errs := Generate(bctx, wd, test.pkg)
|
||||
if len(gen) > 0 {
|
||||
defer t.Logf("wire_gen.go:\n%s", gen)
|
||||
}
|
||||
if err != nil {
|
||||
if len(errs) > 0 {
|
||||
if !test.wantError {
|
||||
t.Fatalf("wirego: %v", err)
|
||||
t.Fatalf("wirego: %v", errs)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err == nil && test.wantError {
|
||||
if len(errs) == 0 && test.wantError {
|
||||
t.Fatal("wirego succeeded; want error")
|
||||
}
|
||||
|
||||
@@ -133,15 +133,15 @@ func TestWire(t *testing.T) {
|
||||
}
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
bctx := test.buildContext()
|
||||
gold, err := Generate(bctx, wd, test.pkg)
|
||||
if err != nil {
|
||||
t.Fatal("wirego:", err)
|
||||
gold, errs := Generate(bctx, wd, test.pkg)
|
||||
if len(errs) > 0 {
|
||||
t.Fatal("wirego:", errs)
|
||||
}
|
||||
goldstr := string(gold)
|
||||
for i := 0; i < runs-1; i++ {
|
||||
out, err := Generate(bctx, wd, test.pkg)
|
||||
if err != nil {
|
||||
t.Fatal("wirego (on subsequent run):", err)
|
||||
out, errs := Generate(bctx, wd, test.pkg)
|
||||
if len(errs) > 0 {
|
||||
t.Fatal("wirego (on subsequent run):", errs)
|
||||
}
|
||||
if !bytes.Equal(gold, out) {
|
||||
t.Fatalf("wirego output differs when run repeatedly on same input:\n%s", diff(goldstr, string(out)))
|
||||
|
||||
Reference in New Issue
Block a user