2018-04-02 09:21:52 -07:00
package goose
import (
"fmt"
"go/ast"
2018-04-04 10:58:07 -07:00
"go/build"
2018-04-02 09:21:52 -07:00
"go/token"
"go/types"
2018-04-04 10:58:07 -07:00
"path/filepath"
2018-04-02 09:21:52 -07:00
"strconv"
"strings"
2018-04-02 10:57:48 -07:00
"unicode"
2018-04-02 09:21:52 -07:00
"golang.org/x/tools/go/loader"
)
// A providerSet describes a set of providers. The zero value is an empty
// providerSet.
type providerSet struct {
providers [ ] * providerInfo
imports [ ] providerSetImport
}
type providerSetImport struct {
providerSetRef
pos token . Pos
}
// providerInfo records the signature of a provider function.
type providerInfo struct {
importPath string
funcName string
pos token . Pos
2018-03-30 21:34:08 -07:00
args [ ] providerInput
2018-04-02 09:21:52 -07:00
out types . Type
hasErr bool
}
2018-03-30 21:34:08 -07:00
type providerInput struct {
typ types . Type
optional bool
}
type findContext struct {
fset * token . FileSet
pkg * types . Package
typeInfo * types . Info
r * importResolver
}
2018-04-02 09:21:52 -07:00
// findProviderSets processes a package and extracts the provider sets declared in it.
2018-03-30 21:34:08 -07:00
func findProviderSets ( fctx findContext , files [ ] * ast . File ) ( map [ string ] * providerSet , error ) {
2018-04-02 09:21:52 -07:00
sets := make ( map [ string ] * providerSet )
for _ , f := range files {
2018-03-30 21:34:08 -07:00
fileScope := fctx . typeInfo . Scopes [ f ]
if fileScope == nil {
return nil , fmt . Errorf ( "%s: no scope found for file (likely a bug)" , fctx . fset . File ( f . Pos ( ) ) . Name ( ) )
}
for _ , dg := range parseFile ( fctx . fset , f ) {
if dg . decl != nil {
if err := processDeclDirectives ( fctx , sets , fileScope , dg ) ; err != nil {
return nil , err
}
} else {
for _ , d := range dg . dirs {
if err := processUnassociatedDirective ( fctx , sets , fileScope , d ) ; err != nil {
return nil , err
2018-04-02 09:21:52 -07:00
}
}
}
}
2018-03-30 21:34:08 -07:00
}
return sets , nil
}
// processUnassociatedDirective handles any directive that was not associated with a top-level declaration.
func processUnassociatedDirective ( fctx findContext , sets map [ string ] * providerSet , scope * types . Scope , d directive ) error {
switch d . kind {
case "provide" , "optional" :
return fmt . Errorf ( "%v: only functions can be marked as providers" , fctx . fset . Position ( d . pos ) )
case "use" :
// Ignore, picked up by injector flow.
case "import" :
2018-04-02 10:57:48 -07:00
args := d . args ( )
if len ( args ) < 2 {
2018-03-30 21:34:08 -07:00
return fmt . Errorf ( "%s: invalid import: expected TARGET SETREF" , fctx . fset . Position ( d . pos ) )
}
2018-04-02 10:57:48 -07:00
name := args [ 0 ]
for _ , spec := range args [ 1 : ] {
ref , err := parseProviderSetRef ( fctx . r , spec , scope , fctx . pkg . Path ( ) , d . pos )
if err != nil {
return fmt . Errorf ( "%v: %v" , fctx . fset . Position ( d . pos ) , err )
2018-04-02 09:21:52 -07:00
}
2018-04-02 10:57:48 -07:00
if ref . importPath != fctx . pkg . Path ( ) {
imported := false
for _ , imp := range fctx . pkg . Imports ( ) {
if ref . importPath == imp . Path ( ) {
imported = true
break
}
}
if ! imported {
return fmt . Errorf ( "%v: provider set %s imports %q which is not in the package's imports" , fctx . fset . Position ( d . pos ) , name , ref . importPath )
2018-04-02 09:21:52 -07:00
}
}
2018-04-02 10:57:48 -07:00
if mod := sets [ name ] ; mod != nil {
found := false
for _ , other := range mod . imports {
if ref == other . providerSetRef {
found = true
break
}
}
if ! found {
mod . imports = append ( mod . imports , providerSetImport { providerSetRef : ref , pos : d . pos } )
}
} else {
sets [ name ] = & providerSet {
imports : [ ] providerSetImport { { providerSetRef : ref , pos : d . pos } } ,
}
2018-04-02 09:21:52 -07:00
}
2018-03-30 21:34:08 -07:00
}
default :
return fmt . Errorf ( "%v: unknown directive %s" , fctx . fset . Position ( d . pos ) , d . kind )
}
return nil
}
// processDeclDirectives processes the directives associated with a top-level declaration.
func processDeclDirectives ( fctx findContext , sets map [ string ] * providerSet , scope * types . Scope , dg directiveGroup ) error {
p , err := dg . single ( fctx . fset , "provide" )
if err != nil {
return err
}
if ! p . isValid ( ) {
for _ , d := range dg . dirs {
if d . kind == "optional" {
return fmt . Errorf ( "%v: cannot use goose:%s directive on non-provider" , fctx . fset . Position ( d . pos ) , d . kind )
}
}
return nil
}
fn , ok := dg . decl . ( * ast . FuncDecl )
if ! ok {
return fmt . Errorf ( "%v: only functions can be marked as providers" , fctx . fset . Position ( p . pos ) )
}
sig := fctx . typeInfo . ObjectOf ( fn . Name ) . Type ( ) . ( * types . Signature )
optionals := make ( [ ] bool , sig . Params ( ) . Len ( ) )
for _ , d := range dg . dirs {
if d . kind == "optional" {
// Marking the given argument names as optional inputs.
2018-04-02 10:57:48 -07:00
for _ , arg := range d . args ( ) {
2018-03-30 21:34:08 -07:00
pi := paramIndex ( sig . Params ( ) , arg )
if pi == - 1 {
return fmt . Errorf ( "%v: %s is not a parameter of func %s" , fctx . fset . Position ( d . pos ) , arg , fn . Name . Name )
2018-04-02 09:21:52 -07:00
}
2018-03-30 21:34:08 -07:00
optionals [ pi ] = true
2018-04-02 09:21:52 -07:00
}
}
}
2018-03-30 21:34:08 -07:00
fpos := fn . Pos ( )
r := sig . Results ( )
var hasErr bool
switch r . Len ( ) {
case 1 :
hasErr = false
case 2 :
if t := r . At ( 1 ) . Type ( ) ; ! types . Identical ( t , errorType ) {
return fmt . Errorf ( "%v: wrong signature for provider %s: second return type must be error" , fctx . fset . Position ( fpos ) , fn . Name . Name )
}
hasErr = true
default :
return fmt . Errorf ( "%v: wrong signature for provider %s: must have one return value and optional error" , fctx . fset . Position ( fpos ) , fn . Name . Name )
}
out := r . At ( 0 ) . Type ( )
params := sig . Params ( )
provider := & providerInfo {
importPath : fctx . pkg . Path ( ) ,
funcName : fn . Name . Name ,
pos : fn . Pos ( ) ,
args : make ( [ ] providerInput , params . Len ( ) ) ,
out : out ,
hasErr : hasErr ,
}
for i := 0 ; i < params . Len ( ) ; i ++ {
provider . args [ i ] = providerInput {
typ : params . At ( i ) . Type ( ) ,
optional : optionals [ i ] ,
}
for j := 0 ; j < i ; j ++ {
if types . Identical ( provider . args [ i ] . typ , provider . args [ j ] . typ ) {
return fmt . Errorf ( "%v: provider has multiple parameters of type %s" , fctx . fset . Position ( fpos ) , types . TypeString ( provider . args [ j ] . typ , nil ) )
}
}
}
providerSetName := fn . Name . Name
2018-04-02 10:57:48 -07:00
if args := p . args ( ) ; len ( args ) == 1 {
2018-03-30 21:34:08 -07:00
// TODO(light): validate identifier
2018-04-02 10:57:48 -07:00
providerSetName = args [ 0 ]
} else if len ( args ) > 1 {
return fmt . Errorf ( "%v: goose:provide takes at most one argument" , fctx . fset . Position ( fpos ) )
2018-03-30 21:34:08 -07:00
}
if mod := sets [ providerSetName ] ; mod != nil {
for _ , other := range mod . providers {
if types . Identical ( other . out , provider . out ) {
return fmt . Errorf ( "%v: provider set %s has multiple providers for %s (previous declaration at %v)" , fctx . fset . Position ( fn . Pos ( ) ) , providerSetName , types . TypeString ( provider . out , nil ) , fctx . fset . Position ( other . pos ) )
}
}
mod . providers = append ( mod . providers , provider )
} else {
sets [ providerSetName ] = & providerSet {
providers : [ ] * providerInfo { provider } ,
}
}
return nil
2018-04-02 09:21:52 -07:00
}
// providerSetCache is a lazily evaluated index of provider sets.
type providerSetCache struct {
sets map [ string ] map [ string ] * providerSet
fset * token . FileSet
prog * loader . Program
2018-04-04 10:58:07 -07:00
r * importResolver
2018-04-02 09:21:52 -07:00
}
2018-04-04 10:58:07 -07:00
func newProviderSetCache ( prog * loader . Program , r * importResolver ) * providerSetCache {
2018-04-02 09:21:52 -07:00
return & providerSetCache {
fset : prog . Fset ,
prog : prog ,
2018-04-04 10:58:07 -07:00
r : r ,
2018-04-02 09:21:52 -07:00
}
}
func ( mc * providerSetCache ) get ( ref providerSetRef ) ( * providerSet , error ) {
if mods , cached := mc . sets [ ref . importPath ] ; cached {
mod := mods [ ref . name ]
if mod == nil {
return nil , fmt . Errorf ( "no such provider set %s in package %q" , ref . name , ref . importPath )
}
return mod , nil
}
if mc . sets == nil {
mc . sets = make ( map [ string ] map [ string ] * providerSet )
}
pkg := mc . prog . Package ( ref . importPath )
2018-03-30 21:34:08 -07:00
mods , err := findProviderSets ( findContext {
fset : mc . fset ,
pkg : pkg . Pkg ,
typeInfo : & pkg . Info ,
r : mc . r ,
} , pkg . Files )
2018-04-02 09:21:52 -07:00
if err != nil {
mc . sets [ ref . importPath ] = nil
return nil , err
}
mc . sets [ ref . importPath ] = mods
mod := mods [ ref . name ]
if mod == nil {
return nil , fmt . Errorf ( "no such provider set %s in package %q" , ref . name , ref . importPath )
}
return mod , nil
}
// A providerSetRef is a parsed reference to a collection of providers.
type providerSetRef struct {
importPath string
name string
}
2018-04-04 10:58:07 -07:00
func parseProviderSetRef ( r * importResolver , ref string , s * types . Scope , pkg string , pos token . Pos ) ( providerSetRef , error ) {
2018-04-02 09:21:52 -07:00
// TODO(light): verify that provider set name is an identifier before returning
i := strings . LastIndexByte ( ref , '.' )
if i == - 1 {
return providerSetRef { importPath : pkg , name : ref } , nil
}
imp , name := ref [ : i ] , ref [ i + 1 : ]
if strings . HasPrefix ( imp , ` " ` ) {
path , err := strconv . Unquote ( imp )
if err != nil {
return providerSetRef { } , fmt . Errorf ( "parse provider set reference %q: bad import path" , ref )
}
2018-04-04 10:58:07 -07:00
path , err = r . resolve ( pos , path )
if err != nil {
return providerSetRef { } , fmt . Errorf ( "parse provider set reference %q: %v" , ref , err )
}
2018-04-02 09:21:52 -07:00
return providerSetRef { importPath : path , name : name } , nil
}
_ , obj := s . LookupParent ( imp , pos )
if obj == nil {
return providerSetRef { } , fmt . Errorf ( "parse provider set reference %q: unknown identifier %s" , ref , imp )
}
pn , ok := obj . ( * types . PkgName )
if ! ok {
return providerSetRef { } , fmt . Errorf ( "parse provider set reference %q: %s does not name a package" , ref , imp )
}
return providerSetRef { importPath : pn . Imported ( ) . Path ( ) , name : name } , nil
}
func ( ref providerSetRef ) String ( ) string {
return strconv . Quote ( ref . importPath ) + "." + ref . name
}
2018-04-04 10:58:07 -07:00
type importResolver struct {
fset * token . FileSet
bctx * build . Context
findPackage func ( bctx * build . Context , importPath , fromDir string , mode build . ImportMode ) ( * build . Package , error )
}
func newImportResolver ( c * loader . Config , fset * token . FileSet ) * importResolver {
r := & importResolver {
fset : fset ,
bctx : c . Build ,
findPackage : c . FindPackage ,
}
if r . bctx == nil {
r . bctx = & build . Default
}
if r . findPackage == nil {
r . findPackage = ( * build . Context ) . Import
}
return r
}
func ( r * importResolver ) resolve ( pos token . Pos , path string ) ( string , error ) {
dir := filepath . Dir ( r . fset . File ( pos ) . Name ( ) )
pkg , err := r . findPackage ( r . bctx , path , dir , build . FindOnly )
if err != nil {
return "" , err
}
return pkg . ImportPath , nil
}
2018-03-30 21:34:08 -07:00
// A directive is a parsed goose comment.
2018-04-02 09:21:52 -07:00
type directive struct {
pos token . Pos
kind string
line string
}
2018-03-30 21:34:08 -07:00
// A directiveGroup is a set of directives associated with a particular
// declaration.
type directiveGroup struct {
decl ast . Decl
dirs [ ] directive
}
// parseFile extracts the directives from a file, grouped by declaration.
func parseFile ( fset * token . FileSet , f * ast . File ) [ ] directiveGroup {
cmap := ast . NewCommentMap ( fset , f , f . Comments )
// Reserve first group for directives that don't associate with a
// declaration, like import.
groups := make ( [ ] directiveGroup , 1 , len ( f . Decls ) + 1 )
// Walk declarations and add to groups.
for _ , decl := range f . Decls {
grp := directiveGroup { decl : decl }
ast . Inspect ( decl , func ( node ast . Node ) bool {
if g := cmap [ node ] ; len ( g ) > 0 {
for _ , cg := range g {
start := len ( grp . dirs )
grp . dirs = extractDirectives ( grp . dirs , cg )
// Move directives that don't associate into the unassociated group.
n := 0
for i := start ; i < len ( grp . dirs ) ; i ++ {
if k := grp . dirs [ i ] . kind ; k == "provide" || k == "optional" || k == "use" {
grp . dirs [ start + n ] = grp . dirs [ i ]
n ++
} else {
groups [ 0 ] . dirs = append ( groups [ 0 ] . dirs , grp . dirs [ i ] )
}
}
grp . dirs = grp . dirs [ : start + n ]
}
delete ( cmap , node )
}
return true
} )
if len ( grp . dirs ) > 0 {
groups = append ( groups , grp )
}
}
// Place remaining directives into the unassociated group.
unassoc := & groups [ 0 ]
for _ , g := range cmap {
for _ , cg := range g {
unassoc . dirs = extractDirectives ( unassoc . dirs , cg )
}
}
if len ( unassoc . dirs ) == 0 {
return groups [ 1 : ]
}
return groups
}
2018-04-02 09:21:52 -07:00
func extractDirectives ( d [ ] directive , cg * ast . CommentGroup ) [ ] directive {
const prefix = "goose:"
text := cg . Text ( )
for len ( text ) > 0 {
text = strings . TrimLeft ( text , " \t\r\n" )
if ! strings . HasPrefix ( text , prefix ) {
break
}
line := text [ len ( prefix ) : ]
2018-04-02 10:57:48 -07:00
// Text() is always newline terminated.
i := strings . IndexByte ( line , '\n' )
line , text = line [ : i ] , line [ i + 1 : ]
2018-04-02 09:21:52 -07:00
if i := strings . IndexByte ( line , ' ' ) ; i != - 1 {
d = append ( d , directive {
kind : line [ : i ] ,
line : strings . TrimSpace ( line [ i + 1 : ] ) ,
pos : cg . Pos ( ) , // TODO(light): more precise position
} )
} else {
d = append ( d , directive {
kind : line ,
pos : cg . Pos ( ) , // TODO(light): more precise position
} )
}
}
return d
}
2018-03-30 21:34:08 -07:00
// single finds at most one directive that matches the given kind.
func ( dg directiveGroup ) single ( fset * token . FileSet , kind string ) ( directive , error ) {
var found directive
ok := false
for _ , d := range dg . dirs {
if d . kind != kind {
continue
}
if ok {
switch decl := dg . decl . ( type ) {
case * ast . FuncDecl :
return directive { } , fmt . Errorf ( "%v: multiple %s directives for %s" , fset . Position ( d . pos ) , kind , decl . Name . Name )
case * ast . GenDecl :
if decl . Tok == token . TYPE && len ( decl . Specs ) == 1 {
name := decl . Specs [ 0 ] . ( * ast . TypeSpec ) . Name . Name
return directive { } , fmt . Errorf ( "%v: multiple %s directives for %s" , fset . Position ( d . pos ) , kind , name )
}
return directive { } , fmt . Errorf ( "%v: multiple %s directives" , fset . Position ( d . pos ) , kind )
default :
return directive { } , fmt . Errorf ( "%v: multiple %s directives" , fset . Position ( d . pos ) , kind )
}
}
found , ok = d , true
}
return found , nil
}
func ( d directive ) isValid ( ) bool {
return d . kind != ""
}
2018-04-02 10:57:48 -07:00
// args splits the directive line into tokens.
func ( d directive ) args ( ) [ ] string {
var args [ ] string
start := - 1
state := 0 // 0 = boundary, 1 = in token, 2 = in quote, 3 = quote backslash
for i , r := range d . line {
switch state {
case 0 :
// Argument boundary
switch {
case r == '"' :
start = i
state = 2
case ! unicode . IsSpace ( r ) :
start = i
state = 1
}
case 1 :
// In token
switch {
case unicode . IsSpace ( r ) :
args = append ( args , d . line [ start : i ] )
start = - 1
state = 0
case r == '"' :
state = 2
}
case 2 :
// In quotes
switch {
case r == '"' :
state = 1
case r == '\\' :
state = 3
}
case 3 :
// Quote backslash. Consumes one character and jumps back into "in quote" state.
state = 2
default :
panic ( "unreachable" )
}
}
if start != - 1 {
args = append ( args , d . line [ start : ] )
}
return args
}
2018-04-02 09:21:52 -07:00
// isInjectFile reports whether a given file is an injection template.
func isInjectFile ( f * ast . File ) bool {
// TODO(light): better determination
for _ , cg := range f . Comments {
text := cg . Text ( )
if strings . HasPrefix ( text , "+build" ) && strings . Contains ( text , "gooseinject" ) {
return true
}
}
return false
}
2018-03-30 21:34:08 -07:00
// paramIndex returns the index of the parameter with the given name, or
// -1 if no such parameter exists.
func paramIndex ( params * types . Tuple , name string ) int {
for i := 0 ; i < params . Len ( ) ; i ++ {
if params . At ( i ) . Name ( ) == name {
return i
}
}
return - 1
}