diff --git a/dialect/sql/schema/migrate.go b/dialect/sql/schema/migrate.go index 91adcdbc49..57a8cda3b9 100644 --- a/dialect/sql/schema/migrate.go +++ b/dialect/sql/schema/migrate.go @@ -171,7 +171,7 @@ func (m *Migrate) Diff(ctx context.Context, tables ...*Table) error { if err != nil { return err } - // Skip if the plan has no changes + // Skip if the plan has no changes. if len(plan.Changes) == 0 { return nil } diff --git a/schema/field/field.go b/schema/field/field.go index 0c20eda2b3..c895b304e9 100644 --- a/schema/field/field.go +++ b/schema/field/field.go @@ -1165,37 +1165,43 @@ func (d *Descriptor) goType(typ interface{}, expectType reflect.Type) { Methods: make(map[string]struct{ In, Out []*RType }, t.NumMethod()), }, } + methods(t, info.RType) switch t.Kind() { case reflect.Slice, reflect.Ptr, reflect.Map: info.Nillable = true } switch pt := reflect.PtrTo(t); { - case pt.Implements(valueScannerType): - t = pt - fallthrough - case t.Implements(valueScannerType): - n := t.NumMethod() - for i := 0; i < n; i++ { - m := t.Method(i) - in := make([]*RType, m.Type.NumIn()-1) - for j := range in { - arg := m.Type.In(j + 1) - in[j] = &RType{Name: arg.Name(), Ident: arg.String(), Kind: arg.Kind(), PkgPath: arg.PkgPath()} - } - out := make([]*RType, m.Type.NumOut()) - for j := range out { - ret := m.Type.Out(j) - out[j] = &RType{Name: ret.Name(), Ident: ret.String(), Kind: ret.Kind(), PkgPath: ret.PkgPath()} - } - info.RType.Methods[m.Name] = struct{ In, Out []*RType }{in, out} - } - case t.Kind() == expectType.Kind() && t.ConvertibleTo(expectType): + case pt.Implements(valueScannerType), t.Implements(valueScannerType), + t.Kind() == expectType.Kind() && t.ConvertibleTo(expectType): default: d.Err = fmt.Errorf("GoType must be a %q type or ValueScanner", expectType) } d.Info = info } +func methods(t reflect.Type, rtype *RType) { + // For type T, add methods with + // pointer receiver as well (*T). + if t.Kind() != reflect.Ptr { + t = reflect.PtrTo(t) + } + n := t.NumMethod() + for i := 0; i < n; i++ { + m := t.Method(i) + in := make([]*RType, m.Type.NumIn()-1) + for j := range in { + arg := m.Type.In(j + 1) + in[j] = &RType{Name: arg.Name(), Ident: arg.String(), Kind: arg.Kind(), PkgPath: arg.PkgPath()} + } + out := make([]*RType, m.Type.NumOut()) + for j := range out { + ret := m.Type.Out(j) + out[j] = &RType{Name: ret.Name(), Ident: ret.String(), Kind: ret.Kind(), PkgPath: ret.PkgPath()} + } + rtype.Methods[m.Name] = struct{ In, Out []*RType }{in, out} + } +} + func (d *Descriptor) checkDefaultFunc(expectType reflect.Type) { for _, typ := range []reflect.Type{reflect.TypeOf(d.Default), reflect.TypeOf(d.UpdateDefault)} { if typ == nil || typ.Kind() != reflect.Func || d.Err != nil { diff --git a/schema/field/field_test.go b/schema/field/field_test.go index 84740d6fb2..88bd8b0d1a 100644 --- a/schema/field/field_test.go +++ b/schema/field/field_test.go @@ -8,14 +8,18 @@ import ( "database/sql" "database/sql/driver" "errors" + "fmt" + "io" "net" "net/http" "net/url" "reflect" "regexp" + "strconv" "testing" "time" + "entgo.io/ent" "entgo.io/ent/dialect" "entgo.io/ent/schema/field" @@ -680,6 +684,77 @@ func TestField_Other(t *testing.T) { assert.Error(t, fd.Err, "invalid default value") } +type UserRole string + +const ( + Admin UserRole = "ADMIN" + User UserRole = "USER" + Unknown UserRole = "UNKNOWN" +) + +func (UserRole) Values() (roles []string) { + for _, r := range []UserRole{Admin, User, Unknown} { + roles = append(roles, string(r)) + } + return +} + +func (e UserRole) String() string { + return string(e) +} + +// MarshalGQL implements graphql.Marshaler interface. +func (e UserRole) MarshalGQL(w io.Writer) { + _, _ = io.WriteString(w, strconv.Quote(e.String())) +} + +// UnmarshalGQL implements graphql.Unmarshaler interface. +func (e *UserRole) UnmarshalGQL(val interface{}) error { + str, ok := val.(string) + if !ok { + return fmt.Errorf("enum %T must be a string", val) + } + *e = UserRole(str) + switch *e { + case Admin, User, Unknown: + return nil + default: + return fmt.Errorf("%s is not a valid Role", str) + } +} + +type Scalar struct{} + +func (Scalar) MarshalGQL(io.Writer) {} +func (*Scalar) UnmarshalGQL(interface{}) error { return nil } +func (Scalar) Value() (driver.Value, error) { return nil, nil } + +func TestRType_Implements(t *testing.T) { + type ( + marshaler interface{ MarshalGQL(w io.Writer) } + unmarshaler interface{ UnmarshalGQL(v interface{}) error } + codec interface { + marshaler + unmarshaler + } + ) + var ( + codecType = reflect.TypeOf((*codec)(nil)).Elem() + marshalType = reflect.TypeOf((*marshaler)(nil)).Elem() + unmarshalType = reflect.TypeOf((*unmarshaler)(nil)).Elem() + ) + for _, f := range []ent.Field{ + field.Enum("role").GoType(Admin), + field.Other("scalar", &Scalar{}), + field.Other("scalar", Scalar{}), + } { + fd := f.Descriptor() + assert.True(t, fd.Info.RType.Implements(codecType)) + assert.True(t, fd.Info.RType.Implements(marshalType)) + assert.True(t, fd.Info.RType.Implements(unmarshalType)) + } +} + func TestTypeString(t *testing.T) { typ := field.TypeBool assert.Equal(t, "bool", typ.String()) diff --git a/schema/field/type.go b/schema/field/type.go index addbffc699..15e1964b12 100644 --- a/schema/field/type.go +++ b/schema/field/type.go @@ -120,12 +120,12 @@ func (t TypeInfo) ConstName() string { // ValueScanner indicates if this type implements the ValueScanner interface. func (t TypeInfo) ValueScanner() bool { - return t.RType.implements(valueScannerType) + return t.RType.Implements(valueScannerType) } // Valuer indicates if this type implements the driver.Valuer interface. func (t TypeInfo) Valuer() bool { - return t.RType.implements(valuerType) + return t.RType.Implements(valuerType) } // Comparable reports whether values of this type are comparable. @@ -147,7 +147,7 @@ var stringerType = reflect.TypeOf((*fmt.Stringer)(nil)).Elem() // Stringer indicates if this type implements the Stringer interface. func (t TypeInfo) Stringer() bool { - return t.RType.implements(stringerType) + return t.RType.Implements(stringerType) } var ( @@ -215,7 +215,8 @@ func (r *RType) IsPtr() bool { return r != nil && r.Kind == reflect.Ptr } -func (r *RType) implements(typ reflect.Type) bool { +// Implements reports whether the RType ~implements the given interface type. +func (r *RType) Implements(typ reflect.Type) bool { if r == nil { return false }