diff --git a/gomock/call.go b/gomock/call.go index 3f77be4e..d921a799 100644 --- a/gomock/call.go +++ b/gomock/call.go @@ -113,8 +113,14 @@ func (c *Call) DoAndReturn(f interface{}) *Call { v := reflect.ValueOf(f) c.addAction(func(args []interface{}) []interface{} { + c.t.Helper() vArgs := make([]reflect.Value, len(args)) ft := v.Type() + if c.methodType.NumIn() != ft.NumIn() { + c.t.Fatalf("wrong number of arguments in DoAndReturn func for %T.%v: got %d, want %d [%s]", + c.receiver, c.method, ft.NumIn(), c.methodType.NumIn(), c.origin) + return nil + } for i := 0; i < len(args); i++ { if args[i] != nil { vArgs[i] = reflect.ValueOf(args[i]) @@ -142,6 +148,12 @@ func (c *Call) Do(f interface{}) *Call { v := reflect.ValueOf(f) c.addAction(func(args []interface{}) []interface{} { + c.t.Helper() + if c.methodType.NumIn() != v.Type().NumIn() { + c.t.Fatalf("wrong number of arguments in Do func for %T.%v: got %d, want %d [%s]", + c.receiver, c.method, v.Type().NumIn(), c.methodType.NumIn(), c.origin) + return nil + } vArgs := make([]reflect.Value, len(args)) ft := v.Type() for i := 0; i < len(args); i++ { diff --git a/gomock/call_test.go b/gomock/call_test.go index 49e6986a..9483c4f5 100644 --- a/gomock/call_test.go +++ b/gomock/call_test.go @@ -142,7 +142,7 @@ var testCases []testCase = []testCase{ doFunc: func(x int) {}, callFunc: func(x int, y int) {}, args: []interface{}{0, 1}, - expectPanic: true, + expectPanic: false, }, { description: "number of args for Do func don't match Call func", doFunc: func(x int) bool { @@ -152,7 +152,7 @@ var testCases []testCase = []testCase{ return true }, args: []interface{}{0, 1}, - expectPanic: true, + expectPanic: false, }, { description: "arg type for Do func incompatible with Call func", doFunc: func(x int) {}, @@ -481,6 +481,104 @@ func TestCall_Do(t *testing.T) { } } +func TestCall_Do_NumArgValidation(t *testing.T) { + tests := []struct { + name string + methodType reflect.Type + doFn interface{} + args []interface{} + wantErr bool + }{ + { + name: "too few", + methodType: reflect.TypeOf(func(one, two string) {}), + doFn: func(one string) {}, + args: []interface{}{"too", "few"}, + wantErr: true, + }, + { + name: "too many", + methodType: reflect.TypeOf(func(one, two string) {}), + doFn: func(one, two, three string) {}, + args: []interface{}{"too", "few"}, + wantErr: true, + }, + { + name: "just right", + methodType: reflect.TypeOf(func(one, two string) {}), + doFn: func(one string, two string) {}, + args: []interface{}{"just", "right"}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tr := &mockTestReporter{} + call := &Call{ + t: tr, + methodType: tt.methodType, + } + call.Do(tt.doFn) + call.actions[0](tt.args) + if tt.wantErr && tr.fatalCalls != 1 { + t.Fatalf("expected call to fail") + } + if !tt.wantErr && tr.fatalCalls != 0 { + t.Fatalf("expected call to pass") + } + }) + } +} + +func TestCall_DoAndReturn_NumArgValidation(t *testing.T) { + tests := []struct { + name string + methodType reflect.Type + doFn interface{} + args []interface{} + wantErr bool + }{ + { + name: "too few", + methodType: reflect.TypeOf(func(one, two string) string { return "" }), + doFn: func(one string) {}, + args: []interface{}{"too", "few"}, + wantErr: true, + }, + { + name: "too many", + methodType: reflect.TypeOf(func(one, two string) string { return "" }), + doFn: func(one, two, three string) string { return "" }, + args: []interface{}{"too", "few"}, + wantErr: true, + }, + { + name: "just right", + methodType: reflect.TypeOf(func(one, two string) string { return "" }), + doFn: func(one string, two string) string { return "" }, + args: []interface{}{"just", "right"}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tr := &mockTestReporter{} + call := &Call{ + t: tr, + methodType: tt.methodType, + } + call.DoAndReturn(tt.doFn) + call.actions[0](tt.args) + if tt.wantErr && tr.fatalCalls != 1 { + t.Fatalf("expected call to fail") + } + if !tt.wantErr && tr.fatalCalls != 0 { + t.Fatalf("expected call to pass") + } + }) + } +} + func TestCall_DoAndReturn(t *testing.T) { for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) {