From 65ae46b7eaa146e99673e290251ea26f28139362 Mon Sep 17 00:00:00 2001 From: shantuo Date: Thu, 28 Mar 2019 09:00:20 -0700 Subject: [PATCH] internal/wire: support specifying struct fields to inject (#147) Added wire.Struct function and deprecate old form. Updates #36 --- internal/wire/parse.go | 95 ++++++++++++++++++- .../wire/testdata/ExampleWithMocks/foo/foo.go | 14 +-- .../FieldsOfImportedStruct/main/wire.go | 2 +- .../testdata/FieldsOfValueStruct/main/wire.go | 2 +- .../MultipleSimilarPackages/main/wire.go | 2 +- internal/wire/testdata/Struct/foo/foo.go | 9 +- internal/wire/testdata/Struct/foo/wire.go | 5 + .../wire/testdata/Struct/want/program_out.txt | 1 + .../wire/testdata/Struct/want/wire_gen.go | 8 ++ .../wire/testdata/StructPointer/foo/foo.go | 2 +- .../wire/testdata/StructPointer/foo/wire.go | 2 +- wire.go | 36 +++++-- 12 files changed, 155 insertions(+), 23 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index aa3d48a..83d4de7 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -21,6 +21,7 @@ import ( "go/ast" "go/token" "go/types" + "os" "strconv" "strings" @@ -550,6 +551,12 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex return nil, []error{notePosition(exprPos, err)} } return v, nil + case "Struct": + s, err := processStructProvider(oc.fset, info, call) + if err != nil { + return nil, []error{notePosition(exprPos, err)} + } + return s, nil case "FieldsOf": v, err := processFieldsOf(oc.fset, info, call) if err != nil { @@ -561,7 +568,7 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex } } if tn := structArgType(info, expr); tn != nil { - p, errs := processStructProvider(oc.fset, tn) + p, errs := processStructLiteralProvider(oc.fset, tn) if len(errs) > 0 { return nil, notePositionAll(exprPos, errs) } @@ -733,9 +740,11 @@ func funcOutput(sig *types.Signature) (outputSignature, error) { } } -// processStructProvider creates a provider for a named struct type. +// processStructLiteralProvider creates a provider for a named struct type. // It produces pointer and non-pointer variants via two values in Out. -func processStructProvider(fset *token.FileSet, typeName *types.TypeName) (*Provider, []error) { +// +// This is a copy of the old processStructProvider, which is deprecated now. +func processStructLiteralProvider(fset *token.FileSet, typeName *types.TypeName) (*Provider, []error) { out := typeName.Type() st, ok := out.Underlying().(*types.Struct) if !ok { @@ -743,6 +752,10 @@ func processStructProvider(fset *token.FileSet, typeName *types.TypeName) (*Prov } pos := typeName.Pos() + fmt.Fprintf(os.Stderr, + "Deprecated: %v, see https://godoc.org/github.com/google/wire#Struct for more information.", + notePosition(fset.Position(pos), + fmt.Errorf("using struct literal to inject %s, use wire.Struct instead", typeName.Type()))) provider := &Provider{ Pkg: typeName.Pkg(), Name: typeName.Name(), @@ -766,6 +779,82 @@ func processStructProvider(fset *token.FileSet, typeName *types.TypeName) (*Prov return provider, nil } +// processStructProvider creates a provider for a named struct type. +// It produces pointer and non-pointer variants via two values in Out. +func processStructProvider(fset *token.FileSet, info *types.Info, call *ast.CallExpr) (*Provider, error) { + // Assumes that call.Fun is wire.Struct. + + if len(call.Args) < 1 { + return nil, notePosition(fset.Position(call.Pos()), + errors.New("call to Struct must specify the struct to be injected")) + } + const firstArgReqFormat = "first argument to Struct must be a pointer to a named struct; found %s" + structType := info.TypeOf(call.Args[0]) + structPtr, ok := structType.(*types.Pointer) + if !ok { + return nil, notePosition(fset.Position(call.Pos()), + fmt.Errorf(firstArgReqFormat, types.TypeString(structType, nil))) + } + + st, ok := structPtr.Elem().Underlying().(*types.Struct) + if !ok { + return nil, notePosition(fset.Position(call.Pos()), + fmt.Errorf(firstArgReqFormat, types.TypeString(st, nil))) + } + + stExpr := call.Args[0].(*ast.CallExpr) + typeName := qualifiedIdentObject(info, stExpr.Args[0]) // should be either an identifier or selector + provider := &Provider{ + Pkg: typeName.Pkg(), + Name: typeName.Name(), + Pos: typeName.Pos(), + IsStruct: true, + Out: []types.Type{structPtr.Elem(), structPtr}, + } + if allFields(call) { + provider.Args = make([]ProviderInput, st.NumFields()) + for i := 0; i < st.NumFields(); i++ { + f := st.Field(i) + provider.Args[i] = ProviderInput{ + Type: f.Type(), + FieldName: f.Name(), + } + } + } else { + provider.Args = make([]ProviderInput, len(call.Args)-1) + for i := 1; i < len(call.Args); i++ { + v, err := checkField(call.Args[i], st) + if err != nil { + return nil, notePosition(fset.Position(call.Pos()), err) + } + provider.Args[i-1] = ProviderInput{ + Type: v.Type(), + FieldName: v.Name(), + } + } + } + for i := 0; i < len(provider.Args); i++ { + for j := 0; j < i; j++ { + if types.Identical(provider.Args[i].Type, provider.Args[j].Type) { + f := st.Field(j) + return nil, notePosition(fset.Position(f.Pos()), fmt.Errorf("provider struct has multiple fields of type %s", types.TypeString(provider.Args[j].Type, nil))) + } + } + } + return provider, nil +} + +func allFields(call *ast.CallExpr) bool { + if len(call.Args) != 2 { + return false + } + b, ok := call.Args[1].(*ast.BasicLit) + if !ok { + return false + } + return strings.EqualFold(strconv.Quote("*"), b.Value) +} + // processBind creates an interface binding from a wire.Bind call. func processBind(fset *token.FileSet, info *types.Info, call *ast.CallExpr) (*IfaceBinding, error) { // Assumes that call.Fun is wire.Bind. diff --git a/internal/wire/testdata/ExampleWithMocks/foo/foo.go b/internal/wire/testdata/ExampleWithMocks/foo/foo.go index 0c29252..ac8f75c 100644 --- a/internal/wire/testdata/ExampleWithMocks/foo/foo.go +++ b/internal/wire/testdata/ExampleWithMocks/foo/foo.go @@ -51,8 +51,8 @@ func main() { // appSet is a provider set for creating a real app. var appSet = wire.NewSet( - app{}, - greeter{}, + wire.Struct(new(app), "*"), + wire.Struct(new(greeter), "*"), wire.InterfaceValue(new(timer), realTime{}), ) @@ -61,17 +61,17 @@ var appSet = wire.NewSet( // arguments to the injector. // It is used for Approach A. var appSetWithoutMocks = wire.NewSet( - app{}, - greeter{}, + wire.Struct(new(app), "*"), + wire.Struct(new(greeter), "*"), ) // mockAppSet is a provider set for creating a mocked app, including the mocked // dependencies. // It is used for Approach B. var mockAppSet = wire.NewSet( - app{}, - greeter{}, - appWithMocks{}, + wire.Struct(new(app), "*"), + wire.Struct(new(greeter), "*"), + wire.Struct(new(appWithMocks), "*"), // For each mocked dependency, add a provider and use wire.Bind to bind // the concrete type to the relevant interface. newMockTimer, diff --git a/internal/wire/testdata/FieldsOfImportedStruct/main/wire.go b/internal/wire/testdata/FieldsOfImportedStruct/main/wire.go index 5df1a39..5a7d0df 100644 --- a/internal/wire/testdata/FieldsOfImportedStruct/main/wire.go +++ b/internal/wire/testdata/FieldsOfImportedStruct/main/wire.go @@ -27,7 +27,7 @@ import ( func newBazService(*baz.Config) *baz.Service { wire.Build( - baz.Service{}, + wire.Struct(new(baz.Service), "*"), wire.FieldsOf( new(*baz.Config), "Foo", diff --git a/internal/wire/testdata/FieldsOfValueStruct/main/wire.go b/internal/wire/testdata/FieldsOfValueStruct/main/wire.go index 7053288..eca1ea6 100644 --- a/internal/wire/testdata/FieldsOfValueStruct/main/wire.go +++ b/internal/wire/testdata/FieldsOfValueStruct/main/wire.go @@ -27,7 +27,7 @@ import ( func newBazService() *baz.Service { wire.Build( - baz.Service{}, + wire.Struct(new(baz.Service), "*"), wire.Value(&baz.Config{ Foo: &foo.Config{1}, Bar: &bar.Config{2}, diff --git a/internal/wire/testdata/MultipleSimilarPackages/main/wire.go b/internal/wire/testdata/MultipleSimilarPackages/main/wire.go index 88ddac5..7a94c6e 100644 --- a/internal/wire/testdata/MultipleSimilarPackages/main/wire.go +++ b/internal/wire/testdata/MultipleSimilarPackages/main/wire.go @@ -43,7 +43,7 @@ func (m *MainService) String() string { func newMainService(MainConfig) *MainService { wire.Build( - MainService{}, + wire.Struct(new(MainService), "Foo", "Bar", "baz"), wire.FieldsOf( new(MainConfig), "Foo", diff --git a/internal/wire/testdata/Struct/foo/foo.go b/internal/wire/testdata/Struct/foo/foo.go index aeeaf30..9bb1e52 100644 --- a/internal/wire/testdata/Struct/foo/foo.go +++ b/internal/wire/testdata/Struct/foo/foo.go @@ -22,7 +22,9 @@ import ( func main() { fb := injectFooBar() + pfb := injectPartFooBar() fmt.Println(fb.Foo, fb.Bar) + fmt.Println(pfb.Foo, pfb.Bar) } type Foo int @@ -42,6 +44,11 @@ func provideBar() Bar { } var Set = wire.NewSet( - FooBar{}, + wire.Struct(new(FooBar), "*"), provideFoo, provideBar) + +var PartSet = wire.NewSet( + wire.Struct(new(FooBar), "Foo"), + provideFoo, +) diff --git a/internal/wire/testdata/Struct/foo/wire.go b/internal/wire/testdata/Struct/foo/wire.go index 9083b06..38ecb0c 100644 --- a/internal/wire/testdata/Struct/foo/wire.go +++ b/internal/wire/testdata/Struct/foo/wire.go @@ -24,3 +24,8 @@ func injectFooBar() FooBar { wire.Build(Set) return FooBar{} } + +func injectPartFooBar() FooBar { + wire.Build(PartSet) + return FooBar{} +} diff --git a/internal/wire/testdata/Struct/want/program_out.txt b/internal/wire/testdata/Struct/want/program_out.txt index b1ae43f..3ea9a3a 100644 --- a/internal/wire/testdata/Struct/want/program_out.txt +++ b/internal/wire/testdata/Struct/want/program_out.txt @@ -1 +1,2 @@ 41 1 +41 0 diff --git a/internal/wire/testdata/Struct/want/wire_gen.go b/internal/wire/testdata/Struct/want/wire_gen.go index 3ba452e..4fc97e9 100644 --- a/internal/wire/testdata/Struct/want/wire_gen.go +++ b/internal/wire/testdata/Struct/want/wire_gen.go @@ -16,3 +16,11 @@ func injectFooBar() FooBar { } return fooBar } + +func injectPartFooBar() FooBar { + foo := provideFoo() + fooBar := FooBar{ + Foo: foo, + } + return fooBar +} diff --git a/internal/wire/testdata/StructPointer/foo/foo.go b/internal/wire/testdata/StructPointer/foo/foo.go index e1b4148..8809f77 100644 --- a/internal/wire/testdata/StructPointer/foo/foo.go +++ b/internal/wire/testdata/StructPointer/foo/foo.go @@ -45,6 +45,6 @@ func provideBar() Bar { } var Set = wire.NewSet( - FooBar{}, + wire.Struct(new(FooBar), "*"), provideFoo, provideBar) diff --git a/internal/wire/testdata/StructPointer/foo/wire.go b/internal/wire/testdata/StructPointer/foo/wire.go index 2412601..30145c1 100644 --- a/internal/wire/testdata/StructPointer/foo/wire.go +++ b/internal/wire/testdata/StructPointer/foo/wire.go @@ -26,6 +26,6 @@ func injectFooBar() *FooBar { } func injectEmptyStruct() *Empty { - wire.Build(Empty{}) + wire.Build(wire.Struct(new(Empty))) return nil } diff --git a/wire.go b/wire.go index e7acfa9..3352d72 100644 --- a/wire.go +++ b/wire.go @@ -29,9 +29,9 @@ package wire type ProviderSet struct{} // NewSet creates a new provider set that includes the providers in its -// arguments. Each argument is a function value, a struct (zero) value, a -// provider set, a call to Bind, a call to Value, a call to InterfaceValue or a -// call to FieldsOf. +// arguments. Each argument is a function value, a provider set, a call to +// Struct, a call to Bind, a call to Value, a call to InterfaceValue or a call +// to FieldsOf. // // Passing a function value to NewSet declares that the function's first // return value type will be provided by calling the function. The arguments @@ -44,15 +44,17 @@ type ProviderSet struct{} // will call all the appropriate cleanup functions and return the error from // the injector function. // -// Passing a struct value of type S to NewSet declares that both S and *S will -// be provided by creating a new value of the appropriate type by filling in -// each field of S using the provider of the field's type. -// // Passing a ProviderSet to NewSet is the same as if the set's contents // were passed as arguments to NewSet directly. // // The behavior of passing the result of a call to other functions in this // package are described in their respective doc comments. +// +// For compatibility with older versions of Wire, passing a struct value of type +// S to NewSet declares that both S and *S will be provided by creating a new +// value of the appropriate type by filling in each field of S using the +// provider of the field's type. This form is deprecated and will be removed in +// a future version of Wire: new providers sets should use wire.Struct. func NewSet(...interface{}) ProviderSet { return ProviderSet{} } @@ -137,6 +139,26 @@ func InterfaceValue(typ interface{}, x interface{}) ProvidedValue { return ProvidedValue{} } +// A StructProvider represents a named struct. +type StructProvider struct{} + +// Struct specifies that the given struct type will be provided by filling in the fields +// in the struct that have the names given. Each of the arguments must be a name +// to the field they wish to reference. As a special case, if a single name "*" +// is given, then all of the fields in the struct will be filled in. +// +// For example: +// +// type S struct { +// MyFoo *Foo +// MyBar *Bar +// } +// var Set = wire.NewSet(wire.Struct(new(S), "MyFoo")) -> inject only S.MyFoo +// var Set = wire.NewSet(wire.Struct(new(S), "*")) -> inject all fields +func Struct(structType interface{}, fieldNames ...string) StructProvider { + return StructProvider{} +} + // StructFields is a collection of the fields from a struct. type StructFields struct{}