Skip to content

Commit

Permalink
schema/field: expose RType.Implements method (ent#2379)
Browse files Browse the repository at this point in the history
Also, add both (T) and (*T) methods for RType
  • Loading branch information
a8m authored and gitlawr committed Apr 13, 2022
1 parent e8273c4 commit defb219
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 25 deletions.
2 changes: 1 addition & 1 deletion dialect/sql/schema/migrate.go
Expand Up @@ -171,7 +171,7 @@ func (m *Migrate) Diff(ctx context.Context, tables ...*Table) error {
if err != nil {
return err
}
// Skip if the plan has no changes
// Skip if the plan has no changes.
if len(plan.Changes) == 0 {
return nil
}
Expand Down
46 changes: 26 additions & 20 deletions schema/field/field.go
Expand Up @@ -1165,37 +1165,43 @@ func (d *Descriptor) goType(typ interface{}, expectType reflect.Type) {
Methods: make(map[string]struct{ In, Out []*RType }, t.NumMethod()),
},
}
methods(t, info.RType)
switch t.Kind() {
case reflect.Slice, reflect.Ptr, reflect.Map:
info.Nillable = true
}
switch pt := reflect.PtrTo(t); {
case pt.Implements(valueScannerType):
t = pt
fallthrough
case t.Implements(valueScannerType):
n := t.NumMethod()
for i := 0; i < n; i++ {
m := t.Method(i)
in := make([]*RType, m.Type.NumIn()-1)
for j := range in {
arg := m.Type.In(j + 1)
in[j] = &RType{Name: arg.Name(), Ident: arg.String(), Kind: arg.Kind(), PkgPath: arg.PkgPath()}
}
out := make([]*RType, m.Type.NumOut())
for j := range out {
ret := m.Type.Out(j)
out[j] = &RType{Name: ret.Name(), Ident: ret.String(), Kind: ret.Kind(), PkgPath: ret.PkgPath()}
}
info.RType.Methods[m.Name] = struct{ In, Out []*RType }{in, out}
}
case t.Kind() == expectType.Kind() && t.ConvertibleTo(expectType):
case pt.Implements(valueScannerType), t.Implements(valueScannerType),
t.Kind() == expectType.Kind() && t.ConvertibleTo(expectType):
default:
d.Err = fmt.Errorf("GoType must be a %q type or ValueScanner", expectType)
}
d.Info = info
}

func methods(t reflect.Type, rtype *RType) {
// For type T, add methods with
// pointer receiver as well (*T).
if t.Kind() != reflect.Ptr {
t = reflect.PtrTo(t)
}
n := t.NumMethod()
for i := 0; i < n; i++ {
m := t.Method(i)
in := make([]*RType, m.Type.NumIn()-1)
for j := range in {
arg := m.Type.In(j + 1)
in[j] = &RType{Name: arg.Name(), Ident: arg.String(), Kind: arg.Kind(), PkgPath: arg.PkgPath()}
}
out := make([]*RType, m.Type.NumOut())
for j := range out {
ret := m.Type.Out(j)
out[j] = &RType{Name: ret.Name(), Ident: ret.String(), Kind: ret.Kind(), PkgPath: ret.PkgPath()}
}
rtype.Methods[m.Name] = struct{ In, Out []*RType }{in, out}
}
}

func (d *Descriptor) checkDefaultFunc(expectType reflect.Type) {
for _, typ := range []reflect.Type{reflect.TypeOf(d.Default), reflect.TypeOf(d.UpdateDefault)} {
if typ == nil || typ.Kind() != reflect.Func || d.Err != nil {
Expand Down
75 changes: 75 additions & 0 deletions schema/field/field_test.go
Expand Up @@ -8,14 +8,18 @@ import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"reflect"
"regexp"
"strconv"
"testing"
"time"

"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/schema/field"

Expand Down Expand Up @@ -680,6 +684,77 @@ func TestField_Other(t *testing.T) {
assert.Error(t, fd.Err, "invalid default value")
}

type UserRole string

const (
Admin UserRole = "ADMIN"
User UserRole = "USER"
Unknown UserRole = "UNKNOWN"
)

func (UserRole) Values() (roles []string) {
for _, r := range []UserRole{Admin, User, Unknown} {
roles = append(roles, string(r))
}
return
}

func (e UserRole) String() string {
return string(e)
}

// MarshalGQL implements graphql.Marshaler interface.
func (e UserRole) MarshalGQL(w io.Writer) {
_, _ = io.WriteString(w, strconv.Quote(e.String()))
}

// UnmarshalGQL implements graphql.Unmarshaler interface.
func (e *UserRole) UnmarshalGQL(val interface{}) error {
str, ok := val.(string)
if !ok {
return fmt.Errorf("enum %T must be a string", val)
}
*e = UserRole(str)
switch *e {
case Admin, User, Unknown:
return nil
default:
return fmt.Errorf("%s is not a valid Role", str)
}
}

type Scalar struct{}

func (Scalar) MarshalGQL(io.Writer) {}
func (*Scalar) UnmarshalGQL(interface{}) error { return nil }
func (Scalar) Value() (driver.Value, error) { return nil, nil }

func TestRType_Implements(t *testing.T) {
type (
marshaler interface{ MarshalGQL(w io.Writer) }
unmarshaler interface{ UnmarshalGQL(v interface{}) error }
codec interface {
marshaler
unmarshaler
}
)
var (
codecType = reflect.TypeOf((*codec)(nil)).Elem()
marshalType = reflect.TypeOf((*marshaler)(nil)).Elem()
unmarshalType = reflect.TypeOf((*unmarshaler)(nil)).Elem()
)
for _, f := range []ent.Field{
field.Enum("role").GoType(Admin),
field.Other("scalar", &Scalar{}),
field.Other("scalar", Scalar{}),
} {
fd := f.Descriptor()
assert.True(t, fd.Info.RType.Implements(codecType))
assert.True(t, fd.Info.RType.Implements(marshalType))
assert.True(t, fd.Info.RType.Implements(unmarshalType))
}
}

func TestTypeString(t *testing.T) {
typ := field.TypeBool
assert.Equal(t, "bool", typ.String())
Expand Down
9 changes: 5 additions & 4 deletions schema/field/type.go
Expand Up @@ -120,12 +120,12 @@ func (t TypeInfo) ConstName() string {

// ValueScanner indicates if this type implements the ValueScanner interface.
func (t TypeInfo) ValueScanner() bool {
return t.RType.implements(valueScannerType)
return t.RType.Implements(valueScannerType)
}

// Valuer indicates if this type implements the driver.Valuer interface.
func (t TypeInfo) Valuer() bool {
return t.RType.implements(valuerType)
return t.RType.Implements(valuerType)
}

// Comparable reports whether values of this type are comparable.
Expand All @@ -147,7 +147,7 @@ var stringerType = reflect.TypeOf((*fmt.Stringer)(nil)).Elem()

// Stringer indicates if this type implements the Stringer interface.
func (t TypeInfo) Stringer() bool {
return t.RType.implements(stringerType)
return t.RType.Implements(stringerType)
}

var (
Expand Down Expand Up @@ -215,7 +215,8 @@ func (r *RType) IsPtr() bool {
return r != nil && r.Kind == reflect.Ptr
}

func (r *RType) implements(typ reflect.Type) bool {
// Implements reports whether the RType ~implements the given interface type.
func (r *RType) Implements(typ reflect.Type) bool {
if r == nil {
return false
}
Expand Down

0 comments on commit defb219

Please sign in to comment.