Skip to content
This repository has been archived by the owner on Jun 27, 2023. It is now read-only.

refactor mockgen and cleanup #536

Merged
merged 2 commits into from Feb 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 19 additions & 19 deletions gomock/call.go
Expand Up @@ -50,16 +50,16 @@ func newCall(t TestHelper, receiver interface{}, method string, methodType refle
t.Helper()

// TODO: check arity, types.
margs := make([]Matcher, len(args))
mArgs := make([]Matcher, len(args))
for i, arg := range args {
if m, ok := arg.(Matcher); ok {
margs[i] = m
mArgs[i] = m
} else if arg == nil {
// Handle nil specially so that passing a nil interface value
// will match the typed nils of concrete args.
margs[i] = Nil()
mArgs[i] = Nil()
} else {
margs[i] = Eq(arg)
mArgs[i] = Eq(arg)
}
}

Expand All @@ -76,7 +76,7 @@ func newCall(t TestHelper, receiver interface{}, method string, methodType refle
return rets
}}
return &Call{t: t, receiver: receiver, method: method, methodType: methodType,
args: margs, origin: origin, minCalls: 1, maxCalls: 1, actions: actions}
args: mArgs, origin: origin, minCalls: 1, maxCalls: 1, actions: actions}
}

// AnyTimes allows the expectation to be called 0 or more times
Expand Down Expand Up @@ -113,19 +113,19 @@ func (c *Call) DoAndReturn(f interface{}) *Call {
v := reflect.ValueOf(f)

c.addAction(func(args []interface{}) []interface{} {
vargs := make([]reflect.Value, len(args))
vArgs := make([]reflect.Value, len(args))
ft := v.Type()
for i := 0; i < len(args); i++ {
if args[i] != nil {
vargs[i] = reflect.ValueOf(args[i])
vArgs[i] = reflect.ValueOf(args[i])
} else {
// Use the zero value for the arg.
vargs[i] = reflect.Zero(ft.In(i))
vArgs[i] = reflect.Zero(ft.In(i))
}
}
vrets := v.Call(vargs)
rets := make([]interface{}, len(vrets))
for i, ret := range vrets {
vRets := v.Call(vArgs)
rets := make([]interface{}, len(vRets))
for i, ret := range vRets {
rets[i] = ret.Interface()
}
return rets
Expand All @@ -142,17 +142,17 @@ func (c *Call) Do(f interface{}) *Call {
v := reflect.ValueOf(f)

c.addAction(func(args []interface{}) []interface{} {
vargs := make([]reflect.Value, len(args))
vArgs := make([]reflect.Value, len(args))
ft := v.Type()
for i := 0; i < len(args); i++ {
if args[i] != nil {
vargs[i] = reflect.ValueOf(args[i])
vArgs[i] = reflect.ValueOf(args[i])
} else {
// Use the zero value for the arg.
vargs[i] = reflect.Zero(ft.In(i))
vArgs[i] = reflect.Zero(ft.In(i))
}
}
v.Call(vargs)
v.Call(vArgs)
return nil
})
return c
Expand Down Expand Up @@ -353,12 +353,12 @@ func (c *Call) matches(args []interface{}) error {
// matches all the remaining arguments or the lack of any.
// Convert the remaining arguments, if any, into a slice of the
// expected type.
vargsType := c.methodType.In(c.methodType.NumIn() - 1)
vargs := reflect.MakeSlice(vargsType, 0, len(args)-i)
vArgsType := c.methodType.In(c.methodType.NumIn() - 1)
vArgs := reflect.MakeSlice(vArgsType, 0, len(args)-i)
for _, arg := range args[i:] {
vargs = reflect.Append(vargs, reflect.ValueOf(arg))
vArgs = reflect.Append(vArgs, reflect.ValueOf(arg))
}
if m.Matches(vargs.Interface()) {
if m.Matches(vArgs.Interface()) {
// Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, gomock.Any())
// Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, someSliceMatcher)
// Got Foo(a, b) want Foo(matcherA, matcherB, gomock.Any())
Expand Down
2 changes: 1 addition & 1 deletion gomock/callset_test.go
Expand Up @@ -84,7 +84,7 @@ func TestCallSetFindMatch(t *testing.T) {

c1 := newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func))
cs.exhausted = map[callSetKey][]*Call{
callSetKey{receiver: receiver, fname: method}: []*Call{c1},
{receiver: receiver, fname: method}: {c1},
}

_, err := cs.FindMatch(receiver, method, args)
Expand Down
1 change: 0 additions & 1 deletion gomock/matchers.go
Expand Up @@ -153,7 +153,6 @@ func (n notMatcher) Matches(x interface{}) bool {
}

func (n notMatcher) String() string {
// TODO: Improve this if we add a NotString method to the Matcher interface.
return "not(" + n.m.String() + ")"
}

Expand Down
50 changes: 44 additions & 6 deletions mockgen/mockgen.go
Expand Up @@ -38,6 +38,7 @@ import (

"github.com/golang/mock/mockgen/model"

"golang.org/x/mod/modfile"
toolsimports "golang.org/x/tools/imports"
)

Expand Down Expand Up @@ -84,6 +85,7 @@ func main() {
log.Fatal("Expected exactly two arguments")
}
packageName = flag.Arg(0)
interfaces := strings.Split(flag.Arg(1), ",")
if packageName == "." {
dir, err := os.Getwd()
if err != nil {
Expand All @@ -94,7 +96,7 @@ func main() {
log.Fatalf("Parse package name failed: %v", err)
}
}
pkg, err = reflectMode(packageName, strings.Split(flag.Arg(1), ","))
pkg, err = reflectMode(packageName, interfaces)
}
if err != nil {
log.Fatalf("Loading input failed: %v", err)
Expand Down Expand Up @@ -394,11 +396,6 @@ func (g *generator) GenerateMockInterface(intf *model.Interface, outputPackagePa
g.p("}")
g.p("")

// TODO: Re-enable this if we can import the interface reliably.
// g.p("// Verify that the mock satisfies the interface at compile time.")
// g.p("var _ %v = (*%v)(nil)", typeName, mockType)
// g.p("")

g.p("// New%v creates a new mock instance.", mockType)
g.p("func New%v(ctrl *gomock.Controller) *%v {", mockType, mockType)
g.in()
Expand Down Expand Up @@ -665,3 +662,44 @@ func printVersion() {
printModuleVersion()
}
}

// parseImportPackage get package import path via source file
// an alternative implementation is to use:
// cfg := &packages.Config{Mode: packages.NeedName, Tests: true, Dir: srcDir}
// pkgs, err := packages.Load(cfg, "file="+source)
// However, it will call "go list" and slow down the performance
func parsePackageImport(srcDir string) (string, error) {
moduleMode := os.Getenv("GO111MODULE")
// trying to find the module
if moduleMode != "off" {
currentDir := srcDir
for {
dat, err := ioutil.ReadFile(filepath.Join(currentDir, "go.mod"))
if os.IsNotExist(err) {
if currentDir == filepath.Dir(currentDir) {
// at the root
break
}
currentDir = filepath.Dir(currentDir)
continue
} else if err != nil {
return "", err
}
modulePath := modfile.ModulePath(dat)
return filepath.ToSlash(filepath.Join(modulePath, strings.TrimPrefix(srcDir, currentDir))), nil
}
}
// fall back to GOPATH mode
goPaths := os.Getenv("GOPATH")
if goPaths == "" {
return "", fmt.Errorf("GOPATH is not set")
}
goPathList := strings.Split(goPaths, string(os.PathListSeparator))
for _, goPath := range goPathList {
sourceRoot := filepath.Join(goPath, "src") + string(os.PathSeparator)
if strings.HasPrefix(srcDir, sourceRoot) {
return filepath.ToSlash(strings.TrimPrefix(srcDir, sourceRoot)), nil
}
}
return "", errOutsideGoPath
}
85 changes: 85 additions & 0 deletions mockgen/mockgen_test.go
Expand Up @@ -2,6 +2,9 @@ package main

import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"reflect"
"regexp"
"strings"
Expand Down Expand Up @@ -364,3 +367,85 @@ func Test_createPackageMap(t *testing.T) {
})
}
}

func TestParsePackageImport_FallbackGoPath(t *testing.T) {
goPath, err := ioutil.TempDir("", "gopath")
if err != nil {
t.Error(err)
}
defer func() {
if err = os.RemoveAll(goPath); err != nil {
t.Error(err)
}
}()
srcDir := filepath.Join(goPath, "src/example.com/foo")
err = os.MkdirAll(srcDir, 0755)
if err != nil {
t.Error(err)
}
key := "GOPATH"
value := goPath
if err := os.Setenv(key, value); err != nil {
t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err)
}
key = "GO111MODULE"
value = "on"
if err := os.Setenv(key, value); err != nil {
t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err)
}
pkgPath, err := parsePackageImport(srcDir)
expected := "example.com/foo"
if pkgPath != expected {
t.Errorf("expect %s, got %s", expected, pkgPath)
}
}

func TestParsePackageImport_FallbackMultiGoPath(t *testing.T) {
var goPathList []string

// first gopath
goPath, err := ioutil.TempDir("", "gopath1")
if err != nil {
t.Error(err)
}
goPathList = append(goPathList, goPath)
defer func() {
if err = os.RemoveAll(goPath); err != nil {
t.Error(err)
}
}()
srcDir := filepath.Join(goPath, "src/example.com/foo")
err = os.MkdirAll(srcDir, 0755)
if err != nil {
t.Error(err)
}

// second gopath
goPath, err = ioutil.TempDir("", "gopath2")
if err != nil {
t.Error(err)
}
goPathList = append(goPathList, goPath)
defer func() {
if err = os.RemoveAll(goPath); err != nil {
t.Error(err)
}
}()

goPaths := strings.Join(goPathList, string(os.PathListSeparator))
key := "GOPATH"
value := goPaths
if err := os.Setenv(key, value); err != nil {
t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err)
}
key = "GO111MODULE"
value = "on"
if err := os.Setenv(key, value); err != nil {
t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err)
}
pkgPath, err := parsePackageImport(srcDir)
expected := "example.com/foo"
if pkgPath != expected {
t.Errorf("expect %s, got %s", expected, pkgPath)
}
}
5 changes: 2 additions & 3 deletions mockgen/model/model.go
Expand Up @@ -71,7 +71,7 @@ func (intf *Interface) addImports(im map[string]bool) {
}
}

// AddMethod adds a new method, deduplicating by method name.
// AddMethod adds a new method, de-duplicating by method name.
func (intf *Interface) AddMethod(m *Method) {
for _, me := range intf.Methods {
if me.Name == m.Name {
Expand Down Expand Up @@ -260,11 +260,10 @@ func (mt *MapType) addImports(im map[string]bool) {
// NamedType is an exported type in a package.
type NamedType struct {
Package string // may be empty
Type string // TODO: should this be typed Type?
Type string
}

func (nt *NamedType) String(pm map[string]string, pkgOverride string) string {
// TODO: is this right?
if pkgOverride == nt.Package {
return nt.Type
}
Expand Down