Skip to content
This repository has been archived by the owner on Jun 27, 2023. It is now read-only.

Commit

Permalink
support embedded generic interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
n0trace authored and wangyufeng04 committed May 8, 2023
1 parent 6f3e5ba commit 6adf1d1
Show file tree
Hide file tree
Showing 9 changed files with 985 additions and 234 deletions.
39 changes: 34 additions & 5 deletions mockgen/generic_go118.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
package main

import (
"fmt"
"go/ast"
"strings"

Expand All @@ -24,25 +25,25 @@ func getTypeSpecTypeParams(ts *ast.TypeSpec) []*ast.Field {
return ts.TypeParams.List
}

func (p *fileParser) parseGenericType(pkg string, typ ast.Expr, tps map[string]bool) (model.Type, error) {
func (p *fileParser) parseGenericType(pkg string, typ ast.Expr, tps map[string]bool, tm *instTypeMatcher) (model.Type, error) {
switch v := typ.(type) {
case *ast.IndexExpr:
m, err := p.parseType(pkg, v.X, tps)
m, err := p.parseType(pkg, v.X, tps, tm)
if err != nil {
return nil, err
}
nm, ok := m.(*model.NamedType)
if !ok {
return m, nil
}
t, err := p.parseType(pkg, v.Index, tps)
t, err := p.parseType(pkg, v.Index, tps, tm)
if err != nil {
return nil, err
}
nm.TypeParams = &model.TypeParametersType{TypeParameters: []model.Type{t}}
return m, nil
case *ast.IndexListExpr:
m, err := p.parseType(pkg, v.X, tps)
m, err := p.parseType(pkg, v.X, tps, tm)
if err != nil {
return nil, err
}
Expand All @@ -52,7 +53,7 @@ func (p *fileParser) parseGenericType(pkg string, typ ast.Expr, tps map[string]b
}
var ts []model.Type
for _, expr := range v.Indices {
t, err := p.parseType(pkg, expr, tps)
t, err := p.parseType(pkg, expr, tps, tm)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -86,3 +87,31 @@ func getIdentTypeParams(decl interface{}) string {
sb.WriteString("]")
return sb.String()
}

func (p *fileParser) parseGenericMethods(field *ast.Field, it *namedInterface, iface *model.Interface, pkg string, tps map[string]bool) ([]*model.Method, error) {
var indices []ast.Expr
var typ ast.Expr
switch v := field.Type.(type) {
case *ast.IndexExpr:
indices = []ast.Expr{v.Index}
typ = v.X
case *ast.IndexListExpr:
indices = v.Indices
typ = v.X
default:
return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type)
}

nf := &ast.Field{
Doc: field.Comment,
Names: field.Names,
Type: typ,
Tag: field.Tag,
Comment: field.Comment,
}

it.instTypeParams = indices
it.pkg = pkg

return p.parseMethods(nf, it, iface, pkg, tps)
}
7 changes: 6 additions & 1 deletion mockgen/generic_notgo118.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package main

import (
"fmt"
"go/ast"

"github.com/golang/mock/mockgen/model"
Expand All @@ -27,10 +28,14 @@ func getTypeSpecTypeParams(ts *ast.TypeSpec) []*ast.Field {
return nil
}

func (p *fileParser) parseGenericType(pkg string, typ ast.Expr, tps map[string]bool) (model.Type, error) {
func (p *fileParser) parseGenericType(pkg string, typ ast.Expr, tps map[string]bool, tm *instTypeMatcher) (model.Type, error) {
return nil, nil
}

func getIdentTypeParams(decl interface{}) string {
return ""
}

func (p *fileParser) parseGenericMethods(field *ast.Field, it *namedInterface, iface *model.Interface, pkg string, tps map[string]bool) ([]*model.Method, error) {
return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type)
}
49 changes: 47 additions & 2 deletions mockgen/internal/tests/generics/external.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package generics

import (
"context"

"github.com/golang/mock/mockgen/internal/tests/generics/other"
"golang.org/x/exp/constraints"
)
Expand All @@ -20,6 +22,49 @@ type ExternalConstraint[I constraints.Integer, F constraints.Float] interface {
Ten(*I)
}

type TwentyTwo[T any] interface {
TwentyTwo() T
type EmbeddingIface[T constraints.Integer, R constraints.Float] interface {
other.Twenty[T, StructType, R, other.Five]
TwentyTwo[StructType]
other.TwentyThree[TwentyTwo[R], TwentyTwo[T]]
TwentyFour[other.StructType]
Foo() error
ExternalConstraint[T, R]
}

type TwentyOne[T any] interface {
TwentyOne() T
}

type TwentyFour[T other.StructType] interface {
TwentyFour() T
}

type Clonable[T any] interface {
Clone() T
}

type Finder[T Clonable[T]] interface {
Find(ctx context.Context) ([]T, error)
}

type UpdateNotifier[T any] interface {
NotifyC(ctx context.Context) <-chan []T

Refresh(ctx context.Context)
}

type EmbeddedW[W StructType] interface {
EmbeddedY[W]
}

type EmbeddedX[X StructType] interface {
EmbeddedY[X]
}

type EmbeddedY[Y StructType] interface {
EmbeddedZ[Y]
}

type EmbeddedZ[Z any] interface {
EmbeddedZ(Z)
}
12 changes: 5 additions & 7 deletions mockgen/internal/tests/generics/generics.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,7 @@ type Bar[T any, R any] interface {
Seventeen() (*Foo[other.Three, other.Four], error)
Eighteen() (Iface[*other.Five], error)
Nineteen() AliasType
other.Twenty[T]
TwentyOne[T]
TwentyTwo[T]
}

type TwentyOne[T any] interface {
TwentyOne() T
// other.Twenty[any, any, any, *other.Four]
}

type Foo[T any, R any] struct{}
Expand All @@ -45,3 +39,7 @@ type StructType struct{}
type StructType2 struct{}

type AliasType Baz[other.Three]

type TwentyTwo[T any] interface {
TwentyTwo() T
}
10 changes: 8 additions & 2 deletions mockgen/internal/tests/generics/other/other.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ type Four struct{}

type Five interface{}

type Twenty[T any] interface {
Twenty() T
type Twenty[R, S, T any, Z any] interface {
Twenty(S, R) (T, Z)
}

type TwentyThree[U, V any] interface {
TwentyThree(U, V) StructType
}

type StructType struct{}
12 changes: 12 additions & 0 deletions mockgen/internal/tests/generics/source/assert_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package source

import (
"testing"

"github.com/golang/mock/mockgen/internal/tests/generics"
)

func TestAssert(t *testing.T) {
var x MockEmbeddingIface[int, float64]
var _ generics.EmbeddingIface[int, float64] = &x
}

0 comments on commit 6adf1d1

Please sign in to comment.