diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 83d4de7..af0b4f7 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -337,7 +337,7 @@ func Load(ctx context.Context, wd string, env []string, patterns []string) (*Inf } // load typechecks the packages that match the given patterns and -// includes source for all transitive dependencies. The patterns are +// includes source for all transitive dependencies. The patterns are // defined by the underlying build system. For the go tool, this is // described at https://golang.org/cmd/go/#hdr-Package_lists_and_patterns // @@ -860,25 +860,39 @@ func processBind(fset *token.FileSet, info *types.Info, call *ast.CallExpr) (*If // Assumes that call.Fun is wire.Bind. if len(call.Args) != 2 { - return nil, notePosition(fset.Position(call.Pos()), errors.New("call to Bind takes exactly two arguments")) + return nil, notePosition(fset.Position(call.Pos()), + errors.New("call to Bind takes exactly two arguments")) } // TODO(light): Verify that arguments are simple expressions. ifaceArgType := info.TypeOf(call.Args[0]) ifacePtr, ok := ifaceArgType.(*types.Pointer) if !ok { - return nil, notePosition(fset.Position(call.Pos()), fmt.Errorf("first argument to Bind must be a pointer to an interface type; found %s", types.TypeString(ifaceArgType, nil))) + return nil, notePosition(fset.Position(call.Pos()), + fmt.Errorf("first argument to Bind must be a pointer to an interface type; found %s", types.TypeString(ifaceArgType, nil))) } iface := ifacePtr.Elem() methodSet, ok := iface.Underlying().(*types.Interface) if !ok { - return nil, notePosition(fset.Position(call.Pos()), fmt.Errorf("first argument to Bind must be a pointer to an interface type; found %s", types.TypeString(ifaceArgType, nil))) + return nil, notePosition(fset.Position(call.Pos()), + fmt.Errorf("first argument to Bind must be a pointer to an interface type; found %s", types.TypeString(ifaceArgType, nil))) } + provided := info.TypeOf(call.Args[1]) + if bindShouldUsePointer(info, call) { + providedPtr, ok := provided.(*types.Pointer) + if !ok { + return nil, notePosition(fset.Position(call.Args[0].Pos()), + fmt.Errorf("second argument to Bind must be a pointer or a pointer to a pointer; found %s", types.TypeString(provided, nil))) + } + provided = providedPtr.Elem() + } if types.Identical(iface, provided) { - return nil, notePosition(fset.Position(call.Pos()), errors.New("cannot bind interface to itself")) + return nil, notePosition(fset.Position(call.Pos()), + errors.New("cannot bind interface to itself")) } if !types.Implements(provided, methodSet) { - return nil, notePosition(fset.Position(call.Pos()), fmt.Errorf("%s does not implement %s", types.TypeString(provided, nil), types.TypeString(iface, nil))) + return nil, notePosition(fset.Position(call.Pos()), + fmt.Errorf("%s does not implement %s", types.TypeString(provided, nil), types.TypeString(iface, nil))) } return &IfaceBinding{ Pos: call.Pos(), @@ -1185,3 +1199,13 @@ func (pt ProvidedType) Field() *Field { } return pt.f } + +// bindShouldUsePointer loads the wire package the user is importing from their +// injector. The call is a wire marker function call. +func bindShouldUsePointer(info *types.Info, call *ast.CallExpr) bool { + // These type assertions should not fail, otherwise panic. + fun := call.Fun.(*ast.SelectorExpr) // wire.Bind + pkgName := fun.X.(*ast.Ident) // wire + wireName := info.ObjectOf(pkgName).(*types.PkgName) // wire package + return wireName.Imported().Scope().Lookup("bindToUsePointer") != nil +} diff --git a/internal/wire/testdata/BindInjectorArg/foo/foo.go b/internal/wire/testdata/BindInjectorArg/foo/foo.go index fe61a3f..5e39698 100644 --- a/internal/wire/testdata/BindInjectorArg/foo/foo.go +++ b/internal/wire/testdata/BindInjectorArg/foo/foo.go @@ -19,7 +19,7 @@ import ( ) func main() { - fmt.Println(inject(&Foo{"hello"}).Name) + fmt.Println(inject(Foo{"hello"}).Name) } type Fooer interface { @@ -30,7 +30,7 @@ type Foo struct { f string } -func (f *Foo) Foo() string { +func (f Foo) Foo() string { return f.f } diff --git a/internal/wire/testdata/BindInjectorArg/foo/wire.go b/internal/wire/testdata/BindInjectorArg/foo/wire.go index a8e56f2..b46279b 100644 --- a/internal/wire/testdata/BindInjectorArg/foo/wire.go +++ b/internal/wire/testdata/BindInjectorArg/foo/wire.go @@ -20,11 +20,10 @@ import ( "github.com/google/wire" ) -func inject(foo *Foo) *Bar { - // Currently fails because wire.Bind can't see injector args (#547). +func inject(foo Foo) *Bar { wire.Build( NewBar, - wire.Bind(new(Fooer), &Foo{}), + wire.Bind(new(Fooer), new(Foo)), ) return nil } diff --git a/internal/wire/testdata/BindInjectorArg/want/wire_gen.go b/internal/wire/testdata/BindInjectorArg/want/wire_gen.go index abbd9c6..a9ceae2 100644 --- a/internal/wire/testdata/BindInjectorArg/want/wire_gen.go +++ b/internal/wire/testdata/BindInjectorArg/want/wire_gen.go @@ -7,7 +7,7 @@ package main // Injectors from wire.go: -func inject(foo *Foo) *Bar { +func inject(foo Foo) *Bar { bar := NewBar(foo) return bar } diff --git a/internal/wire/testdata/BindInjectorArgPointer/foo/foo.go b/internal/wire/testdata/BindInjectorArgPointer/foo/foo.go new file mode 100644 index 0000000..d8b1edd --- /dev/null +++ b/internal/wire/testdata/BindInjectorArgPointer/foo/foo.go @@ -0,0 +1,43 @@ +// Copyright 2019 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "fmt" +) + +func main() { + fmt.Println(inject(&Foo{"hello"}).Name) +} + +type Fooer interface { + Foo() string +} + +type Foo struct { + f string +} + +func (f *Foo) Foo() string { + return f.f +} + +type Bar struct { + Name string +} + +func NewBar(fooer Fooer) *Bar { + return &Bar{Name: fooer.Foo()} +} diff --git a/internal/wire/testdata/BindInjectorArgPointer/foo/wire.go b/internal/wire/testdata/BindInjectorArgPointer/foo/wire.go new file mode 100644 index 0000000..125fee4 --- /dev/null +++ b/internal/wire/testdata/BindInjectorArgPointer/foo/wire.go @@ -0,0 +1,29 @@ +// Copyright 2019 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//+build wireinject + +package main + +import ( + "github.com/google/wire" +) + +func inject(foo *Foo) *Bar { + wire.Build( + NewBar, + wire.Bind(new(Fooer), new(*Foo)), + ) + return nil +} diff --git a/internal/wire/testdata/BindInjectorArgPointer/pkg b/internal/wire/testdata/BindInjectorArgPointer/pkg new file mode 100644 index 0000000..f7a5c8c --- /dev/null +++ b/internal/wire/testdata/BindInjectorArgPointer/pkg @@ -0,0 +1 @@ +example.com/foo diff --git a/internal/wire/testdata/BindInjectorArgPointer/want/program_out.txt b/internal/wire/testdata/BindInjectorArgPointer/want/program_out.txt new file mode 100644 index 0000000..ce01362 --- /dev/null +++ b/internal/wire/testdata/BindInjectorArgPointer/want/program_out.txt @@ -0,0 +1 @@ +hello diff --git a/internal/wire/testdata/BindInjectorArgPointer/want/wire_gen.go b/internal/wire/testdata/BindInjectorArgPointer/want/wire_gen.go new file mode 100644 index 0000000..abbd9c6 --- /dev/null +++ b/internal/wire/testdata/BindInjectorArgPointer/want/wire_gen.go @@ -0,0 +1,13 @@ +// Code generated by Wire. DO NOT EDIT. + +//go:generate wire +//+build !wireinject + +package main + +// Injectors from wire.go: + +func inject(foo *Foo) *Bar { + bar := NewBar(foo) + return bar +} diff --git a/internal/wire/testdata/BindInterfaceWithValue/foo/wire.go b/internal/wire/testdata/BindInterfaceWithValue/foo/wire.go index 4353bf5..061c089 100644 --- a/internal/wire/testdata/BindInterfaceWithValue/foo/wire.go +++ b/internal/wire/testdata/BindInterfaceWithValue/foo/wire.go @@ -26,7 +26,7 @@ import ( func inject() io.Writer { wire.Build( wire.Value(os.Stdout), - wire.Bind(new(io.Writer), new(os.File)), + wire.Bind(new(io.Writer), new(*os.File)), ) return nil } diff --git a/internal/wire/testdata/ExampleWithMocks/foo/foo.go b/internal/wire/testdata/ExampleWithMocks/foo/foo.go index ac8f75c..a0f421f 100644 --- a/internal/wire/testdata/ExampleWithMocks/foo/foo.go +++ b/internal/wire/testdata/ExampleWithMocks/foo/foo.go @@ -75,7 +75,7 @@ var mockAppSet = wire.NewSet( // For each mocked dependency, add a provider and use wire.Bind to bind // the concrete type to the relevant interface. newMockTimer, - wire.Bind(new(timer), new(mockTimer)), + wire.Bind(new(timer), new(*mockTimer)), ) type timer interface { diff --git a/internal/wire/testdata/ImportedInterfaceBinding/bar/bar.go b/internal/wire/testdata/ImportedInterfaceBinding/bar/bar.go index 2aaaaf5..02099ef 100644 --- a/internal/wire/testdata/ImportedInterfaceBinding/bar/bar.go +++ b/internal/wire/testdata/ImportedInterfaceBinding/bar/bar.go @@ -39,4 +39,4 @@ func provideBar() *Bar { var Set = wire.NewSet( provideBar, - wire.Bind((*foo.Fooer)(nil), (*Bar)(nil))) + wire.Bind(new(foo.Fooer), new(*Bar))) diff --git a/internal/wire/testdata/InterfaceBinding/foo/foo.go b/internal/wire/testdata/InterfaceBinding/foo/foo.go index 5690c96..0fb12e1 100644 --- a/internal/wire/testdata/InterfaceBinding/foo/foo.go +++ b/internal/wire/testdata/InterfaceBinding/foo/foo.go @@ -42,4 +42,4 @@ func provideBar() *Bar { var Set = wire.NewSet( provideBar, - wire.Bind((*Fooer)(nil), (*Bar)(nil))) + wire.Bind(new(Fooer), new(*Bar))) diff --git a/internal/wire/testdata/InterfaceBindingDoesntImplement/foo/wire.go b/internal/wire/testdata/InterfaceBindingDoesntImplement/foo/wire.go index 568ff42..915ae77 100644 --- a/internal/wire/testdata/InterfaceBindingDoesntImplement/foo/wire.go +++ b/internal/wire/testdata/InterfaceBindingDoesntImplement/foo/wire.go @@ -22,6 +22,6 @@ import ( func injectFooer() Fooer { // wrong: string doesn't implement Fooer. - wire.Build(wire.Bind((*Fooer)(nil), "foo")) + wire.Build(wire.Bind(new(Fooer), new(string))) return nil } diff --git a/internal/wire/testdata/InterfaceBindingNotEnoughArgs/foo/wire.go b/internal/wire/testdata/InterfaceBindingNotEnoughArgs/foo/wire.go index 3f951dc..0c849f1 100644 --- a/internal/wire/testdata/InterfaceBindingNotEnoughArgs/foo/wire.go +++ b/internal/wire/testdata/InterfaceBindingNotEnoughArgs/foo/wire.go @@ -22,6 +22,6 @@ import ( func injectFooer() Fooer { // wrong: wire.Bind requires 2 args. - wire.Build(wire.Bind((*Fooer)(nil))) + wire.Build(wire.Bind(new(Fooer))) return nil } diff --git a/internal/wire/testdata/InterfaceBindingReuse/foo/wire.go b/internal/wire/testdata/InterfaceBindingReuse/foo/wire.go index ef85187..bdb434a 100644 --- a/internal/wire/testdata/InterfaceBindingReuse/foo/wire.go +++ b/internal/wire/testdata/InterfaceBindingReuse/foo/wire.go @@ -24,7 +24,7 @@ func injectFooBar() FooBar { wire.Build( provideBar, provideFooBar, - wire.Bind((*Fooer)(nil), (*Bar)(nil)), + wire.Bind(new(Fooer), new(*Bar)), ) return FooBar{} } diff --git a/internal/wire/testdata/MultipleBindings/foo/wire.go b/internal/wire/testdata/MultipleBindings/foo/wire.go index d6227ff..93a4570 100644 --- a/internal/wire/testdata/MultipleBindings/foo/wire.go +++ b/internal/wire/testdata/MultipleBindings/foo/wire.go @@ -49,5 +49,5 @@ func injectDuplicateValues() Foo { func injectDuplicateInterface() Bar { // fail: provideBar and wire.Bind both provide Bar. - panic(wire.Build(provideBar, wire.Bind(new(Bar), strings.NewReader("hello")))) + panic(wire.Build(provideBar, wire.Bind(new(Bar), new(*strings.Reader)))) } diff --git a/internal/wire/testdata/ProviderSetBindingMissingConcreteType/foo/foo.go b/internal/wire/testdata/ProviderSetBindingMissingConcreteType/foo/foo.go index 6d28bbb..e71d722 100644 --- a/internal/wire/testdata/ProviderSetBindingMissingConcreteType/foo/foo.go +++ b/internal/wire/testdata/ProviderSetBindingMissingConcreteType/foo/foo.go @@ -44,6 +44,6 @@ var ( // From the user guide: // Any set that includes an interface binding must also have a provider in // the same set that provides the concrete type. - setB = wire.NewSet(wire.Bind(new(fooer), new(foo))) + setB = wire.NewSet(wire.Bind(new(fooer), new(*foo))) setC = wire.NewSet(setA, setB) ) diff --git a/internal/wire/testdata/UnusedProviders/foo/wire.go b/internal/wire/testdata/UnusedProviders/foo/wire.go index 689a863..779207d 100644 --- a/internal/wire/testdata/UnusedProviders/foo/wire.go +++ b/internal/wire/testdata/UnusedProviders/foo/wire.go @@ -22,14 +22,14 @@ import ( func injectBar() Bar { wire.Build( - provideFoo, // needed as input for provideBar - provideBar, // needed for Bar - partiallyUsedSet, // 1/2 providers in the set are needed - provideUnused, // not needed -> error - wire.Value("unused"), // not needed -> error - unusedSet, // nothing in set is needed -> error - wire.Bind((*Fooer)(nil), (*Foo)(nil)), // binding to Fooer is not needed -> error - wire.FieldsOf(new(S), "Cfg"), // S.Cfg not needed -> error + provideFoo, // needed as input for provideBar + provideBar, // needed for Bar + partiallyUsedSet, // 1/2 providers in the set are needed + provideUnused, // not needed -> error + wire.Value("unused"), // not needed -> error + unusedSet, // nothing in set is needed -> error + wire.Bind(new(Fooer), new(*Foo)), // binding to Fooer is not needed -> error + wire.FieldsOf(new(S), "Cfg"), // S.Cfg not needed -> error ) return 0 } diff --git a/wire.go b/wire.go index 3352d72..e0c5121 100644 --- a/wire.go +++ b/wire.go @@ -93,9 +93,9 @@ func Build(...interface{}) string { // A Binding maps an interface to a concrete type. type Binding struct{} -// Bind declares that a concrete type should be used to satisfy a -// dependency on the type of iface, which must be a pointer to an -// interface type. +// Bind declares that a concrete type should be used to satisfy a dependency on +// the type of iface. iface must be a pointer to an interface type, to must be a +// pointer to a concrete type. // // Example: // @@ -108,12 +108,16 @@ type Binding struct{} // func (MyFoo) Foo() {} // // var MySet = wire.NewSet( -// MyFoo{}, +// wire.Struct(new(MyFoo)) // wire.Bind(new(Fooer), new(MyFoo))) func Bind(iface, to interface{}) Binding { return Binding{} } +// bindToUsePointer is detected by the wire tool to indicate that Bind's second argument should take a pointer. +// See https://github.com/google/wire/issues/120 for details. +const bindToUsePointer = true + // A ProvidedValue is an expression that is copied to the generated injector. type ProvidedValue struct{}