From 7d4906acd6dcbfc7c97d1c021b44952c74f97739 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Mon, 25 Apr 2022 13:19:07 +0300 Subject: [PATCH] entc/integration/json: add example for using interfaces in JSON fields --- entc/integration/json/ent/migrate/schema.go | 1 + entc/integration/json/ent/mutation.go | 75 ++++++++++++++++++++- entc/integration/json/ent/schema/user.go | 49 ++++++++++++++ entc/integration/json/ent/user.go | 14 +++- entc/integration/json/ent/user/user.go | 3 + entc/integration/json/ent/user/where.go | 14 ++++ entc/integration/json/ent/user_create.go | 22 ++++++ entc/integration/json/ent/user_update.go | 66 ++++++++++++++++++ entc/integration/json/json_test.go | 14 ++++ schema/field/field.go | 12 ++-- schema/field/field_test.go | 2 + 11 files changed, 266 insertions(+), 6 deletions(-) diff --git a/entc/integration/json/ent/migrate/schema.go b/entc/integration/json/ent/migrate/schema.go index 473ee9b004..f8e8a52c83 100644 --- a/entc/integration/json/ent/migrate/schema.go +++ b/entc/integration/json/ent/migrate/schema.go @@ -22,6 +22,7 @@ var ( {Name: "ints", Type: field.TypeJSON, Nullable: true}, {Name: "floats", Type: field.TypeJSON, Nullable: true}, {Name: "strings", Type: field.TypeJSON, Nullable: true}, + {Name: "addr", Type: field.TypeJSON, Nullable: true}, } // UsersTable holds the schema information for the "users" table. UsersTable = &schema.Table{ diff --git a/entc/integration/json/ent/mutation.go b/entc/integration/json/ent/mutation.go index 49300483f5..6b9436fbe5 100644 --- a/entc/integration/json/ent/mutation.go +++ b/entc/integration/json/ent/mutation.go @@ -47,6 +47,7 @@ type UserMutation struct { ints *[]int floats *[]float64 strings *[]string + addr *schema.Addr clearedFields map[string]struct{} done bool oldValue func(context.Context) (*User, error) @@ -481,6 +482,55 @@ func (m *UserMutation) ResetStrings() { delete(m.clearedFields, user.FieldStrings) } +// SetAddr sets the "addr" field. +func (m *UserMutation) SetAddr(s schema.Addr) { + m.addr = &s +} + +// Addr returns the value of the "addr" field in the mutation. +func (m *UserMutation) Addr() (r schema.Addr, exists bool) { + v := m.addr + if v == nil { + return + } + return *v, true +} + +// OldAddr returns the old "addr" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldAddr(ctx context.Context) (v schema.Addr, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAddr is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAddr requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAddr: %w", err) + } + return oldValue.Addr, nil +} + +// ClearAddr clears the value of the "addr" field. +func (m *UserMutation) ClearAddr() { + m.addr = nil + m.clearedFields[user.FieldAddr] = struct{}{} +} + +// AddrCleared returns if the "addr" field was cleared in this mutation. +func (m *UserMutation) AddrCleared() bool { + _, ok := m.clearedFields[user.FieldAddr] + return ok +} + +// ResetAddr resets all changes to the "addr" field. +func (m *UserMutation) ResetAddr() { + m.addr = nil + delete(m.clearedFields, user.FieldAddr) +} + // Where appends a list predicates to the UserMutation builder. func (m *UserMutation) Where(ps ...predicate.User) { m.predicates = append(m.predicates, ps...) @@ -500,7 +550,7 @@ func (m *UserMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserMutation) Fields() []string { - fields := make([]string, 0, 7) + fields := make([]string, 0, 8) if m.t != nil { fields = append(fields, user.FieldT) } @@ -522,6 +572,9 @@ func (m *UserMutation) Fields() []string { if m.strings != nil { fields = append(fields, user.FieldStrings) } + if m.addr != nil { + fields = append(fields, user.FieldAddr) + } return fields } @@ -544,6 +597,8 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { return m.Floats() case user.FieldStrings: return m.Strings() + case user.FieldAddr: + return m.Addr() } return nil, false } @@ -567,6 +622,8 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldFloats(ctx) case user.FieldStrings: return m.OldStrings(ctx) + case user.FieldAddr: + return m.OldAddr(ctx) } return nil, fmt.Errorf("unknown User field %s", name) } @@ -625,6 +682,13 @@ func (m *UserMutation) SetField(name string, value ent.Value) error { } m.SetStrings(v) return nil + case user.FieldAddr: + v, ok := value.(schema.Addr) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAddr(v) + return nil } return fmt.Errorf("unknown User field %s", name) } @@ -673,6 +737,9 @@ func (m *UserMutation) ClearedFields() []string { if m.FieldCleared(user.FieldStrings) { fields = append(fields, user.FieldStrings) } + if m.FieldCleared(user.FieldAddr) { + fields = append(fields, user.FieldAddr) + } return fields } @@ -705,6 +772,9 @@ func (m *UserMutation) ClearField(name string) error { case user.FieldStrings: m.ClearStrings() return nil + case user.FieldAddr: + m.ClearAddr() + return nil } return fmt.Errorf("unknown User nullable field %s", name) } @@ -734,6 +804,9 @@ func (m *UserMutation) ResetField(name string) error { case user.FieldStrings: m.ResetStrings() return nil + case user.FieldAddr: + m.ResetAddr() + return nil } return fmt.Errorf("unknown User field %s", name) } diff --git a/entc/integration/json/ent/schema/user.go b/entc/integration/json/ent/schema/user.go index b30981b652..271a9cec86 100644 --- a/entc/integration/json/ent/schema/user.go +++ b/entc/integration/json/ent/schema/user.go @@ -6,6 +6,9 @@ package schema import ( "encoding/json" + "errors" + "fmt" + "net" "net/http" "net/url" @@ -38,6 +41,8 @@ func (User) Fields() []ent.Field { Optional(), field.Strings("strings"). Optional(), + field.JSON("addr", Addr{}). + Optional(), } } @@ -50,3 +55,47 @@ type T struct { Li []int `json:"li,omitempty"` Ls []string `json:"ls,omitempty"` } + +type Addr struct{ net.Addr } + +func (a *Addr) UnmarshalJSON(data []byte) error { + var types struct { + TCP *net.TCPAddr `json:"tcp,omitempty"` + UDP *net.UDPAddr `json:"udp,omitempty"` + } + if err := json.Unmarshal(data, &types); err != nil { + return err + } + switch { + case types.TCP != nil && types.UDP != nil: + return errors.New("TCP and UDP addresses are mutually exclusive") + case types.TCP != nil: + a.Addr = types.TCP + case types.UDP != nil: + a.Addr = types.UDP + } + return nil +} + +func (a Addr) MarshalJSON() ([]byte, error) { + var types struct { + TCP *net.TCPAddr `json:"tcp,omitempty"` + UDP *net.UDPAddr `json:"udp,omitempty"` + } + switch a.Addr.(type) { + case *net.TCPAddr: + types.TCP = a.Addr.(*net.TCPAddr) + case *net.UDPAddr: + types.UDP = a.Addr.(*net.UDPAddr) + default: + return nil, fmt.Errorf("unsupported address type: %T", a.Addr) + } + return json.Marshal(types) +} + +func (a Addr) String() string { + if a.Addr == nil { + return "" + } + return a.Addr.String() +} diff --git a/entc/integration/json/ent/user.go b/entc/integration/json/ent/user.go index ec59b8c2e2..b4cb063c2f 100644 --- a/entc/integration/json/ent/user.go +++ b/entc/integration/json/ent/user.go @@ -37,6 +37,8 @@ type User struct { Floats []float64 `json:"floats,omitempty"` // Strings holds the value of the "strings" field. Strings []string `json:"strings,omitempty"` + // Addr holds the value of the "addr" field. + Addr schema.Addr `json:"addr,omitempty"` } // scanValues returns the types for scanning values from sql.Rows. @@ -44,7 +46,7 @@ func (*User) scanValues(columns []string) ([]interface{}, error) { values := make([]interface{}, len(columns)) for i := range columns { switch columns[i] { - case user.FieldT, user.FieldURL, user.FieldRaw, user.FieldDirs, user.FieldInts, user.FieldFloats, user.FieldStrings: + case user.FieldT, user.FieldURL, user.FieldRaw, user.FieldDirs, user.FieldInts, user.FieldFloats, user.FieldStrings, user.FieldAddr: values[i] = new([]byte) case user.FieldID: values[i] = new(sql.NullInt64) @@ -125,6 +127,14 @@ func (u *User) assignValues(columns []string, values []interface{}) error { return fmt.Errorf("unmarshal field strings: %w", err) } } + case user.FieldAddr: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field addr", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &u.Addr); err != nil { + return fmt.Errorf("unmarshal field addr: %w", err) + } + } } } return nil @@ -167,6 +177,8 @@ func (u *User) String() string { builder.WriteString(fmt.Sprintf("%v", u.Floats)) builder.WriteString(", strings=") builder.WriteString(fmt.Sprintf("%v", u.Strings)) + builder.WriteString(", addr=") + builder.WriteString(fmt.Sprintf("%v", u.Addr)) builder.WriteByte(')') return builder.String() } diff --git a/entc/integration/json/ent/user/user.go b/entc/integration/json/ent/user/user.go index c3a97bcbdb..29f57d6e0b 100644 --- a/entc/integration/json/ent/user/user.go +++ b/entc/integration/json/ent/user/user.go @@ -29,6 +29,8 @@ const ( FieldFloats = "floats" // FieldStrings holds the string denoting the strings field in the database. FieldStrings = "strings" + // FieldAddr holds the string denoting the addr field in the database. + FieldAddr = "addr" // Table holds the table name of the user in the database. Table = "users" ) @@ -43,6 +45,7 @@ var Columns = []string{ FieldInts, FieldFloats, FieldStrings, + FieldAddr, } // ValidColumn reports if the column name is valid (part of the table columns). diff --git a/entc/integration/json/ent/user/where.go b/entc/integration/json/ent/user/where.go index 8c723452ac..c89b4d5d08 100644 --- a/entc/integration/json/ent/user/where.go +++ b/entc/integration/json/ent/user/where.go @@ -178,6 +178,20 @@ func StringsNotNil() predicate.User { }) } +// AddrIsNil applies the IsNil predicate on the "addr" field. +func AddrIsNil() predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.IsNull(s.C(FieldAddr))) + }) +} + +// AddrNotNil applies the NotNil predicate on the "addr" field. +func AddrNotNil() predicate.User { + return predicate.User(func(s *sql.Selector) { + s.Where(sql.NotNull(s.C(FieldAddr))) + }) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.User) predicate.User { return predicate.User(func(s *sql.Selector) { diff --git a/entc/integration/json/ent/user_create.go b/entc/integration/json/ent/user_create.go index 27dca8a11d..bc1239d09c 100644 --- a/entc/integration/json/ent/user_create.go +++ b/entc/integration/json/ent/user_create.go @@ -69,6 +69,20 @@ func (uc *UserCreate) SetStrings(s []string) *UserCreate { return uc } +// SetAddr sets the "addr" field. +func (uc *UserCreate) SetAddr(s schema.Addr) *UserCreate { + uc.mutation.SetAddr(s) + return uc +} + +// SetNillableAddr sets the "addr" field if the given value is not nil. +func (uc *UserCreate) SetNillableAddr(s *schema.Addr) *UserCreate { + if s != nil { + uc.SetAddr(*s) + } + return uc +} + // Mutation returns the UserMutation object of the builder. func (uc *UserCreate) Mutation() *UserMutation { return uc.mutation @@ -238,6 +252,14 @@ func (uc *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { }) _node.Strings = value } + if value, ok := uc.mutation.Addr(); ok { + _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ + Type: field.TypeJSON, + Value: value, + Column: user.FieldAddr, + }) + _node.Addr = value + } return _node, _spec } diff --git a/entc/integration/json/ent/user_update.go b/entc/integration/json/ent/user_update.go index 569eed7168..8a118c4ae1 100644 --- a/entc/integration/json/ent/user_update.go +++ b/entc/integration/json/ent/user_update.go @@ -113,6 +113,26 @@ func (uu *UserUpdate) ClearStrings() *UserUpdate { return uu } +// SetAddr sets the "addr" field. +func (uu *UserUpdate) SetAddr(s schema.Addr) *UserUpdate { + uu.mutation.SetAddr(s) + return uu +} + +// SetNillableAddr sets the "addr" field if the given value is not nil. +func (uu *UserUpdate) SetNillableAddr(s *schema.Addr) *UserUpdate { + if s != nil { + uu.SetAddr(*s) + } + return uu +} + +// ClearAddr clears the value of the "addr" field. +func (uu *UserUpdate) ClearAddr() *UserUpdate { + uu.mutation.ClearAddr() + return uu +} + // Mutation returns the UserMutation object of the builder. func (uu *UserUpdate) Mutation() *UserMutation { return uu.mutation @@ -275,6 +295,19 @@ func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) { Column: user.FieldStrings, }) } + if value, ok := uu.mutation.Addr(); ok { + _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ + Type: field.TypeJSON, + Value: value, + Column: user.FieldAddr, + }) + } + if uu.mutation.AddrCleared() { + _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ + Type: field.TypeJSON, + Column: user.FieldAddr, + }) + } if n, err = sqlgraph.UpdateNodes(ctx, uu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{user.Label} @@ -372,6 +405,26 @@ func (uuo *UserUpdateOne) ClearStrings() *UserUpdateOne { return uuo } +// SetAddr sets the "addr" field. +func (uuo *UserUpdateOne) SetAddr(s schema.Addr) *UserUpdateOne { + uuo.mutation.SetAddr(s) + return uuo +} + +// SetNillableAddr sets the "addr" field if the given value is not nil. +func (uuo *UserUpdateOne) SetNillableAddr(s *schema.Addr) *UserUpdateOne { + if s != nil { + uuo.SetAddr(*s) + } + return uuo +} + +// ClearAddr clears the value of the "addr" field. +func (uuo *UserUpdateOne) ClearAddr() *UserUpdateOne { + uuo.mutation.ClearAddr() + return uuo +} + // Mutation returns the UserMutation object of the builder. func (uuo *UserUpdateOne) Mutation() *UserMutation { return uuo.mutation @@ -558,6 +611,19 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) Column: user.FieldStrings, }) } + if value, ok := uuo.mutation.Addr(); ok { + _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ + Type: field.TypeJSON, + Value: value, + Column: user.FieldAddr, + }) + } + if uuo.mutation.AddrCleared() { + _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ + Type: field.TypeJSON, + Column: user.FieldAddr, + }) + } _node = &User{config: uuo.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/entc/integration/json/json_test.go b/entc/integration/json/json_test.go index d97f257071..12684fe99f 100644 --- a/entc/integration/json/json_test.go +++ b/entc/integration/json/json_test.go @@ -8,6 +8,7 @@ import ( "context" "encoding/json" "fmt" + "net" "net/http" "net/url" "testing" @@ -46,6 +47,7 @@ func TestMySQL(t *testing.T) { Ints(t, client) Floats(t, client) Strings(t, client) + NetAddr(t, client) RawMessage(t, client) // Skip predicates test for MySQL old versions. if version != "56" { @@ -79,6 +81,7 @@ func TestMaria(t *testing.T) { Ints(t, client) Floats(t, client) Strings(t, client) + NetAddr(t, client) RawMessage(t, client) Predicates(t, client) }) @@ -108,6 +111,7 @@ func TestPostgres(t *testing.T) { Ints(t, client) Floats(t, client) Strings(t, client) + NetAddr(t, client) RawMessage(t, client) Predicates(t, client) }) @@ -126,6 +130,7 @@ func TestSQLite(t *testing.T) { Ints(t, client) Floats(t, client) Strings(t, client) + NetAddr(t, client) RawMessage(t, client) Predicates(t, client) } @@ -187,6 +192,15 @@ func RawMessage(t *testing.T, client *ent.Client) { require.Equal(t, raw, client.User.GetX(ctx, usr.ID).Raw) } +func NetAddr(t *testing.T, client *ent.Client) { + ctx := context.Background() + ip := net.ParseIP("127.0.0.1") + usr := client.User.Create().SetAddr(schema.Addr{Addr: &net.TCPAddr{IP: ip, Port: 80}}).SaveX(ctx) + require.Equal(t, "127.0.0.1:80", client.User.GetX(ctx, usr.ID).Addr.String()) + usr.Update().SetAddr(schema.Addr{Addr: &net.UDPAddr{IP: ip, Port: 1812}}).ExecX(ctx) + require.Equal(t, "127.0.0.1:1812", client.User.GetX(ctx, usr.ID).Addr.String()) +} + func Dirs(t *testing.T, client *ent.Client) { ctx := context.Background() dirs := []http.Dir{"dev", "usr"} diff --git a/schema/field/field.go b/schema/field/field.go index c895b304e9..baad1043f8 100644 --- a/schema/field/field.go +++ b/schema/field/field.go @@ -71,15 +71,19 @@ func Time(name string) *timeBuilder { // Optional() // func JSON(name string, typ interface{}) *jsonBuilder { - t := reflect.TypeOf(typ) b := &jsonBuilder{&Descriptor{ Name: name, Info: &TypeInfo{ - Type: TypeJSON, - Ident: t.String(), - PkgPath: t.PkgPath(), + Type: TypeJSON, }, }} + t := reflect.TypeOf(typ) + if t == nil { + b.desc.Err = errors.New("expect a Go value as JSON type, but got nil") + return b + } + b.desc.Info.Ident = t.String() + b.desc.Info.PkgPath = t.PkgPath() b.desc.goType(typ, t) switch t.Kind() { case reflect.Slice, reflect.Array, reflect.Ptr, reflect.Map: diff --git a/schema/field/field_test.go b/schema/field/field_test.go index 88bd8b0d1a..1c85db6d6f 100644 --- a/schema/field/field_test.go +++ b/schema/field/field_test.go @@ -508,6 +508,8 @@ func TestJSON(t *testing.T) { assert.Equal(t, "net/url", fd.Info.PkgPath) fd = field.JSON("values", map[string]*url.Values{}).Descriptor() assert.Equal(t, "net/url", fd.Info.PkgPath) + fd = field.JSON("addr", net.Addr(nil)).Descriptor() + assert.EqualError(t, fd.Err, "expect a Go value as JSON type, but got nil") } func TestField_Tag(t *testing.T) {