Support variadic provider and injector functions (#91)

Fixes #61
This commit is contained in:
shantuo
2018-12-03 08:30:42 -08:00
committed by Ross Light
parent 65d810f60a
commit ef9bb67152
8 changed files with 100 additions and 4 deletions

View File

@@ -57,6 +57,9 @@ type call struct {
// This will be nil for kind == valueExpr. // This will be nil for kind == valueExpr.
args []int args []int
// varargs is true if the provider function is variadic.
varargs bool
// fieldNames maps the arguments to struct field names. // fieldNames maps the arguments to struct field names.
// This will only be set if kind == structProvider. // This will only be set if kind == structProvider.
fieldNames []string fieldNames []string
@@ -192,6 +195,7 @@ dfs:
pkg: p.Pkg, pkg: p.Pkg,
name: p.Name, name: p.Name,
args: args, args: args,
varargs: p.Varargs,
fieldNames: p.Fields, fieldNames: p.Fields,
ins: ins, ins: ins,
out: curr.t, out: curr.t,

View File

@@ -153,6 +153,9 @@ type Provider struct {
// Args is the list of data dependencies this provider has. // Args is the list of data dependencies this provider has.
Args []ProviderInput Args []ProviderInput
// Varargs is true if the provider function is variadic.
Varargs bool
// IsStruct is true if this provider is a named struct type. // IsStruct is true if this provider is a named struct type.
// Otherwise it's a function. // Otherwise it's a function.
IsStruct bool IsStruct bool
@@ -639,6 +642,7 @@ func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, []erro
Name: fn.Name(), Name: fn.Name(),
Pos: fn.Pos(), Pos: fn.Pos(),
Args: make([]ProviderInput, params.Len()), Args: make([]ProviderInput, params.Len()),
Varargs: sig.Variadic(),
Out: []types.Type{providerSig.out}, Out: []types.Type{providerSig.out},
HasCleanup: providerSig.cleanup, HasCleanup: providerSig.cleanup,
HasErr: providerSig.err, HasErr: providerSig.err,

View File

@@ -0,0 +1,31 @@
// Copyright 2018 The Wire Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"fmt"
"strings"
)
func main() {
fmt.Println(injectedMessage("", "Hello,", "World!"))
}
type title string
// provideMessage provides a friendly user greeting.
func provideMessage(words ...string) string {
return strings.Join(words, " ")
}

View File

@@ -0,0 +1,26 @@
// Copyright 2018 The Wire Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//+build wireinject
package main
import (
"github.com/google/wire"
)
func injectedMessage(t title, lines ...string) string {
wire.Build(provideMessage)
return ""
}

1
internal/wire/testdata/Varargs/pkg vendored Normal file
View File

@@ -0,0 +1 @@
example.com/foo

View File

@@ -0,0 +1 @@
Hello, World!

View File

@@ -0,0 +1,13 @@
// Code generated by Wire. DO NOT EDIT.
//go:generate wire
//+build !wireinject
package main
// Injectors from wire.go:
func injectedMessage(t title, lines ...string) string {
string2 := provideMessage(lines...)
return string2
}

View File

@@ -333,12 +333,12 @@ func (g *gen) inject(pos token.Pos, name string, sig *types.Signature, set *Prov
} }
// Perform one pass to collect all imports, followed by the real pass. // Perform one pass to collect all imports, followed by the real pass.
injectPass(name, params, injectSig, calls, &injectorGen{ injectPass(name, sig, calls, &injectorGen{
g: g, g: g,
errVar: disambiguate("err", g.nameInFileScope), errVar: disambiguate("err", g.nameInFileScope),
discard: true, discard: true,
}) })
injectPass(name, params, injectSig, calls, &injectorGen{ injectPass(name, sig, calls, &injectorGen{
g: g, g: g,
errVar: disambiguate("err", g.nameInFileScope), errVar: disambiguate("err", g.nameInFileScope),
discard: false, discard: false,
@@ -551,7 +551,14 @@ type injectorGen struct {
} }
// injectPass generates an injector given the output from analysis. // injectPass generates an injector given the output from analysis.
func injectPass(name string, params *types.Tuple, injectSig outputSignature, calls []call, ig *injectorGen) { // The sig passed in should be verified.
func injectPass(name string, sig *types.Signature, calls []call, ig *injectorGen) {
params := sig.Params()
injectSig, err := funcOutput(sig)
if err != nil {
// This should be checked by the caller already.
panic(err)
}
ig.p("func %s(", name) ig.p("func %s(", name)
for i := 0; i < params.Len(); i++ { for i := 0; i < params.Len(); i++ {
if i > 0 { if i > 0 {
@@ -565,8 +572,14 @@ func injectPass(name string, params *types.Tuple, injectSig outputSignature, cal
a = disambiguate(a, ig.nameInInjector) a = disambiguate(a, ig.nameInInjector)
} }
ig.paramNames = append(ig.paramNames, a) ig.paramNames = append(ig.paramNames, a)
if sig.Variadic() && i == params.Len()-1 {
// Keep the varargs signature instead of a slice for the last argument if the
// injector is variadic.
ig.p("%s ...%s", ig.paramNames[i], types.TypeString(pi.Type().(*types.Slice).Elem(), ig.g.qualifyPkg))
} else {
ig.p("%s %s", ig.paramNames[i], types.TypeString(pi.Type(), ig.g.qualifyPkg)) ig.p("%s %s", ig.paramNames[i], types.TypeString(pi.Type(), ig.g.qualifyPkg))
} }
}
outTypeString := types.TypeString(injectSig.out, ig.g.qualifyPkg) outTypeString := types.TypeString(injectSig.out, ig.g.qualifyPkg)
switch { switch {
case injectSig.cleanup && injectSig.err: case injectSig.cleanup && injectSig.err:
@@ -639,6 +652,9 @@ func (ig *injectorGen) funcProviderCall(lname string, c *call, injectSig outputS
ig.p("%s", ig.localNames[a-len(ig.paramNames)]) ig.p("%s", ig.localNames[a-len(ig.paramNames)])
} }
} }
if c.varargs {
ig.p("...")
}
ig.p(")\n") ig.p(")\n")
if c.hasErr { if c.hasErr {
ig.p("\tif %s != nil {\n", ig.errVar) ig.p("\tif %s != nil {\n", ig.errVar)