From 1eb749407b348e62108e7cfd9c9cb8331004636e Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Mon, 26 Sep 2022 21:05:10 +0300 Subject: [PATCH] dialect/sql/schema: use serial underlying types for fks --- dialect/sql/schema/postgres.go | 3 + entc/integration/migrate/entv2/client.go | 16 ++ entc/integration/migrate/entv2/group.go | 26 +++ entc/integration/migrate/entv2/group/group.go | 11 ++ entc/integration/migrate/entv2/group/where.go | 29 +++ .../integration/migrate/entv2/group_create.go | 53 ++++- entc/integration/migrate/entv2/group_query.go | 84 +++++++- .../integration/migrate/entv2/group_update.go | 181 ++++++++++++++++++ .../migrate/entv2/migrate/schema.go | 12 +- entc/integration/migrate/entv2/mutation.go | 106 +++++++++- entc/integration/migrate/entv2/schema/user.go | 23 ++- entc/integration/migrate/entv2/user.go | 12 +- entc/integration/migrate/entv2/user/user.go | 11 ++ entc/integration/migrate/entv2/user_query.go | 5 + entc/integration/migrate/migrate_test.go | 34 +++- 15 files changed, 586 insertions(+), 20 deletions(-) diff --git a/dialect/sql/schema/postgres.go b/dialect/sql/schema/postgres.go index ea68272b76..140b3f5a3b 100644 --- a/dialect/sql/schema/postgres.go +++ b/dialect/sql/schema/postgres.go @@ -690,6 +690,9 @@ func (d *Postgres) atTypeC(c1 *Column, c2 *schema.Column) error { return err } c2.Type.Type = t + if s, ok := t.(*postgres.SerialType); c1.foreign != nil && ok { + c2.Type.Type = s.IntegerType() + } return nil } var t schema.Type diff --git a/entc/integration/migrate/entv2/client.go b/entc/integration/migrate/entv2/client.go index 7deeac9e86..0f81bf27c6 100644 --- a/entc/integration/migrate/entv2/client.go +++ b/entc/integration/migrate/entv2/client.go @@ -540,6 +540,22 @@ func (c *GroupClient) GetX(ctx context.Context, id int) *Group { return obj } +// QueryAdmins queries the admins edge of a Group. +func (c *GroupClient) QueryAdmins(gr *Group) *UserQuery { + query := &UserQuery{config: c.config} + query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + id := gr.ID + step := sqlgraph.NewStep( + sqlgraph.From(group.Table, group.FieldID, id), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, group.AdminsTable, group.AdminsColumn), + ) + fromV = sqlgraph.Neighbors(gr.driver.Dialect(), step) + return fromV, nil + } + return query +} + // Hooks returns the client hooks. func (c *GroupClient) Hooks() []Hook { return c.hooks.Group diff --git a/entc/integration/migrate/entv2/group.go b/entc/integration/migrate/entv2/group.go index 5517b1036a..686a290da1 100644 --- a/entc/integration/migrate/entv2/group.go +++ b/entc/integration/migrate/entv2/group.go @@ -19,6 +19,27 @@ type Group struct { config // ID of the ent. ID int `json:"id,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the GroupQuery when eager-loading is set. + Edges GroupEdges `json:"edges"` +} + +// GroupEdges holds the relations/edges for other nodes in the graph. +type GroupEdges struct { + // Admins holds the value of the admins edge. + Admins []*User `json:"admins,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [1]bool +} + +// AdminsOrErr returns the Admins value or an error if the edge +// was not loaded in eager-loading. +func (e GroupEdges) AdminsOrErr() ([]*User, error) { + if e.loadedTypes[0] { + return e.Admins, nil + } + return nil, &NotLoadedError{edge: "admins"} } // scanValues returns the types for scanning values from sql.Rows. @@ -54,6 +75,11 @@ func (gr *Group) assignValues(columns []string, values []any) error { return nil } +// QueryAdmins queries the "admins" edge of the Group entity. +func (gr *Group) QueryAdmins() *UserQuery { + return (&GroupClient{config: gr.config}).QueryAdmins(gr) +} + // Update returns a builder for updating this Group. // Note that you need to call Group.Unwrap() before calling this method if this Group // was returned from a transaction, and the transaction was committed or rolled back. diff --git a/entc/integration/migrate/entv2/group/group.go b/entc/integration/migrate/entv2/group/group.go index 8610210047..ef6b5783ca 100644 --- a/entc/integration/migrate/entv2/group/group.go +++ b/entc/integration/migrate/entv2/group/group.go @@ -11,8 +11,19 @@ const ( Label = "group" // FieldID holds the string denoting the id field in the database. FieldID = "id" + // EdgeAdmins holds the string denoting the admins edge name in mutations. + EdgeAdmins = "admins" + // UserFieldID holds the string denoting the ID field of the User. + UserFieldID = "oid" // Table holds the table name of the group in the database. Table = "groups" + // AdminsTable is the table that holds the admins relation/edge. + AdminsTable = "users" + // AdminsInverseTable is the table name for the User entity. + // It exists in this package in order to avoid circular dependency with the "user" package. + AdminsInverseTable = "users" + // AdminsColumn is the table column denoting the admins relation/edge. + AdminsColumn = "group_admins" ) // Columns holds all SQL columns for group fields. diff --git a/entc/integration/migrate/entv2/group/where.go b/entc/integration/migrate/entv2/group/where.go index c52847284d..cdd2c31838 100644 --- a/entc/integration/migrate/entv2/group/where.go +++ b/entc/integration/migrate/entv2/group/where.go @@ -8,6 +8,7 @@ package group import ( "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/entc/integration/migrate/entv2/predicate" ) @@ -82,6 +83,34 @@ func IDLTE(id int) predicate.Group { }) } +// HasAdmins applies the HasEdge predicate on the "admins" edge. +func HasAdmins() predicate.Group { + return predicate.Group(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AdminsTable, UserFieldID), + sqlgraph.Edge(sqlgraph.O2M, false, AdminsTable, AdminsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasAdminsWith applies the HasEdge predicate on the "admins" edge with a given conditions (other predicates). +func HasAdminsWith(preds ...predicate.User) predicate.Group { + return predicate.Group(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AdminsInverseTable, UserFieldID), + sqlgraph.Edge(sqlgraph.O2M, false, AdminsTable, AdminsColumn), + ) + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.Group) predicate.Group { return predicate.Group(func(s *sql.Selector) { diff --git a/entc/integration/migrate/entv2/group_create.go b/entc/integration/migrate/entv2/group_create.go index 820b14cf2c..b575d2e064 100644 --- a/entc/integration/migrate/entv2/group_create.go +++ b/entc/integration/migrate/entv2/group_create.go @@ -12,6 +12,7 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/entc/integration/migrate/entv2/group" + "entgo.io/ent/entc/integration/migrate/entv2/user" "entgo.io/ent/schema/field" ) @@ -22,6 +23,27 @@ type GroupCreate struct { hooks []Hook } +// SetID sets the "id" field. +func (gc *GroupCreate) SetID(i int) *GroupCreate { + gc.mutation.SetID(i) + return gc +} + +// AddAdminIDs adds the "admins" edge to the User entity by IDs. +func (gc *GroupCreate) AddAdminIDs(ids ...int) *GroupCreate { + gc.mutation.AddAdminIDs(ids...) + return gc +} + +// AddAdmins adds the "admins" edges to the User entity. +func (gc *GroupCreate) AddAdmins(u ...*User) *GroupCreate { + ids := make([]int, len(u)) + for i := range u { + ids[i] = u[i].ID + } + return gc.AddAdminIDs(ids...) +} + // Mutation returns the GroupMutation object of the builder. func (gc *GroupCreate) Mutation() *GroupMutation { return gc.mutation @@ -109,8 +131,10 @@ func (gc *GroupCreate) sqlSave(ctx context.Context) (*Group, error) { } return nil, err } - id := _spec.ID.Value.(int64) - _node.ID = int(id) + if _spec.ID.Value != _node.ID { + id := _spec.ID.Value.(int64) + _node.ID = int(id) + } return _node, nil } @@ -125,6 +149,29 @@ func (gc *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { }, } ) + if id, ok := gc.mutation.ID(); ok { + _node.ID = id + _spec.ID.Value = id + } + if nodes := gc.mutation.AdminsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.AdminsTable, + Columns: []string{group.AdminsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: user.FieldID, + }, + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } return _node, _spec } @@ -168,7 +215,7 @@ func (gcb *GroupCreateBulk) Save(ctx context.Context) ([]*Group, error) { return nil, err } mutation.id = &nodes[i].ID - if specs[i].ID.Value != nil { + if specs[i].ID.Value != nil && nodes[i].ID == 0 { id := specs[i].ID.Value.(int64) nodes[i].ID = int(id) } diff --git a/entc/integration/migrate/entv2/group_query.go b/entc/integration/migrate/entv2/group_query.go index 82db0a95cf..3c1c1e169b 100644 --- a/entc/integration/migrate/entv2/group_query.go +++ b/entc/integration/migrate/entv2/group_query.go @@ -8,6 +8,7 @@ package entv2 import ( "context" + "database/sql/driver" "fmt" "math" @@ -15,6 +16,7 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/entc/integration/migrate/entv2/group" "entgo.io/ent/entc/integration/migrate/entv2/predicate" + "entgo.io/ent/entc/integration/migrate/entv2/user" "entgo.io/ent/schema/field" ) @@ -27,6 +29,7 @@ type GroupQuery struct { order []OrderFunc fields []string predicates []predicate.Group + withAdmins *UserQuery // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) @@ -63,6 +66,28 @@ func (gq *GroupQuery) Order(o ...OrderFunc) *GroupQuery { return gq } +// QueryAdmins chains the current query on the "admins" edge. +func (gq *GroupQuery) QueryAdmins() *UserQuery { + query := &UserQuery{config: gq.config} + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := gq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := gq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(group.Table, group.FieldID, selector), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, group.AdminsTable, group.AdminsColumn), + ) + fromU = sqlgraph.SetNeighbors(gq.driver.Dialect(), step) + return fromU, nil + } + return query +} + // First returns the first Group entity from the query. // Returns a *NotFoundError when no Group was found. func (gq *GroupQuery) First(ctx context.Context) (*Group, error) { @@ -244,6 +269,7 @@ func (gq *GroupQuery) Clone() *GroupQuery { offset: gq.offset, order: append([]OrderFunc{}, gq.order...), predicates: append([]predicate.Group{}, gq.predicates...), + withAdmins: gq.withAdmins.Clone(), // clone intermediate query. sql: gq.sql.Clone(), path: gq.path, @@ -251,6 +277,17 @@ func (gq *GroupQuery) Clone() *GroupQuery { } } +// WithAdmins tells the query-builder to eager-load the nodes that are connected to +// the "admins" edge. The optional arguments are used to configure the query builder of the edge. +func (gq *GroupQuery) WithAdmins(opts ...func(*UserQuery)) *GroupQuery { + query := &UserQuery{config: gq.config} + for _, opt := range opts { + opt(query) + } + gq.withAdmins = query + return gq +} + // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. func (gq *GroupQuery) GroupBy(field string, fields ...string) *GroupGroupBy { @@ -295,8 +332,11 @@ func (gq *GroupQuery) prepareQuery(ctx context.Context) error { func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, error) { var ( - nodes = []*Group{} - _spec = gq.querySpec() + nodes = []*Group{} + _spec = gq.querySpec() + loadedTypes = [1]bool{ + gq.withAdmins != nil, + } ) _spec.ScanValues = func(columns []string) ([]any, error) { return (*Group).scanValues(nil, columns) @@ -304,6 +344,7 @@ func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, _spec.Assign = func(columns []string, values []any) error { node := &Group{config: gq.config} nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } for i := range hooks { @@ -315,9 +356,48 @@ func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, if len(nodes) == 0 { return nodes, nil } + if query := gq.withAdmins; query != nil { + if err := gq.loadAdmins(ctx, query, nodes, + func(n *Group) { n.Edges.Admins = []*User{} }, + func(n *Group, e *User) { n.Edges.Admins = append(n.Edges.Admins, e) }); err != nil { + return nil, err + } + } return nodes, nil } +func (gq *GroupQuery) loadAdmins(ctx context.Context, query *UserQuery, nodes []*Group, init func(*Group), assign func(*Group, *User)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*Group) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.withFKs = true + query.Where(predicate.User(func(s *sql.Selector) { + s.Where(sql.InValues(group.AdminsColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.group_admins + if fk == nil { + return fmt.Errorf(`foreign-key "group_admins" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "group_admins" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} + func (gq *GroupQuery) sqlCount(ctx context.Context) (int, error) { _spec := gq.querySpec() _spec.Node.Columns = gq.fields diff --git a/entc/integration/migrate/entv2/group_update.go b/entc/integration/migrate/entv2/group_update.go index 6aeda94a0f..a226392904 100644 --- a/entc/integration/migrate/entv2/group_update.go +++ b/entc/integration/migrate/entv2/group_update.go @@ -15,6 +15,7 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/entc/integration/migrate/entv2/group" "entgo.io/ent/entc/integration/migrate/entv2/predicate" + "entgo.io/ent/entc/integration/migrate/entv2/user" "entgo.io/ent/schema/field" ) @@ -31,11 +32,47 @@ func (gu *GroupUpdate) Where(ps ...predicate.Group) *GroupUpdate { return gu } +// AddAdminIDs adds the "admins" edge to the User entity by IDs. +func (gu *GroupUpdate) AddAdminIDs(ids ...int) *GroupUpdate { + gu.mutation.AddAdminIDs(ids...) + return gu +} + +// AddAdmins adds the "admins" edges to the User entity. +func (gu *GroupUpdate) AddAdmins(u ...*User) *GroupUpdate { + ids := make([]int, len(u)) + for i := range u { + ids[i] = u[i].ID + } + return gu.AddAdminIDs(ids...) +} + // Mutation returns the GroupMutation object of the builder. func (gu *GroupUpdate) Mutation() *GroupMutation { return gu.mutation } +// ClearAdmins clears all "admins" edges to the User entity. +func (gu *GroupUpdate) ClearAdmins() *GroupUpdate { + gu.mutation.ClearAdmins() + return gu +} + +// RemoveAdminIDs removes the "admins" edge to User entities by IDs. +func (gu *GroupUpdate) RemoveAdminIDs(ids ...int) *GroupUpdate { + gu.mutation.RemoveAdminIDs(ids...) + return gu +} + +// RemoveAdmins removes "admins" edges to User entities. +func (gu *GroupUpdate) RemoveAdmins(u ...*User) *GroupUpdate { + ids := make([]int, len(u)) + for i := range u { + ids[i] = u[i].ID + } + return gu.RemoveAdminIDs(ids...) +} + // Save executes the query and returns the number of nodes affected by the update operation. func (gu *GroupUpdate) Save(ctx context.Context) (int, error) { var ( @@ -108,6 +145,60 @@ func (gu *GroupUpdate) sqlSave(ctx context.Context) (n int, err error) { } } } + if gu.mutation.AdminsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.AdminsTable, + Columns: []string{group.AdminsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: user.FieldID, + }, + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := gu.mutation.RemovedAdminsIDs(); len(nodes) > 0 && !gu.mutation.AdminsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.AdminsTable, + Columns: []string{group.AdminsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: user.FieldID, + }, + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := gu.mutation.AdminsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.AdminsTable, + Columns: []string{group.AdminsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: user.FieldID, + }, + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if n, err = sqlgraph.UpdateNodes(ctx, gu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{group.Label} @@ -127,11 +218,47 @@ type GroupUpdateOne struct { mutation *GroupMutation } +// AddAdminIDs adds the "admins" edge to the User entity by IDs. +func (guo *GroupUpdateOne) AddAdminIDs(ids ...int) *GroupUpdateOne { + guo.mutation.AddAdminIDs(ids...) + return guo +} + +// AddAdmins adds the "admins" edges to the User entity. +func (guo *GroupUpdateOne) AddAdmins(u ...*User) *GroupUpdateOne { + ids := make([]int, len(u)) + for i := range u { + ids[i] = u[i].ID + } + return guo.AddAdminIDs(ids...) +} + // Mutation returns the GroupMutation object of the builder. func (guo *GroupUpdateOne) Mutation() *GroupMutation { return guo.mutation } +// ClearAdmins clears all "admins" edges to the User entity. +func (guo *GroupUpdateOne) ClearAdmins() *GroupUpdateOne { + guo.mutation.ClearAdmins() + return guo +} + +// RemoveAdminIDs removes the "admins" edge to User entities by IDs. +func (guo *GroupUpdateOne) RemoveAdminIDs(ids ...int) *GroupUpdateOne { + guo.mutation.RemoveAdminIDs(ids...) + return guo +} + +// RemoveAdmins removes "admins" edges to User entities. +func (guo *GroupUpdateOne) RemoveAdmins(u ...*User) *GroupUpdateOne { + ids := make([]int, len(u)) + for i := range u { + ids[i] = u[i].ID + } + return guo.RemoveAdminIDs(ids...) +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (guo *GroupUpdateOne) Select(field string, fields ...string) *GroupUpdateOne { @@ -234,6 +361,60 @@ func (guo *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error } } } + if guo.mutation.AdminsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.AdminsTable, + Columns: []string{group.AdminsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: user.FieldID, + }, + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := guo.mutation.RemovedAdminsIDs(); len(nodes) > 0 && !guo.mutation.AdminsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.AdminsTable, + Columns: []string{group.AdminsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: user.FieldID, + }, + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := guo.mutation.AdminsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.AdminsTable, + Columns: []string{group.AdminsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: &sqlgraph.FieldSpec{ + Type: field.TypeInt, + Column: user.FieldID, + }, + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } _node = &Group{config: guo.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/entc/integration/migrate/entv2/migrate/schema.go b/entc/integration/migrate/entv2/migrate/schema.go index 6b123526fb..2affbdd345 100644 --- a/entc/integration/migrate/entv2/migrate/schema.go +++ b/entc/integration/migrate/entv2/migrate/schema.go @@ -67,7 +67,7 @@ var ( } // GroupsColumns holds the columns for the "groups" table. GroupsColumns = []*schema.Column{ - {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "id", Type: field.TypeInt, Increment: true, SchemaType: map[string]string{"postgres": "serial"}}, } // GroupsTable holds the schema information for the "groups" table. GroupsTable = &schema.Table{ @@ -149,12 +149,21 @@ var ( {Name: "workplace", Type: field.TypeString, Nullable: true}, {Name: "created_at", Type: field.TypeTime, Default: "CURRENT_TIMESTAMP"}, {Name: "drop_optional", Type: field.TypeString}, + {Name: "group_admins", Type: field.TypeInt, Nullable: true, SchemaType: map[string]string{"postgres": "serial"}}, } // UsersTable holds the schema information for the "users" table. UsersTable = &schema.Table{ Name: "users", Columns: UsersColumns, PrimaryKey: []*schema.Column{UsersColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "users_groups_admins", + Columns: []*schema.Column{UsersColumns[19]}, + RefColumns: []*schema.Column{GroupsColumns[0]}, + OnDelete: schema.SetNull, + }, + }, Indexes: []*schema.Index{ { Name: "user_description", @@ -257,6 +266,7 @@ func init() { "boring_check": "source_uri <> 'entgo.io'", } PetsTable.ForeignKeys[0].RefTable = UsersTable + UsersTable.ForeignKeys[0].RefTable = GroupsTable FriendsTable.ForeignKeys[0].RefTable = UsersTable FriendsTable.ForeignKeys[1].RefTable = UsersTable } diff --git a/entc/integration/migrate/entv2/mutation.go b/entc/integration/migrate/entv2/mutation.go index c4eb89a9a5..779515d5e1 100644 --- a/entc/integration/migrate/entv2/mutation.go +++ b/entc/integration/migrate/entv2/mutation.go @@ -16,6 +16,7 @@ import ( "entgo.io/ent/entc/integration/migrate/entv2/car" "entgo.io/ent/entc/integration/migrate/entv2/conversion" "entgo.io/ent/entc/integration/migrate/entv2/customtype" + "entgo.io/ent/entc/integration/migrate/entv2/group" "entgo.io/ent/entc/integration/migrate/entv2/media" "entgo.io/ent/entc/integration/migrate/entv2/pet" "entgo.io/ent/entc/integration/migrate/entv2/predicate" @@ -1845,6 +1846,9 @@ type GroupMutation struct { typ string id *int clearedFields map[string]struct{} + admins map[int]struct{} + removedadmins map[int]struct{} + clearedadmins bool done bool oldValue func(context.Context) (*Group, error) predicates []predicate.Group @@ -1920,6 +1924,12 @@ func (m GroupMutation) Tx() (*Tx, error) { return tx, nil } +// SetID sets the value of the id field. Note that this +// operation is only accepted on creation of Group entities. +func (m *GroupMutation) SetID(id int) { + m.id = &id +} + // ID returns the ID value in the mutation. Note that the ID is only available // if it was provided to the builder or after it was returned from the database. func (m *GroupMutation) ID() (id int, exists bool) { @@ -1948,6 +1958,60 @@ func (m *GroupMutation) IDs(ctx context.Context) ([]int, error) { } } +// AddAdminIDs adds the "admins" edge to the User entity by ids. +func (m *GroupMutation) AddAdminIDs(ids ...int) { + if m.admins == nil { + m.admins = make(map[int]struct{}) + } + for i := range ids { + m.admins[ids[i]] = struct{}{} + } +} + +// ClearAdmins clears the "admins" edge to the User entity. +func (m *GroupMutation) ClearAdmins() { + m.clearedadmins = true +} + +// AdminsCleared reports if the "admins" edge to the User entity was cleared. +func (m *GroupMutation) AdminsCleared() bool { + return m.clearedadmins +} + +// RemoveAdminIDs removes the "admins" edge to the User entity by IDs. +func (m *GroupMutation) RemoveAdminIDs(ids ...int) { + if m.removedadmins == nil { + m.removedadmins = make(map[int]struct{}) + } + for i := range ids { + delete(m.admins, ids[i]) + m.removedadmins[ids[i]] = struct{}{} + } +} + +// RemovedAdmins returns the removed IDs of the "admins" edge to the User entity. +func (m *GroupMutation) RemovedAdminsIDs() (ids []int) { + for id := range m.removedadmins { + ids = append(ids, id) + } + return +} + +// AdminsIDs returns the "admins" edge IDs in the mutation. +func (m *GroupMutation) AdminsIDs() (ids []int) { + for id := range m.admins { + ids = append(ids, id) + } + return +} + +// ResetAdmins resets all changes to the "admins" edge. +func (m *GroupMutation) ResetAdmins() { + m.admins = nil + m.clearedadmins = false + m.removedadmins = nil +} + // Where appends a list predicates to the GroupMutation builder. func (m *GroupMutation) Where(ps ...predicate.Group) { m.predicates = append(m.predicates, ps...) @@ -2041,49 +2105,85 @@ func (m *GroupMutation) ResetField(name string) error { // AddedEdges returns all edge names that were set/added in this mutation. func (m *GroupMutation) AddedEdges() []string { - edges := make([]string, 0, 0) + edges := make([]string, 0, 1) + if m.admins != nil { + edges = append(edges, group.EdgeAdmins) + } return edges } // AddedIDs returns all IDs (to other nodes) that were added for the given edge // name in this mutation. func (m *GroupMutation) AddedIDs(name string) []ent.Value { + switch name { + case group.EdgeAdmins: + ids := make([]ent.Value, 0, len(m.admins)) + for id := range m.admins { + ids = append(ids, id) + } + return ids + } return nil } // RemovedEdges returns all edge names that were removed in this mutation. func (m *GroupMutation) RemovedEdges() []string { - edges := make([]string, 0, 0) + edges := make([]string, 0, 1) + if m.removedadmins != nil { + edges = append(edges, group.EdgeAdmins) + } return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. func (m *GroupMutation) RemovedIDs(name string) []ent.Value { + switch name { + case group.EdgeAdmins: + ids := make([]ent.Value, 0, len(m.removedadmins)) + for id := range m.removedadmins { + ids = append(ids, id) + } + return ids + } return nil } // ClearedEdges returns all edge names that were cleared in this mutation. func (m *GroupMutation) ClearedEdges() []string { - edges := make([]string, 0, 0) + edges := make([]string, 0, 1) + if m.clearedadmins { + edges = append(edges, group.EdgeAdmins) + } return edges } // EdgeCleared returns a boolean which indicates if the edge with the given name // was cleared in this mutation. func (m *GroupMutation) EdgeCleared(name string) bool { + switch name { + case group.EdgeAdmins: + return m.clearedadmins + } return false } // ClearEdge clears the value of the edge with the given name. It returns an error // if that edge is not defined in the schema. func (m *GroupMutation) ClearEdge(name string) error { + switch name { + } return fmt.Errorf("unknown Group unique edge %s", name) } // ResetEdge resets all changes to the edge with the given name in this mutation. // It returns an error if the edge is not defined in the schema. func (m *GroupMutation) ResetEdge(name string) error { + switch name { + case group.EdgeAdmins: + m.ResetAdmins() + return nil + } return fmt.Errorf("unknown Group edge %s", name) } diff --git a/entc/integration/migrate/entv2/schema/user.go b/entc/integration/migrate/entv2/schema/user.go index 2583f4f9d9..d8ae382795 100644 --- a/entc/integration/migrate/entv2/schema/user.go +++ b/entc/integration/migrate/entv2/schema/user.go @@ -7,17 +7,17 @@ package schema import ( "time" - "github.com/google/uuid" - - "entgo.io/ent/schema" - "entgo.io/ent" "entgo.io/ent/dialect" "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" "entgo.io/ent/schema/edge" "entgo.io/ent/schema/field" "entgo.io/ent/schema/index" "entgo.io/ent/schema/mixin" + + "ariga.io/atlas/sql/postgres" + "github.com/google/uuid" ) type Mixin struct { @@ -197,6 +197,21 @@ func (Car) Edges() []ent.Edge { // Group schema. type Group struct{ ent.Schema } +func (Group) Fields() []ent.Field { + return []ent.Field{ + field.Int("id"). + SchemaType(map[string]string{ + dialect.Postgres: postgres.TypeSerial, + }), + } +} + +func (Group) Edges() []ent.Edge { + return []ent.Edge{ + edge.To("admins", User.Type), + } +} + // Pet schema. type Pet struct { ent.Schema diff --git a/entc/integration/migrate/entv2/user.go b/entc/integration/migrate/entv2/user.go index 83d4b71d96..743df7fe38 100644 --- a/entc/integration/migrate/entv2/user.go +++ b/entc/integration/migrate/entv2/user.go @@ -59,7 +59,8 @@ type User struct { DropOptional string `json:"drop_optional,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the UserQuery when eager-loading is set. - Edges UserEdges `json:"edges"` + Edges UserEdges `json:"edges"` + group_admins *int } // UserEdges holds the relations/edges for other nodes in the graph. @@ -121,6 +122,8 @@ func (*User) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullString) case user.FieldCreatedAt: values[i] = new(sql.NullTime) + case user.ForeignKeys[0]: // group_admins + values[i] = new(sql.NullInt64) default: return nil, fmt.Errorf("unexpected column %q for type User", columns[i]) } @@ -250,6 +253,13 @@ func (u *User) assignValues(columns []string, values []any) error { } else if value.Valid { u.DropOptional = value.String } + case user.ForeignKeys[0]: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for edge-field group_admins", value) + } else if value.Valid { + u.group_admins = new(int) + *u.group_admins = int(value.Int64) + } } } return nil diff --git a/entc/integration/migrate/entv2/user/user.go b/entc/integration/migrate/entv2/user/user.go index 7dd24e9927..1db0c9873c 100644 --- a/entc/integration/migrate/entv2/user/user.go +++ b/entc/integration/migrate/entv2/user/user.go @@ -105,6 +105,12 @@ var Columns = []string{ FieldDropOptional, } +// ForeignKeys holds the SQL foreign-keys that are owned by the "users" +// table and are not defined as standalone fields in the schema. +var ForeignKeys = []string{ + "group_admins", +} + var ( // FriendsPrimaryKey and FriendsColumn2 are the table columns denoting the // primary key for the friends relation (M2M). @@ -118,6 +124,11 @@ func ValidColumn(column string) bool { return true } } + for i := range ForeignKeys { + if column == ForeignKeys[i] { + return true + } + } return false } diff --git a/entc/integration/migrate/entv2/user_query.go b/entc/integration/migrate/entv2/user_query.go index b6bfbc9815..2313193725 100644 --- a/entc/integration/migrate/entv2/user_query.go +++ b/entc/integration/migrate/entv2/user_query.go @@ -33,6 +33,7 @@ type UserQuery struct { withCar *CarQuery withPets *PetQuery withFriends *UserQuery + withFKs bool // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) @@ -426,6 +427,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} + withFKs = uq.withFKs _spec = uq.querySpec() loadedTypes = [3]bool{ uq.withCar != nil, @@ -433,6 +435,9 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e uq.withFriends != nil, } ) + if withFKs { + _spec.Node.Columns = append(_spec.Node.Columns, user.ForeignKeys...) + } _spec.ScanValues = func(columns []string) ([]any, error) { return (*User).scanValues(nil, columns) } diff --git a/entc/integration/migrate/migrate_test.go b/entc/integration/migrate/migrate_test.go index 9d60315984..d939366ad2 100644 --- a/entc/integration/migrate/migrate_test.go +++ b/entc/integration/migrate/migrate_test.go @@ -19,9 +19,6 @@ import ( "testing" "text/template" - "ariga.io/atlas/sql/migrate" - atlas "ariga.io/atlas/sql/schema" - "ariga.io/atlas/sql/sqltool" "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/schema" @@ -31,11 +28,17 @@ import ( "entgo.io/ent/entc/integration/migrate/entv2" "entgo.io/ent/entc/integration/migrate/entv2/conversion" "entgo.io/ent/entc/integration/migrate/entv2/customtype" + "entgo.io/ent/entc/integration/migrate/entv2/group" migratev2 "entgo.io/ent/entc/integration/migrate/entv2/migrate" "entgo.io/ent/entc/integration/migrate/entv2/predicate" "entgo.io/ent/entc/integration/migrate/entv2/user" "entgo.io/ent/entc/integration/migrate/versioned" vmigrate "entgo.io/ent/entc/integration/migrate/versioned/migrate" + + "ariga.io/atlas/sql/migrate" + "ariga.io/atlas/sql/postgres" + atlas "ariga.io/atlas/sql/schema" + "ariga.io/atlas/sql/sqltool" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" @@ -106,7 +109,26 @@ func TestPostgres(t *testing.T) { clientv1 := entv1.NewClient(entv1.Driver(drv)) clientv2 := entv2.NewClient(entv2.Driver(drv)) - V1ToV2(t, drv.Dialect(), clientv1, clientv2) + V1ToV2( + t, drv.Dialect(), clientv1, clientv2, + // A diff hook to ensure foreign-keys that point to + // serial columns are configured to integer types. + func(next schema.Differ) schema.Differ { + return schema.DiffFunc(func(current, desired *atlas.Schema) ([]atlas.Change, error) { + groups, ok := desired.Table(group.Table) + require.True(t, ok) + id, ok := groups.Column(group.FieldID) + require.True(t, ok) + require.IsType(t, &postgres.SerialType{}, id.Type.Type) + users, ok := desired.Table(user.Table) + require.True(t, ok) + fk, ok := users.Column(group.AdminsColumn) + require.True(t, ok) + require.IsType(t, &atlas.IntegerType{}, fk.Type.Type) + return next.Diff(current, desired) + }) + }, + ) CheckConstraint(t, clientv2) TimePrecision(t, drv, "SELECT datetime_precision FROM information_schema.columns WHERE table_name = $1 AND column_name = $2") PartialIndexes(t, drv, "select indexdef from pg_indexes where indexname=$1", "CREATE INDEX user_phone ON public.users USING btree (phone) WHERE active") @@ -312,7 +334,7 @@ func Versioned(t *testing.T, drv sql.ExecQuerier, devURL string, client *version require.Equal(t, string(f1), string(f2)) } -func V1ToV2(t *testing.T, dialect string, clientv1 *entv1.Client, clientv2 *entv2.Client) { +func V1ToV2(t *testing.T, dialect string, clientv1 *entv1.Client, clientv2 *entv2.Client, hooks ...schema.DiffHook) { ctx := context.Background() // Run migration and execute queries on v1. @@ -328,7 +350,7 @@ func V1ToV2(t *testing.T, dialect string, clientv1 *entv1.Client, clientv2 *entv clientv1.Conversion.DeleteOne(c1).ExecX(ctx) // Run migration and execute queries on v2. - require.NoError(t, clientv2.Schema.Create(ctx, migratev2.WithGlobalUniqueID(true), migratev2.WithDropIndex(true), migratev2.WithDropColumn(true), schema.WithDiffHook(renameTokenColumn), schema.WithApplyHook(fillNulls(dialect)))) + require.NoError(t, clientv2.Schema.Create(ctx, migratev2.WithGlobalUniqueID(true), migratev2.WithDropIndex(true), migratev2.WithDropColumn(true), schema.WithDiffHook(append(hooks, renameTokenColumn)...), schema.WithApplyHook(fillNulls(dialect)))) require.NoError(t, clientv2.Schema.Create(ctx, migratev2.WithGlobalUniqueID(true), migratev2.WithDropIndex(true), migratev2.WithDropColumn(true)), "should not create additional resources on multiple runs") SanityV2(t, dialect, clientv2) clientv2.Conversion.CreateBulk(clientv2.Conversion.Create(), clientv2.Conversion.Create(), clientv2.Conversion.Create()).ExecX(ctx)