diff --git a/mockgen/internal/tests/const_array_length/input.go b/mockgen/internal/tests/const_array_length/input.go index 008d57e3..95576e7b 100644 --- a/mockgen/internal/tests/const_array_length/input.go +++ b/mockgen/internal/tests/const_array_length/input.go @@ -1,5 +1,7 @@ package const_length +import "math" + //go:generate mockgen -package const_length -destination mock.go -source input.go const C = 2 @@ -7,4 +9,5 @@ const C = 2 type I interface { Foo() [C]int Bar() [2]int + Baz() [math.MaxInt8]int } diff --git a/mockgen/internal/tests/const_array_length/mock.go b/mockgen/internal/tests/const_array_length/mock.go index 71853b7e..8107f641 100644 --- a/mockgen/internal/tests/const_array_length/mock.go +++ b/mockgen/internal/tests/const_array_length/mock.go @@ -47,6 +47,20 @@ func (mr *MockIMockRecorder) Bar() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Bar", reflect.TypeOf((*MockI)(nil).Bar)) } +// Baz mocks base method. +func (m *MockI) Baz() [127]int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Baz") + ret0, _ := ret[0].([127]int) + return ret0 +} + +// Baz indicates an expected call of Baz. +func (mr *MockIMockRecorder) Baz() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Baz", reflect.TypeOf((*MockI)(nil).Baz)) +} + // Foo mocks base method. func (m *MockI) Foo() [2]int { m.ctrl.T.Helper() diff --git a/mockgen/parse.go b/mockgen/parse.go index 9975ae10..bf6902cd 100644 --- a/mockgen/parse.go +++ b/mockgen/parse.go @@ -22,8 +22,10 @@ import ( "fmt" "go/ast" "go/build" + "go/importer" "go/parser" "go/token" + "go/types" "io/ioutil" "log" "path" @@ -409,8 +411,19 @@ func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) { case (*ast.BasicLit): value = val.Value case (*ast.Ident): - // when the length is a const + // when the length is a const defined locally value = val.Obj.Decl.(*ast.ValueSpec).Values[0].(*ast.BasicLit).Value + case (*ast.SelectorExpr): + // when the length is a const defined in an external package + usedPkg, err := importer.Default().Import(fmt.Sprintf("%s", val.X)) + if err != nil { + return nil, p.errorf(v.Len.Pos(), "unknown package in array length: %v", err) + } + ev, err := types.Eval(token.NewFileSet(), usedPkg, token.NoPos, val.Sel.Name) + if err != nil { + return nil, p.errorf(v.Len.Pos(), "unknown constant in array length: %v", err) + } + value = ev.Value.String() } x, err := strconv.Atoi(value) diff --git a/mockgen/parse_test.go b/mockgen/parse_test.go index a7ea9f82..7e224183 100644 --- a/mockgen/parse_test.go +++ b/mockgen/parse_test.go @@ -116,14 +116,19 @@ func Benchmark_parseFile(b *testing.B) { func TestParseArrayWithConstLength(t *testing.T) { fs := token.NewFileSet() + srcDir := "internal/tests/const_array_length/input.go" - file, err := parser.ParseFile(fs, "internal/tests/const_array_length/input.go", nil, 0) + file, err := parser.ParseFile(fs, srcDir, nil, 0) if err != nil { t.Fatalf("Unexpected error: %v", err) } p := fileParser{ - fileSet: fs, + fileSet: fs, + imports: make(map[string]importedPackage), + importedInterfaces: make(map[string]map[string]*ast.InterfaceType), + auxInterfaces: make(map[string]map[string]*ast.InterfaceType), + srcDir: srcDir, } pkg, err := p.parseFile("", file) @@ -131,9 +136,11 @@ func TestParseArrayWithConstLength(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - expect := "[2]int" - got := pkg.Interfaces[0].Methods[0].Out[0].Type.String(nil, "") - if got != expect { - t.Fatalf("got %v; expected %v", got, expect) + expects := []string{"[2]int", "[2]int", "[127]int"} + for i, e := range expects { + got := pkg.Interfaces[0].Methods[i].Out[0].Type.String(nil, "") + if got != e { + t.Fatalf("got %v; expected %v", got, e) + } } }