Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

schema/field: expose RType.Implements method #2379

Merged
merged 1 commit into from Mar 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion dialect/sql/schema/migrate.go
Expand Up @@ -171,7 +171,7 @@ func (m *Migrate) Diff(ctx context.Context, tables ...*Table) error {
if err != nil {
return err
}
// Skip if the plan has no changes
// Skip if the plan has no changes.
if len(plan.Changes) == 0 {
return nil
}
Expand Down
46 changes: 26 additions & 20 deletions schema/field/field.go
Expand Up @@ -1165,37 +1165,43 @@ func (d *Descriptor) goType(typ interface{}, expectType reflect.Type) {
Methods: make(map[string]struct{ In, Out []*RType }, t.NumMethod()),
},
}
methods(t, info.RType)
switch t.Kind() {
case reflect.Slice, reflect.Ptr, reflect.Map:
info.Nillable = true
}
switch pt := reflect.PtrTo(t); {
case pt.Implements(valueScannerType):
t = pt
fallthrough
case t.Implements(valueScannerType):
n := t.NumMethod()
for i := 0; i < n; i++ {
m := t.Method(i)
in := make([]*RType, m.Type.NumIn()-1)
for j := range in {
arg := m.Type.In(j + 1)
in[j] = &RType{Name: arg.Name(), Ident: arg.String(), Kind: arg.Kind(), PkgPath: arg.PkgPath()}
}
out := make([]*RType, m.Type.NumOut())
for j := range out {
ret := m.Type.Out(j)
out[j] = &RType{Name: ret.Name(), Ident: ret.String(), Kind: ret.Kind(), PkgPath: ret.PkgPath()}
}
info.RType.Methods[m.Name] = struct{ In, Out []*RType }{in, out}
}
case t.Kind() == expectType.Kind() && t.ConvertibleTo(expectType):
case pt.Implements(valueScannerType), t.Implements(valueScannerType),
t.Kind() == expectType.Kind() && t.ConvertibleTo(expectType):
default:
d.Err = fmt.Errorf("GoType must be a %q type or ValueScanner", expectType)
}
d.Info = info
}

func methods(t reflect.Type, rtype *RType) {
// For type T, add methods with
// pointer receiver as well (*T).
if t.Kind() != reflect.Ptr {
t = reflect.PtrTo(t)
}
n := t.NumMethod()
for i := 0; i < n; i++ {
m := t.Method(i)
in := make([]*RType, m.Type.NumIn()-1)
for j := range in {
arg := m.Type.In(j + 1)
in[j] = &RType{Name: arg.Name(), Ident: arg.String(), Kind: arg.Kind(), PkgPath: arg.PkgPath()}
}
out := make([]*RType, m.Type.NumOut())
for j := range out {
ret := m.Type.Out(j)
out[j] = &RType{Name: ret.Name(), Ident: ret.String(), Kind: ret.Kind(), PkgPath: ret.PkgPath()}
}
rtype.Methods[m.Name] = struct{ In, Out []*RType }{in, out}
}
}

func (d *Descriptor) checkDefaultFunc(expectType reflect.Type) {
for _, typ := range []reflect.Type{reflect.TypeOf(d.Default), reflect.TypeOf(d.UpdateDefault)} {
if typ == nil || typ.Kind() != reflect.Func || d.Err != nil {
Expand Down
75 changes: 75 additions & 0 deletions schema/field/field_test.go
Expand Up @@ -8,14 +8,18 @@ import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"reflect"
"regexp"
"strconv"
"testing"
"time"

"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/schema/field"

Expand Down Expand Up @@ -680,6 +684,77 @@ func TestField_Other(t *testing.T) {
assert.Error(t, fd.Err, "invalid default value")
}

type UserRole string

const (
Admin UserRole = "ADMIN"
User UserRole = "USER"
Unknown UserRole = "UNKNOWN"
)

func (UserRole) Values() (roles []string) {
for _, r := range []UserRole{Admin, User, Unknown} {
roles = append(roles, string(r))
}
return
}

func (e UserRole) String() string {
return string(e)
}

// MarshalGQL implements graphql.Marshaler interface.
func (e UserRole) MarshalGQL(w io.Writer) {
_, _ = io.WriteString(w, strconv.Quote(e.String()))
}

// UnmarshalGQL implements graphql.Unmarshaler interface.
func (e *UserRole) UnmarshalGQL(val interface{}) error {
str, ok := val.(string)
if !ok {
return fmt.Errorf("enum %T must be a string", val)
}
*e = UserRole(str)
switch *e {
case Admin, User, Unknown:
return nil
default:
return fmt.Errorf("%s is not a valid Role", str)
}
}

type Scalar struct{}

func (Scalar) MarshalGQL(io.Writer) {}
func (*Scalar) UnmarshalGQL(interface{}) error { return nil }
func (Scalar) Value() (driver.Value, error) { return nil, nil }

func TestRType_Implements(t *testing.T) {
type (
marshaler interface{ MarshalGQL(w io.Writer) }
unmarshaler interface{ UnmarshalGQL(v interface{}) error }
codec interface {
marshaler
unmarshaler
}
)
var (
codecType = reflect.TypeOf((*codec)(nil)).Elem()
marshalType = reflect.TypeOf((*marshaler)(nil)).Elem()
unmarshalType = reflect.TypeOf((*unmarshaler)(nil)).Elem()
)
for _, f := range []ent.Field{
field.Enum("role").GoType(Admin),
field.Other("scalar", &Scalar{}),
field.Other("scalar", Scalar{}),
} {
fd := f.Descriptor()
assert.True(t, fd.Info.RType.Implements(codecType))
assert.True(t, fd.Info.RType.Implements(marshalType))
assert.True(t, fd.Info.RType.Implements(unmarshalType))
}
}

func TestTypeString(t *testing.T) {
typ := field.TypeBool
assert.Equal(t, "bool", typ.String())
Expand Down
9 changes: 5 additions & 4 deletions schema/field/type.go
Expand Up @@ -120,12 +120,12 @@ func (t TypeInfo) ConstName() string {

// ValueScanner indicates if this type implements the ValueScanner interface.
func (t TypeInfo) ValueScanner() bool {
return t.RType.implements(valueScannerType)
return t.RType.Implements(valueScannerType)
}

// Valuer indicates if this type implements the driver.Valuer interface.
func (t TypeInfo) Valuer() bool {
return t.RType.implements(valuerType)
return t.RType.Implements(valuerType)
}

// Comparable reports whether values of this type are comparable.
Expand All @@ -147,7 +147,7 @@ var stringerType = reflect.TypeOf((*fmt.Stringer)(nil)).Elem()

// Stringer indicates if this type implements the Stringer interface.
func (t TypeInfo) Stringer() bool {
return t.RType.implements(stringerType)
return t.RType.Implements(stringerType)
}

var (
Expand Down Expand Up @@ -215,7 +215,8 @@ func (r *RType) IsPtr() bool {
return r != nil && r.Kind == reflect.Ptr
}

func (r *RType) implements(typ reflect.Type) bool {
// Implements reports whether the RType ~implements the given interface type.
func (r *RType) Implements(typ reflect.Type) bool {
if r == nil {
return false
}
Expand Down