Skip to content
This repository has been archived by the owner on Jun 27, 2023. It is now read-only.

Commit

Permalink
refactor mockgen and cleanup (#536)
Browse files Browse the repository at this point in the history
  • Loading branch information
codyoss committed Feb 26, 2021
1 parent 58935d8 commit 7105dde
Show file tree
Hide file tree
Showing 9 changed files with 220 additions and 298 deletions.
38 changes: 19 additions & 19 deletions gomock/call.go
Expand Up @@ -50,16 +50,16 @@ func newCall(t TestHelper, receiver interface{}, method string, methodType refle
t.Helper()

// TODO: check arity, types.
margs := make([]Matcher, len(args))
mArgs := make([]Matcher, len(args))
for i, arg := range args {
if m, ok := arg.(Matcher); ok {
margs[i] = m
mArgs[i] = m
} else if arg == nil {
// Handle nil specially so that passing a nil interface value
// will match the typed nils of concrete args.
margs[i] = Nil()
mArgs[i] = Nil()
} else {
margs[i] = Eq(arg)
mArgs[i] = Eq(arg)
}
}

Expand All @@ -76,7 +76,7 @@ func newCall(t TestHelper, receiver interface{}, method string, methodType refle
return rets
}}
return &Call{t: t, receiver: receiver, method: method, methodType: methodType,
args: margs, origin: origin, minCalls: 1, maxCalls: 1, actions: actions}
args: mArgs, origin: origin, minCalls: 1, maxCalls: 1, actions: actions}
}

// AnyTimes allows the expectation to be called 0 or more times
Expand Down Expand Up @@ -113,19 +113,19 @@ func (c *Call) DoAndReturn(f interface{}) *Call {
v := reflect.ValueOf(f)

c.addAction(func(args []interface{}) []interface{} {
vargs := make([]reflect.Value, len(args))
vArgs := make([]reflect.Value, len(args))
ft := v.Type()
for i := 0; i < len(args); i++ {
if args[i] != nil {
vargs[i] = reflect.ValueOf(args[i])
vArgs[i] = reflect.ValueOf(args[i])
} else {
// Use the zero value for the arg.
vargs[i] = reflect.Zero(ft.In(i))
vArgs[i] = reflect.Zero(ft.In(i))
}
}
vrets := v.Call(vargs)
rets := make([]interface{}, len(vrets))
for i, ret := range vrets {
vRets := v.Call(vArgs)
rets := make([]interface{}, len(vRets))
for i, ret := range vRets {
rets[i] = ret.Interface()
}
return rets
Expand All @@ -142,17 +142,17 @@ func (c *Call) Do(f interface{}) *Call {
v := reflect.ValueOf(f)

c.addAction(func(args []interface{}) []interface{} {
vargs := make([]reflect.Value, len(args))
vArgs := make([]reflect.Value, len(args))
ft := v.Type()
for i := 0; i < len(args); i++ {
if args[i] != nil {
vargs[i] = reflect.ValueOf(args[i])
vArgs[i] = reflect.ValueOf(args[i])
} else {
// Use the zero value for the arg.
vargs[i] = reflect.Zero(ft.In(i))
vArgs[i] = reflect.Zero(ft.In(i))
}
}
v.Call(vargs)
v.Call(vArgs)
return nil
})
return c
Expand Down Expand Up @@ -353,12 +353,12 @@ func (c *Call) matches(args []interface{}) error {
// matches all the remaining arguments or the lack of any.
// Convert the remaining arguments, if any, into a slice of the
// expected type.
vargsType := c.methodType.In(c.methodType.NumIn() - 1)
vargs := reflect.MakeSlice(vargsType, 0, len(args)-i)
vArgsType := c.methodType.In(c.methodType.NumIn() - 1)
vArgs := reflect.MakeSlice(vArgsType, 0, len(args)-i)
for _, arg := range args[i:] {
vargs = reflect.Append(vargs, reflect.ValueOf(arg))
vArgs = reflect.Append(vArgs, reflect.ValueOf(arg))
}
if m.Matches(vargs.Interface()) {
if m.Matches(vArgs.Interface()) {
// Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, gomock.Any())
// Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, someSliceMatcher)
// Got Foo(a, b) want Foo(matcherA, matcherB, gomock.Any())
Expand Down
2 changes: 1 addition & 1 deletion gomock/callset_test.go
Expand Up @@ -84,7 +84,7 @@ func TestCallSetFindMatch(t *testing.T) {

c1 := newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func))
cs.exhausted = map[callSetKey][]*Call{
callSetKey{receiver: receiver, fname: method}: []*Call{c1},
{receiver: receiver, fname: method}: {c1},
}

_, err := cs.FindMatch(receiver, method, args)
Expand Down
1 change: 0 additions & 1 deletion gomock/matchers.go
Expand Up @@ -153,7 +153,6 @@ func (n notMatcher) Matches(x interface{}) bool {
}

func (n notMatcher) String() string {
// TODO: Improve this if we add a NotString method to the Matcher interface.
return "not(" + n.m.String() + ")"
}

Expand Down
50 changes: 44 additions & 6 deletions mockgen/mockgen.go
Expand Up @@ -38,6 +38,7 @@ import (

"github.com/golang/mock/mockgen/model"

"golang.org/x/mod/modfile"
toolsimports "golang.org/x/tools/imports"
)

Expand Down Expand Up @@ -84,6 +85,7 @@ func main() {
log.Fatal("Expected exactly two arguments")
}
packageName = flag.Arg(0)
interfaces := strings.Split(flag.Arg(1), ",")
if packageName == "." {
dir, err := os.Getwd()
if err != nil {
Expand All @@ -94,7 +96,7 @@ func main() {
log.Fatalf("Parse package name failed: %v", err)
}
}
pkg, err = reflectMode(packageName, strings.Split(flag.Arg(1), ","))
pkg, err = reflectMode(packageName, interfaces)
}
if err != nil {
log.Fatalf("Loading input failed: %v", err)
Expand Down Expand Up @@ -394,11 +396,6 @@ func (g *generator) GenerateMockInterface(intf *model.Interface, outputPackagePa
g.p("}")
g.p("")

// TODO: Re-enable this if we can import the interface reliably.
// g.p("// Verify that the mock satisfies the interface at compile time.")
// g.p("var _ %v = (*%v)(nil)", typeName, mockType)
// g.p("")

g.p("// New%v creates a new mock instance.", mockType)
g.p("func New%v(ctrl *gomock.Controller) *%v {", mockType, mockType)
g.in()
Expand Down Expand Up @@ -665,3 +662,44 @@ func printVersion() {
printModuleVersion()
}
}

// parseImportPackage get package import path via source file
// an alternative implementation is to use:
// cfg := &packages.Config{Mode: packages.NeedName, Tests: true, Dir: srcDir}
// pkgs, err := packages.Load(cfg, "file="+source)
// However, it will call "go list" and slow down the performance
func parsePackageImport(srcDir string) (string, error) {
moduleMode := os.Getenv("GO111MODULE")
// trying to find the module
if moduleMode != "off" {
currentDir := srcDir
for {
dat, err := ioutil.ReadFile(filepath.Join(currentDir, "go.mod"))
if os.IsNotExist(err) {
if currentDir == filepath.Dir(currentDir) {
// at the root
break
}
currentDir = filepath.Dir(currentDir)
continue
} else if err != nil {
return "", err
}
modulePath := modfile.ModulePath(dat)
return filepath.ToSlash(filepath.Join(modulePath, strings.TrimPrefix(srcDir, currentDir))), nil
}
}
// fall back to GOPATH mode
goPaths := os.Getenv("GOPATH")
if goPaths == "" {
return "", fmt.Errorf("GOPATH is not set")
}
goPathList := strings.Split(goPaths, string(os.PathListSeparator))
for _, goPath := range goPathList {
sourceRoot := filepath.Join(goPath, "src") + string(os.PathSeparator)
if strings.HasPrefix(srcDir, sourceRoot) {
return filepath.ToSlash(strings.TrimPrefix(srcDir, sourceRoot)), nil
}
}
return "", errOutsideGoPath
}
85 changes: 85 additions & 0 deletions mockgen/mockgen_test.go
Expand Up @@ -2,6 +2,9 @@ package main

import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"reflect"
"regexp"
"strings"
Expand Down Expand Up @@ -364,3 +367,85 @@ func Test_createPackageMap(t *testing.T) {
})
}
}

func TestParsePackageImport_FallbackGoPath(t *testing.T) {
goPath, err := ioutil.TempDir("", "gopath")
if err != nil {
t.Error(err)
}
defer func() {
if err = os.RemoveAll(goPath); err != nil {
t.Error(err)
}
}()
srcDir := filepath.Join(goPath, "src/example.com/foo")
err = os.MkdirAll(srcDir, 0755)
if err != nil {
t.Error(err)
}
key := "GOPATH"
value := goPath
if err := os.Setenv(key, value); err != nil {
t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err)
}
key = "GO111MODULE"
value = "on"
if err := os.Setenv(key, value); err != nil {
t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err)
}
pkgPath, err := parsePackageImport(srcDir)
expected := "example.com/foo"
if pkgPath != expected {
t.Errorf("expect %s, got %s", expected, pkgPath)
}
}

func TestParsePackageImport_FallbackMultiGoPath(t *testing.T) {
var goPathList []string

// first gopath
goPath, err := ioutil.TempDir("", "gopath1")
if err != nil {
t.Error(err)
}
goPathList = append(goPathList, goPath)
defer func() {
if err = os.RemoveAll(goPath); err != nil {
t.Error(err)
}
}()
srcDir := filepath.Join(goPath, "src/example.com/foo")
err = os.MkdirAll(srcDir, 0755)
if err != nil {
t.Error(err)
}

// second gopath
goPath, err = ioutil.TempDir("", "gopath2")
if err != nil {
t.Error(err)
}
goPathList = append(goPathList, goPath)
defer func() {
if err = os.RemoveAll(goPath); err != nil {
t.Error(err)
}
}()

goPaths := strings.Join(goPathList, string(os.PathListSeparator))
key := "GOPATH"
value := goPaths
if err := os.Setenv(key, value); err != nil {
t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err)
}
key = "GO111MODULE"
value = "on"
if err := os.Setenv(key, value); err != nil {
t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err)
}
pkgPath, err := parsePackageImport(srcDir)
expected := "example.com/foo"
if pkgPath != expected {
t.Errorf("expect %s, got %s", expected, pkgPath)
}
}
5 changes: 2 additions & 3 deletions mockgen/model/model.go
Expand Up @@ -71,7 +71,7 @@ func (intf *Interface) addImports(im map[string]bool) {
}
}

// AddMethod adds a new method, deduplicating by method name.
// AddMethod adds a new method, de-duplicating by method name.
func (intf *Interface) AddMethod(m *Method) {
for _, me := range intf.Methods {
if me.Name == m.Name {
Expand Down Expand Up @@ -260,11 +260,10 @@ func (mt *MapType) addImports(im map[string]bool) {
// NamedType is an exported type in a package.
type NamedType struct {
Package string // may be empty
Type string // TODO: should this be typed Type?
Type string
}

func (nt *NamedType) String(pm map[string]string, pkgOverride string) string {
// TODO: is this right?
if pkgOverride == nt.Package {
return nt.Type
}
Expand Down

0 comments on commit 7105dde

Please sign in to comment.