Skip to content

Commit

Permalink
Generics mocking fixes
Browse files Browse the repository at this point in the history
- Add support for rendering union types
- Add support for rendering the type arguments of named types
- Add support for rendering types embedded in interfaces
- Test generic generation against a single more complicated interface
- Fix package namespacing on type arguments
  • Loading branch information
Paul Cruickshank committed Apr 22, 2022
1 parent a411f3d commit a129a32
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 186 deletions.
9 changes: 9 additions & 0 deletions pkg/fixtures/constraints/constraints.go
@@ -0,0 +1,9 @@
package constraints

type Signed interface {
~int
}

type Integer interface {
~int
}
43 changes: 30 additions & 13 deletions pkg/fixtures/generic.go
@@ -1,21 +1,38 @@
package test

type Constraint interface {
int
}
import (
"io"

type Generic[T Constraint] interface {
Get() T
}
"github.com/vektra/mockery/v2/pkg/fixtures/constraints"
)

type GenericAny[T any] interface {
Get() T
type RequesterGenerics[
TAny any,
TComparable comparable,
TSigned constraints.Signed, // external constraint
TIntf GetInt, // internal interface
TExternalIntf io.Writer, // external interface
TGenIntf GetGeneric[TSigned], // generic interface
TInlineType interface{ ~int | ~uint }, // inlined interface constraints
TInlineTypeGeneric interface {
~int | GenericType[int, GetInt]
comparable
}, // inlined type constraints
] interface {
GenericArguments(TAny, TComparable) (TSigned, TIntf)
GenericStructs(GenericType[TAny, TIntf]) GenericType[TSigned, TIntf]
GenericAnonymousStructs(struct{ Type1 TExternalIntf }) struct {
Type2 GenericType[string, EmbeddedGet[int]]
}
}

type GenericComparable[T comparable] interface {
Get() T
type GenericType[T any, S GetInt] struct {
Any T
Some []S
}

type Embedded interface {
Generic[int]
}
type GetInt interface{ Get() int }

type GetGeneric[T constraints.Integer] interface{ Get() T }

type EmbeddedGet[T constraints.Signed] interface{ GetGeneric[T] }
42 changes: 33 additions & 9 deletions pkg/generator.go
Expand Up @@ -259,12 +259,7 @@ func (g *Generator) typeConstraints(ctx context.Context) string {
qualifiedParams := make([]string, 0, tp.Len())
for i := 0; i < tp.Len(); i++ {
param := tp.At(i)
switch constraint := param.Constraint().(type) {
case *types.Named:
qualifiedParams = append(qualifiedParams, fmt.Sprintf("%s %s", param.String(), g.addPackageScopedType(ctx, constraint.Obj())))
case *types.Interface:
qualifiedParams = append(qualifiedParams, fmt.Sprintf("%s %s", param.String(), constraint.String()))
}
qualifiedParams = append(qualifiedParams, fmt.Sprintf("%s %s", param.String(), g.renderType(ctx, param.Constraint())))
}
return fmt.Sprintf("[%s]", strings.Join(qualifiedParams, ", "))
}
Expand Down Expand Up @@ -385,7 +380,16 @@ type namer interface {
func (g *Generator) renderType(ctx context.Context, typ types.Type) string {
switch t := typ.(type) {
case *types.Named:
return g.addPackageScopedType(ctx, t.Obj())
name := g.addPackageScopedType(ctx, t.Obj())
if t.TypeArgs() == nil || t.TypeArgs().Len() == 0 {
return name
}
args := make([]string, 0, t.TypeArgs().Len())
for i := 0; i < t.TypeArgs().Len(); i++ {
arg := t.TypeArgs().At(i)
args = append(args, g.renderType(ctx, arg))
}
return fmt.Sprintf("%s[%s]", name, strings.Join(args, ","))
case *types.TypeParam:
if t.Constraint() != nil {
return t.Obj().Name()
Expand Down Expand Up @@ -452,7 +456,27 @@ func (g *Generator) renderType(ctx context.Context, typ types.Type) string {
panic("Unable to mock inline interfaces with methods")
}

return "interface{}"
rv := []string{"interface{"}
for i := 0; i < t.NumEmbeddeds(); i++ {
rv = append(rv, g.renderType(ctx, t.EmbeddedType(i)))
}
rv = append(rv, "}")
sep := ""
if t.NumEmbeddeds() > 1 {
sep = "\n"
}
return strings.Join(rv, sep)
case *types.Union:
rv := make([]string, 0, t.Len())
for i := 0; i < t.Len(); i++ {
term := t.Term(i)
if term.Tilde() {
rv = append(rv, "~"+g.renderType(ctx, term.Type()))
} else {
rv = append(rv, g.renderType(ctx, term.Type()))
}
}
return strings.Join(rv, "|")
case namer:
return t.Name()
default:
Expand Down Expand Up @@ -589,7 +613,7 @@ func (g *Generator) Generate(ctx context.Context) error {
)
}
g.printf(
"func (_m *%s%s) %s(%s) ", g.mockName(), g.typeParams(), fname,
"func (_m *%s%s) %s(%s) ", g.mockName(), g.typeParams(ctx), fname,
strings.Join(params.Params, ", "),
)

Expand Down

0 comments on commit a129a32

Please sign in to comment.