diff --git a/pkg/fixtures/example_project/bar/foo/client.go b/pkg/fixtures/example_project/bar/foo/client.go new file mode 100644 index 00000000..dc2e6c74 --- /dev/null +++ b/pkg/fixtures/example_project/bar/foo/client.go @@ -0,0 +1,5 @@ +package foo + +type Client interface { + Search(query string) ([]string, error) +} diff --git a/pkg/fixtures/example_project/foo/collision.go b/pkg/fixtures/example_project/foo/collision.go new file mode 100644 index 00000000..1254e9fc --- /dev/null +++ b/pkg/fixtures/example_project/foo/collision.go @@ -0,0 +1,7 @@ +package foo + +import "github.com/vektra/mockery/v2/pkg/fixtures/example_project/bar/foo" + +type Collision interface { + NewClient() foo.Client +} diff --git a/pkg/generator.go b/pkg/generator.go index 5c64277e..1c0fc5ae 100644 --- a/pkg/generator.go +++ b/pkg/generator.go @@ -103,7 +103,10 @@ func (g *Generator) addPackageImportWithName(ctx context.Context, path, name str func (g *Generator) getNonConflictingName(path, name string) string { if !g.importNameExists(name) { - return name + // do not allow imports with the same name as the package when inPackage + if !g.InPackage || g.iface.Pkg.Name() != name { + return name + } } // The path will always contain '/' because it is enforced in getLocalizedPath @@ -120,7 +123,10 @@ func (g *Generator) getNonConflictingName(path, name string) string { for i := 1; i <= numDirectories; i++ { prospectiveName = strings.Join(cleanedDirectories[numDirectories-i:], "") if !g.importNameExists(prospectiveName) { - return prospectiveName + // do not allow imports with the same name as the package when inPackage + if !g.InPackage || g.iface.Pkg.Name() != prospectiveName { + return prospectiveName + } } } // Try adding numbers to the given name diff --git a/pkg/generator_test.go b/pkg/generator_test.go index d92deddb..530b57c7 100644 --- a/pkg/generator_test.go +++ b/pkg/generator_test.go @@ -1402,6 +1402,22 @@ import mock "github.com/stretchr/testify/mock" } } +func (s *GeneratorSuite) TestInPackagePackageCollision() { + expected := `package foo + +import barfoo "github.com/vektra/mockery/v2/pkg/fixtures/example_project/bar/foo" +import mock "github.com/stretchr/testify/mock" + +` + generator := NewGenerator( + s.ctx, + config.Config{InPackage: true, LogLevel: "debug"}, + s.getInterfaceFromFile("example_project/foo/collision.go", "Collision"), + pkg, + ) + s.checkPrologueGeneration(generator, expected) +} + func TestGeneratorSuite(t *testing.T) { generatorSuite := new(GeneratorSuite) suite.Run(t, generatorSuite)