Skip to content
This repository has been archived by the owner on Jun 27, 2023. It is now read-only.

feat validate Do & DoReturn args #558

Merged
merged 1 commit into from May 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions gomock/call.go
Expand Up @@ -113,8 +113,14 @@ func (c *Call) DoAndReturn(f interface{}) *Call {
v := reflect.ValueOf(f)

c.addAction(func(args []interface{}) []interface{} {
c.t.Helper()
vArgs := make([]reflect.Value, len(args))
ft := v.Type()
if c.methodType.NumIn() != ft.NumIn() {
c.t.Fatalf("wrong number of arguments in DoAndReturn func for %T.%v: got %d, want %d [%s]",
c.receiver, c.method, ft.NumIn(), c.methodType.NumIn(), c.origin)
return nil
}
for i := 0; i < len(args); i++ {
if args[i] != nil {
vArgs[i] = reflect.ValueOf(args[i])
Expand Down Expand Up @@ -142,6 +148,12 @@ func (c *Call) Do(f interface{}) *Call {
v := reflect.ValueOf(f)

c.addAction(func(args []interface{}) []interface{} {
c.t.Helper()
if c.methodType.NumIn() != v.Type().NumIn() {
c.t.Fatalf("wrong number of arguments in Do func for %T.%v: got %d, want %d [%s]",
c.receiver, c.method, v.Type().NumIn(), c.methodType.NumIn(), c.origin)
return nil
}
vArgs := make([]reflect.Value, len(args))
ft := v.Type()
for i := 0; i < len(args); i++ {
Expand Down
102 changes: 100 additions & 2 deletions gomock/call_test.go
Expand Up @@ -142,7 +142,7 @@ var testCases []testCase = []testCase{
doFunc: func(x int) {},
callFunc: func(x int, y int) {},
args: []interface{}{0, 1},
expectPanic: true,
expectPanic: false,
}, {
description: "number of args for Do func don't match Call func",
doFunc: func(x int) bool {
Expand All @@ -152,7 +152,7 @@ var testCases []testCase = []testCase{
return true
},
args: []interface{}{0, 1},
expectPanic: true,
expectPanic: false,
}, {
description: "arg type for Do func incompatible with Call func",
doFunc: func(x int) {},
Expand Down Expand Up @@ -481,6 +481,104 @@ func TestCall_Do(t *testing.T) {
}
}

func TestCall_Do_NumArgValidation(t *testing.T) {
tests := []struct {
name string
methodType reflect.Type
doFn interface{}
args []interface{}
wantErr bool
}{
{
name: "too few",
methodType: reflect.TypeOf(func(one, two string) {}),
doFn: func(one string) {},
args: []interface{}{"too", "few"},
wantErr: true,
},
{
name: "too many",
methodType: reflect.TypeOf(func(one, two string) {}),
doFn: func(one, two, three string) {},
args: []interface{}{"too", "few"},
wantErr: true,
},
{
name: "just right",
methodType: reflect.TypeOf(func(one, two string) {}),
doFn: func(one string, two string) {},
args: []interface{}{"just", "right"},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tr := &mockTestReporter{}
call := &Call{
t: tr,
methodType: tt.methodType,
}
call.Do(tt.doFn)
call.actions[0](tt.args)
if tt.wantErr && tr.fatalCalls != 1 {
t.Fatalf("expected call to fail")
}
if !tt.wantErr && tr.fatalCalls != 0 {
t.Fatalf("expected call to pass")
}
})
}
}

func TestCall_DoAndReturn_NumArgValidation(t *testing.T) {
tests := []struct {
name string
methodType reflect.Type
doFn interface{}
args []interface{}
wantErr bool
}{
{
name: "too few",
methodType: reflect.TypeOf(func(one, two string) string { return "" }),
doFn: func(one string) {},
args: []interface{}{"too", "few"},
wantErr: true,
},
{
name: "too many",
methodType: reflect.TypeOf(func(one, two string) string { return "" }),
doFn: func(one, two, three string) string { return "" },
args: []interface{}{"too", "few"},
wantErr: true,
},
{
name: "just right",
methodType: reflect.TypeOf(func(one, two string) string { return "" }),
doFn: func(one string, two string) string { return "" },
args: []interface{}{"just", "right"},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tr := &mockTestReporter{}
call := &Call{
t: tr,
methodType: tt.methodType,
}
call.DoAndReturn(tt.doFn)
call.actions[0](tt.args)
if tt.wantErr && tr.fatalCalls != 1 {
t.Fatalf("expected call to fail")
}
if !tt.wantErr && tr.fatalCalls != 0 {
t.Fatalf("expected call to pass")
}
})
}
}

func TestCall_DoAndReturn(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
Expand Down