diff --git a/pkg/generator.go b/pkg/generator.go index 52f9c40a..84c3d89b 100644 --- a/pkg/generator.go +++ b/pkg/generator.go @@ -103,7 +103,12 @@ func (g *Generator) addImportsFromTuple(ctx context.Context, list *types.Tuple) } } -func (g *Generator) addPackageScopedType(ctx context.Context, o *types.TypeName) string { +// getPackageScopedType returns the appropriate string representation for the +// object TypeName. The string may either be the unqualified name (in the case +// the mock will live in the same package as the interface being mocked, e.g. +// `Foo`) or the package pathname (in the case the type lives in a package +// external to the mock, e.g. `packagename.Foo`). +func (g *Generator) getPackageScopedType(ctx context.Context, o *types.TypeName) string { if o.Pkg() == nil || o.Pkg().Name() == "main" || (!g.KeepTree && g.InPackage && o.Pkg() == g.iface.Pkg) { return o.Name() } @@ -253,7 +258,17 @@ func (g *Generator) mockName() string { return g.maybeMakeNameExported(g.iface.Name, g.Exported) } -func (g *Generator) typeConstraints(ctx context.Context) string { +// getTypeConstraintString returns type constraint string for a given interface. +// For instance, a method using this constraint: +// +// func Foo[T Stringer](s []T) (ret []string) { +// +// } +// +// The constraint returned will be "[T Stringer]" +// +// https://go.googlesource.com/proposal/+/refs/heads/master/design/43651-type-parameters.md#type-parameters +func (g *Generator) getTypeConstraintString(ctx context.Context) string { tp := g.iface.NamedType.TypeParams() if tp == nil || tp.Len() == 0 { return "" @@ -266,7 +281,10 @@ func (g *Generator) typeConstraints(ctx context.Context) string { return fmt.Sprintf("[%s]", strings.Join(qualifiedParams, ", ")) } -func (g *Generator) typeParams() string { +// getInstantiatedTypeString returns the "instantiated" type names for a given +// constraint list. For instance, if your interface has the constraints +// `[S Stringer, I int, C Comparable]`, this method would return: `[S, I, C]` +func (g *Generator) getInstantiatedTypeString() string { tp := g.iface.NamedType.TypeParams() if tp == nil || tp.Len() == 0 { return "" @@ -382,7 +400,7 @@ type namer interface { func (g *Generator) renderType(ctx context.Context, typ types.Type) string { switch t := typ.(type) { case *types.Named: - name := g.addPackageScopedType(ctx, t.Obj()) + name := g.getPackageScopedType(ctx, t.Obj()) if t.TypeArgs() == nil || t.TypeArgs().Len() == 0 { return name } @@ -396,7 +414,7 @@ func (g *Generator) renderType(ctx context.Context, typ types.Type) string { if t.Constraint() != nil { return t.Obj().Name() } - return g.addPackageScopedType(ctx, t.Obj()) + return g.getPackageScopedType(ctx, t.Obj()) case *types.Basic: if t.Kind() == types.UnsafePointer { return "unsafe.Pointer" @@ -589,7 +607,7 @@ func (g *Generator) Generate(ctx context.Context) error { ) g.printf( - "type %s%s struct {\n\tmock.Mock\n}\n\n", g.mockName(), g.typeConstraints(ctx), + "type %s%s struct {\n\tmock.Mock\n}\n\n", g.mockName(), g.getTypeConstraintString(ctx), ) if g.WithExpecter { @@ -618,7 +636,7 @@ func (g *Generator) Generate(ctx context.Context) error { ) } g.printf( - "func (_m *%s%s) %s(%s) ", g.mockName(), g.typeParams(ctx), fname, + "func (_m *%s%s) %s(%s) ", g.mockName(), g.getInstantiatedTypeString(), fname, strings.Join(params.Params, ", "), ) @@ -805,7 +823,7 @@ func %[1]s%[3]s(t testing.TB) *%[2]s%[4]s { mockName := g.mockName() constructorName := g.maybeMakeNameExported("new"+g.makeNameExported(mockName), ast.IsExported(mockName)) - g.printf(constructor, constructorName, mockName, g.typeConstraints(ctx), g.typeParams()) + g.printf(constructor, constructorName, mockName, g.getTypeConstraintString(ctx), g.getInstantiatedTypeString()) } // generateCalled returns the Mock.Called invocation string and, if necessary, prints the diff --git a/pkg/generator_test.go b/pkg/generator_test.go index 7a096d72..d8edf023 100644 --- a/pkg/generator_test.go +++ b/pkg/generator_test.go @@ -1925,12 +1925,13 @@ func (_m *RequesterGenerics[TAny, TComparable, TSigned, TIntf, TExternalIntf, TG return r0 } -// NewRequesterGenerics creates a new instance of RequesterGenerics. It also registers a cleanup function to assert the mocks expectations. +// NewRequesterGenerics creates a new instance of RequesterGenerics. It also registers the testing.TB interface on the mock and a cleanup function to assert the mocks expectations. func NewRequesterGenerics[TAny interface{}, TComparable comparable, TSigned constraints.Signed, TIntf test.GetInt, TExternalIntf io.Writer, TGenIntf test.GetGeneric[TSigned], TInlineType interface{ ~int | ~uint }, TInlineTypeGeneric interface { ~int | test.GenericType[int, test.GetInt] comparable }](t testing.TB) *RequesterGenerics[TAny, TComparable, TSigned, TIntf, TExternalIntf, TGenIntf, TInlineType, TInlineTypeGeneric] { mock := &RequesterGenerics[TAny, TComparable, TSigned, TIntf, TExternalIntf, TGenIntf, TInlineType, TInlineTypeGeneric]{} + mock.Mock.Test(t) t.Cleanup(func() { mock.AssertExpectations(t) }) @@ -2006,12 +2007,13 @@ func (_m *MockRequesterGenerics[TAny, TComparable, TSigned, TIntf, TExternalIntf return r0 } -// NewMockRequesterGenerics creates a new instance of MockRequesterGenerics. It also registers a cleanup function to assert the mocks expectations. +// NewMockRequesterGenerics creates a new instance of MockRequesterGenerics. It also registers the testing.TB interface on the mock and a cleanup function to assert the mocks expectations. func NewMockRequesterGenerics[TAny interface{}, TComparable comparable, TSigned constraints.Signed, TIntf GetInt, TExternalIntf io.Writer, TGenIntf GetGeneric[TSigned], TInlineType interface{ ~int | ~uint }, TInlineTypeGeneric interface { ~int | GenericType[int, GetInt] comparable }](t testing.TB) *MockRequesterGenerics[TAny, TComparable, TSigned, TIntf, TExternalIntf, TGenIntf, TInlineType, TInlineTypeGeneric] { mock := &MockRequesterGenerics[TAny, TComparable, TSigned, TIntf, TExternalIntf, TGenIntf, TInlineType, TInlineTypeGeneric]{} + mock.Mock.Test(t) t.Cleanup(func() { mock.AssertExpectations(t) })