From c99b33ac959a871dea49458597b2fba343a80cc5 Mon Sep 17 00:00:00 2001 From: Georgii Kliukovkin Date: Fri, 24 Jun 2022 10:18:54 +0300 Subject: [PATCH] add interfaces flag --- mockgen/mockgen.go | 34 ++++++++++++++ mockgen/mockgen_test.go | 101 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 135 insertions(+) diff --git a/mockgen/mockgen.go b/mockgen/mockgen.go index 72a1c960..b52c41b9 100644 --- a/mockgen/mockgen.go +++ b/mockgen/mockgen.go @@ -60,6 +60,7 @@ var ( selfPackage = flag.String("self_package", "", "The full package import path for the generated code. The purpose of this flag is to prevent import cycles in the generated code by trying to include its own package. This can happen if the mock's package is set to one of its inputs (usually the main one) and the output is stdio so mockgen cannot detect the final output package. Setting this flag will then tell mockgen which import to exclude.") writePkgComment = flag.Bool("write_package_comment", true, "Writes package documentation comment (godoc) if true.") copyrightFile = flag.String("copyright_file", "", "Copyright file used to add copyright header") + interfaces = flag.String("interfaces", "", "List of interfaces to generate mocks for; if empty, mockgen will generate mocks for all interfaces found in the input file(s).") debugParser = flag.Bool("debug_parser", false, "Print out parser results only.") showVersion = flag.Bool("version", false, "Print version.") @@ -107,6 +108,14 @@ func main() { return } + if len(*interfaces) > 0 { + ifaces := strings.Split(*interfaces, ",") + if pkg.Interfaces, err = filterInterfaces(pkg.Interfaces, ifaces); err != nil { + log.Fatalf("Filtering interfaces failed: %v", err) + } + + } + dst := os.Stdout if len(*destination) > 0 { if err := os.MkdirAll(filepath.Dir(*destination), os.ModePerm); err != nil { @@ -725,3 +734,28 @@ func parsePackageImport(srcDir string) (string, error) { } return "", errOutsideGoPath } + +func filterInterfaces(all []*model.Interface, requested []string) ([]*model.Interface, error) { + if len(requested) == 0 { + return nil, fmt.Errorf("no interfaces requested, other provide them or remove flag -interfaces") + } + requestedIfaces := make(map[string]struct{}) + for _, iface := range requested { + requestedIfaces[iface] = struct{}{} + } + result := make([]*model.Interface, 0, len(all)) + for _, iface := range all { + if _, ok := requestedIfaces[iface.Name]; ok { + result = append(result, iface) + delete(requestedIfaces, iface.Name) + } + } + if len(requestedIfaces) > 0 { + var missing []string + for iface := range requestedIfaces { + missing = append(missing, iface) + } + return nil, fmt.Errorf("missing interfaces: %s", strings.Join(missing, ", ")) + } + return result, nil +} diff --git a/mockgen/mockgen_test.go b/mockgen/mockgen_test.go index 55566001..c9f4d654 100644 --- a/mockgen/mockgen_test.go +++ b/mockgen/mockgen_test.go @@ -449,3 +449,104 @@ func TestParsePackageImport_FallbackMultiGoPath(t *testing.T) { t.Errorf("expect %s, got %s", expected, pkgPath) } } + +func Test_filterInterfaces1(t *testing.T) { + type args struct { + all []*model.Interface + requested []string + } + tests := []struct { + name string + args args + want []*model.Interface + wantErr bool + }{ + { + name: "no filter", + args: args{ + all: []*model.Interface{ + { + Name: "Foo", + }, + { + Name: "Bar", + }, + }, + requested: []string{}, + }, + want: nil, + wantErr: true, + }, + { + name: "filter by Foo", + args: args{ + all: []*model.Interface{ + { + Name: "Foo", + }, + { + Name: "Bar", + }, + }, + requested: []string{"Foo"}, + }, + want: []*model.Interface{ + { + Name: "Foo", + }, + }, + wantErr: false, + }, + { + name: "filter by Foo and Bar", + args: args{ + all: []*model.Interface{ + { + Name: "Foo", + }, + { + Name: "Bar", + }, + }, + requested: []string{"Foo", "Bar"}, + }, + want: []*model.Interface{ + { + Name: "Foo", + }, + { + Name: "Bar", + }, + }, + wantErr: false, + }, + { + name: "incorrect filter by Foo and Baz", + args: args{ + all: []*model.Interface{ + { + Name: "Foo", + }, + { + Name: "Bar", + }, + }, + requested: []string{"Foo", "Baz"}, + }, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := filterInterfaces(tt.args.all, tt.args.requested) + if (err != nil) != tt.wantErr { + t.Errorf("filterInterfaces() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("filterInterfaces() got = %v, want %v", got, tt.want) + } + }) + } +}