From 5d93290fb458fbc2b23ca263fd53a3de846e6498 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Sun, 20 Mar 2022 11:57:24 +0200 Subject: [PATCH] entc/gen: use join for m2m relationship --- entc/gen/template/dialect/sql/globals.tmpl | 6 +- entc/gen/template/dialect/sql/query.tmpl | 111 +++--- .../cascadelete/ent/comment_query.go | 15 +- entc/integration/cascadelete/ent/ent.go | 4 + .../integration/cascadelete/ent/post_query.go | 15 +- .../integration/cascadelete/ent/user_query.go | 15 +- entc/integration/config/ent/ent.go | 4 + entc/integration/config/ent/user_query.go | 15 +- .../integration/customid/ent/account_query.go | 15 +- entc/integration/customid/ent/blob_query.go | 99 +++-- entc/integration/customid/ent/car_query.go | 15 +- entc/integration/customid/ent/device_query.go | 15 +- entc/integration/customid/ent/doc_query.go | 15 +- entc/integration/customid/ent/ent.go | 4 + entc/integration/customid/ent/group_query.go | 99 +++-- .../integration/customid/ent/mixinid_query.go | 15 +- entc/integration/customid/ent/note_query.go | 15 +- entc/integration/customid/ent/other_query.go | 15 +- entc/integration/customid/ent/pet_query.go | 99 +++-- .../integration/customid/ent/session_query.go | 15 +- entc/integration/customid/ent/token_query.go | 15 +- entc/integration/customid/ent/user_query.go | 99 +++-- entc/integration/edgefield/ent/car_query.go | 15 +- entc/integration/edgefield/ent/card_query.go | 15 +- entc/integration/edgefield/ent/ent.go | 4 + entc/integration/edgefield/ent/info_query.go | 15 +- .../edgefield/ent/metadata_query.go | 15 +- entc/integration/edgefield/ent/node_query.go | 15 +- entc/integration/edgefield/ent/pet_query.go | 15 +- entc/integration/edgefield/ent/post_query.go | 15 +- .../integration/edgefield/ent/rental_query.go | 15 +- entc/integration/edgefield/ent/user_query.go | 15 +- entc/integration/ent/card_query.go | 99 +++-- entc/integration/ent/comment_query.go | 15 +- entc/integration/ent/ent.go | 4 + entc/integration/ent/fieldtype_query.go | 15 +- entc/integration/ent/file_query.go | 15 +- entc/integration/ent/filetype_query.go | 15 +- entc/integration/ent/goods_query.go | 15 +- entc/integration/ent/group_query.go | 99 +++-- entc/integration/ent/groupinfo_query.go | 15 +- entc/integration/ent/item_query.go | 15 +- entc/integration/ent/node_query.go | 15 +- entc/integration/ent/pet_query.go | 15 +- entc/integration/ent/spec_query.go | 99 +++-- entc/integration/ent/task_query.go | 15 +- entc/integration/ent/user_query.go | 351 ++++++++---------- entc/integration/hooks/ent/card_query.go | 15 +- entc/integration/hooks/ent/ent.go | 4 + entc/integration/hooks/ent/user_query.go | 99 +++-- entc/integration/idtype/ent/ent.go | 4 + entc/integration/idtype/ent/user_query.go | 183 ++++----- entc/integration/json/ent/ent.go | 4 + entc/integration/json/ent/user_query.go | 15 +- entc/integration/migrate/entv1/car_query.go | 15 +- .../migrate/entv1/conversion_query.go | 15 +- .../migrate/entv1/customtype_query.go | 15 +- entc/integration/migrate/entv1/ent.go | 4 + entc/integration/migrate/entv1/user_query.go | 15 +- entc/integration/migrate/entv2/car_query.go | 15 +- .../migrate/entv2/conversion_query.go | 15 +- .../migrate/entv2/customtype_query.go | 15 +- entc/integration/migrate/entv2/ent.go | 4 + entc/integration/migrate/entv2/group_query.go | 15 +- entc/integration/migrate/entv2/media_query.go | 15 +- entc/integration/migrate/entv2/pet_query.go | 15 +- entc/integration/migrate/entv2/user_query.go | 99 +++-- .../migrate/versioned/car_query.go | 15 +- entc/integration/migrate/versioned/ent.go | 4 + .../migrate/versioned/user_query.go | 15 +- entc/integration/multischema/ent/ent.go | 4 + .../multischema/ent/group_query.go | 100 +++-- entc/integration/multischema/ent/pet_query.go | 15 +- .../integration/multischema/ent/user_query.go | 100 +++-- entc/integration/privacy/ent/ent.go | 4 + entc/integration/privacy/ent/task_query.go | 99 +++-- entc/integration/privacy/ent/team_query.go | 183 ++++----- entc/integration/privacy/ent/user_query.go | 99 +++-- entc/integration/template/ent/ent.go | 4 + entc/integration/template/ent/group_query.go | 15 +- entc/integration/template/ent/pet_query.go | 15 +- entc/integration/template/ent/user_query.go | 99 +++-- examples/edgeindex/ent/city_query.go | 15 +- examples/edgeindex/ent/ent.go | 4 + examples/edgeindex/ent/street_query.go | 15 +- examples/entcpkg/ent/ent.go | 4 + examples/entcpkg/ent/user_query.go | 15 +- examples/fs/ent/ent.go | 4 + examples/fs/ent/file_query.go | 15 +- examples/m2m2types/ent/ent.go | 4 + examples/m2m2types/ent/group_query.go | 99 +++-- examples/m2m2types/ent/user_query.go | 99 +++-- examples/m2mbidi/ent/ent.go | 4 + examples/m2mbidi/ent/user_query.go | 99 +++-- examples/m2mrecur/ent/ent.go | 4 + examples/m2mrecur/ent/user_query.go | 183 ++++----- examples/o2m2types/ent/ent.go | 4 + examples/o2m2types/ent/pet_query.go | 15 +- examples/o2m2types/ent/user_query.go | 15 +- examples/o2mrecur/ent/ent.go | 4 + examples/o2mrecur/ent/node_query.go | 15 +- examples/o2o2types/ent/card_query.go | 15 +- examples/o2o2types/ent/ent.go | 4 + examples/o2o2types/ent/user_query.go | 15 +- examples/o2obidi/ent/ent.go | 4 + examples/o2obidi/ent/user_query.go | 15 +- examples/o2orecur/ent/ent.go | 4 + examples/o2orecur/ent/node_query.go | 15 +- examples/privacyadmin/ent/ent.go | 4 + examples/privacyadmin/ent/user_query.go | 15 +- examples/privacytenant/ent/ent.go | 4 + examples/privacytenant/ent/group_query.go | 99 +++-- examples/privacytenant/ent/tenant_query.go | 15 +- examples/privacytenant/ent/user_query.go | 99 +++-- examples/start/ent/car_query.go | 15 +- examples/start/ent/ent.go | 4 + examples/start/ent/group_query.go | 99 +++-- examples/start/ent/user_query.go | 99 +++-- examples/traversal/ent/ent.go | 4 + examples/traversal/ent/group_query.go | 99 +++-- examples/traversal/ent/pet_query.go | 99 +++-- examples/traversal/ent/user_query.go | 183 ++++----- examples/version/ent/ent.go | 4 + examples/version/ent/user_query.go | 15 +- 124 files changed, 2076 insertions(+), 2483 deletions(-) diff --git a/entc/gen/template/dialect/sql/globals.tmpl b/entc/gen/template/dialect/sql/globals.tmpl index 300a300a2d..ed0ccf12b3 100644 --- a/entc/gen/template/dialect/sql/globals.tmpl +++ b/entc/gen/template/dialect/sql/globals.tmpl @@ -7,4 +7,8 @@ in the LICENSE file in the root directory of this source tree. {{/* gotype: entgo.io/ent/entc/gen.Graph */}} {{/* custom globals and helpers for sql dialects */}} -{{ define "dialect/sql/globals" }}{{ end }} +{{ define "dialect/sql/globals" }} + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) +{{ end }} diff --git a/entc/gen/template/dialect/sql/query.tmpl b/entc/gen/template/dialect/sql/query.tmpl index 47045e47d2..560e7e7798 100644 --- a/entc/gen/template/dialect/sql/query.tmpl +++ b/entc/gen/template/dialect/sql/query.tmpl @@ -23,7 +23,7 @@ in the LICENSE file in the root directory of this source tree. {{ $builder := pascal $.Scope.Builder }} {{ $receiver := receiver $builder }} -func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context) ([]*{{ $.Name }}, error) { +func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...queryHook) ([]*{{ $.Name }}, error) { var ( nodes = []*{{ $.Name }}{} {{- with $.UnexportedForeignKeys }} @@ -49,15 +49,11 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context) ([]*{{ $.Name } {{- end }} _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &{{ $.Name }}{config: {{ $receiver }}.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*{{ $.Name }}).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("{{ $pkg }}: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &{{ $.Name }}{config: {{ $receiver }}.config} + nodes = append(nodes, node) {{- with $.Edges }} node.Edges.loadedTypes = loadedTypes {{- end }} @@ -69,6 +65,9 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context) ([]*{{ $.Name {{- xtemplate $tmpl $ }} {{- end }} {{- end }} + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, {{ $receiver }}.driver, _spec); err != nil { return nil, err } @@ -272,74 +271,58 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select {{- $receiver := $.Scope.Rec }} if query := {{ $receiver }}.{{ $e.EagerLoadField }}; query != nil { {{- if $e.M2M }} - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[{{ $.ID.Type }}]*{{ $.Name }}, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[{{ $.ID.Type }}]*{{ $.Name }}) + nids := make(map[{{ $e.Type.ID.Type }}]map[*{{ $.Name }}]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.{{ $e.StructField }} = []*{{ $e.Type.Name }}{} } - var ( - edgeids []{{ $e.Type.ID.Type }} - edges = make(map[{{ $e.Type.ID.Type }}][]*{{ $.Name }}) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: {{ $e.IsInverse }}, - Table: {{ $.Package }}.{{ $e.TableConstant }}, - Columns: {{ $.Package }}.{{ $e.PKConstant }}, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues({{ $.Package }}.{{ $e.PKConstant }}[{{ if $e.IsInverse }}1{{ else }}0{{ end }}], fks...)) - }, + query.Where(func(s *sql.Selector) { + joinT := sql.Table({{ $.Package }}.{{ $e.TableConstant }}) + {{- $edgeid := print $e.Type.Package "." $e.Type.ID.Constant }} + {{- $fk1idx := 1 }}{{- $fk2idx := 0 }}{{ if $e.IsInverse }}{{ $fk1idx = 0 }}{{ $fk1idx = 1 }}{{ end }} + s.Join(joinT).On(s.C({{ $edgeid }}), joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk1idx }}])) + s.Where(sql.InValues(joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk2idx }}]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk2idx }}])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues {{- $out := "sql.NullInt64" }}{{ if $.ID.UserDefined }}{{ $out = $.ID.ScanType }}{{ end }} {{- $in := "sql.NullInt64" }}{{ if $e.Type.ID.UserDefined }}{{ $in = $e.Type.ID.ScanType }}{{ end }} - ScanValues: func() [2]interface{}{ - return [2]interface{}{new({{ $out }}), new({{ $in }})} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*{{ $out }}) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") - } - ein, ok := in.(*{{ $in }}) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") - } - outValue := {{ with extend $ "Arg" "eout" "Field" $.ID "ScanType" $out }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }} - inValue := {{ with extend $ "Arg" "ein" "Field" $e.Type.ID "ScanType" $in }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }} - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new({{ $out }})}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := {{ with extend $ "Arg" "values[0]" "Field" $.ID "ScanType" $out }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }} + inValue := {{ with extend $ "Arg" "values[1]" "Field" $e.Type.ID "ScanType" $in }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }} + if nids[inValue] == nil { + nids[inValue] = map[*{{ $.Name }}]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - {{- /* Allow mutating the sqlgraph.EdgeQuerySpec by ent extensions or user templates.*/}} - {{- with $tmpls := matchTemplate "dialect/sql/query/eagerloading/spec/*" }} - {{- range $tmpl := $tmpls }} - {{- xtemplate $tmpl $ }} - {{- end }} - {{- end }} - if err := sqlgraph.QueryEdges(ctx, {{ $receiver }}.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "{{ $e.Name }}": %w`, err) - } - query.Where({{ $e.Type.Package }}.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "{{ $e.Name }}" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.{{ $e.StructField }} = append(nodes[i].Edges.{{ $e.StructField }}, n) + for kn := range nodes { + kn.Edges.{{ $e.StructField }} = append(kn.Edges.{{ $e.StructField }}, n) } } {{- else if $e.OwnFK }} @@ -415,9 +398,9 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select {{- $field := $.Scope.Field }} {{- $scantype := $.Scope.ScanType }} {{- if hasPrefix $scantype "sql" -}} - {{ $field.ScanTypeField $arg -}} + {{ printf "%s.(*%s)" $arg $scantype | $field.ScanTypeField -}} {{- else -}} - {{ if not $field.Nillable }}*{{ end }}{{ $arg }} + {{ if not $field.Nillable }}*{{ end }}{{ printf "%s.(*%s)" $arg $scantype }} {{- end }} {{- end }} diff --git a/entc/integration/cascadelete/ent/comment_query.go b/entc/integration/cascadelete/ent/comment_query.go index 1a9b199cf6..7fdfca5902 100644 --- a/entc/integration/cascadelete/ent/comment_query.go +++ b/entc/integration/cascadelete/ent/comment_query.go @@ -354,7 +354,7 @@ func (cq *CommentQuery) prepareQuery(ctx context.Context) error { return nil } -func (cq *CommentQuery) sqlAll(ctx context.Context) ([]*Comment, error) { +func (cq *CommentQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Comment, error) { var ( nodes = []*Comment{} _spec = cq.querySpec() @@ -363,18 +363,17 @@ func (cq *CommentQuery) sqlAll(ctx context.Context) ([]*Comment, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Comment{config: cq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Comment).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Comment{config: cq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, cq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/cascadelete/ent/ent.go b/entc/integration/cascadelete/ent/ent.go index b47983685c..7ddd2024e1 100644 --- a/entc/integration/cascadelete/ent/ent.go +++ b/entc/integration/cascadelete/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/entc/integration/cascadelete/ent/comment" "entgo.io/ent/entc/integration/cascadelete/ent/post" "entgo.io/ent/entc/integration/cascadelete/ent/user" @@ -468,3 +469,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/entc/integration/cascadelete/ent/post_query.go b/entc/integration/cascadelete/ent/post_query.go index a30e3fde3b..dce36e0013 100644 --- a/entc/integration/cascadelete/ent/post_query.go +++ b/entc/integration/cascadelete/ent/post_query.go @@ -391,7 +391,7 @@ func (pq *PostQuery) prepareQuery(ctx context.Context) error { return nil } -func (pq *PostQuery) sqlAll(ctx context.Context) ([]*Post, error) { +func (pq *PostQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Post, error) { var ( nodes = []*Post{} _spec = pq.querySpec() @@ -401,18 +401,17 @@ func (pq *PostQuery) sqlAll(ctx context.Context) ([]*Post, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Post{config: pq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Post).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Post{config: pq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, pq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/cascadelete/ent/user_query.go b/entc/integration/cascadelete/ent/user_query.go index 2684df952e..4c4613f701 100644 --- a/entc/integration/cascadelete/ent/user_query.go +++ b/entc/integration/cascadelete/ent/user_query.go @@ -355,7 +355,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { return nil } -func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} _spec = uq.querySpec() @@ -364,18 +364,17 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &User{config: uq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*User).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &User{config: uq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/config/ent/ent.go b/entc/integration/config/ent/ent.go index 2f900602da..ad0320803d 100644 --- a/entc/integration/config/ent/ent.go +++ b/entc/integration/config/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/entc/integration/config/ent/user" ) @@ -464,3 +465,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/entc/integration/config/ent/user_query.go b/entc/integration/config/ent/user_query.go index 7596b7e6e4..a284d1bc47 100644 --- a/entc/integration/config/ent/user_query.go +++ b/entc/integration/config/ent/user_query.go @@ -317,23 +317,22 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { return nil } -func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} _spec = uq.querySpec() ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &User{config: uq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*User).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &User{config: uq.config} + nodes = append(nodes, node) return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/customid/ent/account_query.go b/entc/integration/customid/ent/account_query.go index 980df8090a..f806c41388 100644 --- a/entc/integration/customid/ent/account_query.go +++ b/entc/integration/customid/ent/account_query.go @@ -356,7 +356,7 @@ func (aq *AccountQuery) prepareQuery(ctx context.Context) error { return nil } -func (aq *AccountQuery) sqlAll(ctx context.Context) ([]*Account, error) { +func (aq *AccountQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Account, error) { var ( nodes = []*Account{} _spec = aq.querySpec() @@ -365,18 +365,17 @@ func (aq *AccountQuery) sqlAll(ctx context.Context) ([]*Account, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Account{config: aq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Account).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Account{config: aq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, aq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/customid/ent/blob_query.go b/entc/integration/customid/ent/blob_query.go index 10d3491ee4..966e7e5d51 100644 --- a/entc/integration/customid/ent/blob_query.go +++ b/entc/integration/customid/ent/blob_query.go @@ -391,7 +391,7 @@ func (bq *BlobQuery) prepareQuery(ctx context.Context) error { return nil } -func (bq *BlobQuery) sqlAll(ctx context.Context) ([]*Blob, error) { +func (bq *BlobQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Blob, error) { var ( nodes = []*Blob{} withFKs = bq.withFKs @@ -408,18 +408,17 @@ func (bq *BlobQuery) sqlAll(ctx context.Context) ([]*Blob, error) { _spec.Node.Columns = append(_spec.Node.Columns, blob.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Blob{config: bq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Blob).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Blob{config: bq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, bq.driver, _spec); err != nil { return nil, err } @@ -457,66 +456,54 @@ func (bq *BlobQuery) sqlAll(ctx context.Context) ([]*Blob, error) { } if query := bq.withLinks; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[uuid.UUID]*Blob, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[uuid.UUID]*Blob) + nids := make(map[uuid.UUID]map[*Blob]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Links = []*Blob{} } - var ( - edgeids []uuid.UUID - edges = make(map[uuid.UUID][]*Blob) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: false, - Table: blob.LinksTable, - Columns: blob.LinksPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(blob.LinksPrimaryKey[0], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(uuid.UUID), new(uuid.UUID)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*uuid.UUID) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") + query.Where(func(s *sql.Selector) { + joinT := sql.Table(blob.LinksTable) + s.Join(joinT).On(s.C(blob.FieldID), joinT.C(blob.LinksPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(blob.LinksPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(blob.LinksPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - ein, ok := in.(*uuid.UUID) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") - } - outValue := *eout - inValue := *ein - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) - } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(uuid.UUID)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := *values[0].(*uuid.UUID) + inValue := *values[1].(*uuid.UUID) + if nids[inValue] == nil { + nids[inValue] = map[*Blob]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, bq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "links": %w`, err) - } - query.Where(blob.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "links" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Links = append(nodes[i].Edges.Links, n) + for kn := range nodes { + kn.Edges.Links = append(kn.Edges.Links, n) } } } diff --git a/entc/integration/customid/ent/car_query.go b/entc/integration/customid/ent/car_query.go index a4b7906355..8cbacc435c 100644 --- a/entc/integration/customid/ent/car_query.go +++ b/entc/integration/customid/ent/car_query.go @@ -355,7 +355,7 @@ func (cq *CarQuery) prepareQuery(ctx context.Context) error { return nil } -func (cq *CarQuery) sqlAll(ctx context.Context) ([]*Car, error) { +func (cq *CarQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Car, error) { var ( nodes = []*Car{} withFKs = cq.withFKs @@ -371,18 +371,17 @@ func (cq *CarQuery) sqlAll(ctx context.Context) ([]*Car, error) { _spec.Node.Columns = append(_spec.Node.Columns, car.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Car{config: cq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Car).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Car{config: cq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, cq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/customid/ent/device_query.go b/entc/integration/customid/ent/device_query.go index 8e139a24e3..94937680ea 100644 --- a/entc/integration/customid/ent/device_query.go +++ b/entc/integration/customid/ent/device_query.go @@ -368,7 +368,7 @@ func (dq *DeviceQuery) prepareQuery(ctx context.Context) error { return nil } -func (dq *DeviceQuery) sqlAll(ctx context.Context) ([]*Device, error) { +func (dq *DeviceQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Device, error) { var ( nodes = []*Device{} withFKs = dq.withFKs @@ -385,18 +385,17 @@ func (dq *DeviceQuery) sqlAll(ctx context.Context) ([]*Device, error) { _spec.Node.Columns = append(_spec.Node.Columns, device.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Device{config: dq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Device).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Device{config: dq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, dq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/customid/ent/doc_query.go b/entc/integration/customid/ent/doc_query.go index 635544e96c..d373bb22c3 100644 --- a/entc/integration/customid/ent/doc_query.go +++ b/entc/integration/customid/ent/doc_query.go @@ -391,7 +391,7 @@ func (dq *DocQuery) prepareQuery(ctx context.Context) error { return nil } -func (dq *DocQuery) sqlAll(ctx context.Context) ([]*Doc, error) { +func (dq *DocQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Doc, error) { var ( nodes = []*Doc{} withFKs = dq.withFKs @@ -408,18 +408,17 @@ func (dq *DocQuery) sqlAll(ctx context.Context) ([]*Doc, error) { _spec.Node.Columns = append(_spec.Node.Columns, doc.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Doc{config: dq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Doc).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Doc{config: dq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, dq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/customid/ent/ent.go b/entc/integration/customid/ent/ent.go index 7937569aca..b9082e1630 100644 --- a/entc/integration/customid/ent/ent.go +++ b/entc/integration/customid/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/entc/integration/customid/ent/account" "entgo.io/ent/entc/integration/customid/ent/blob" "entgo.io/ent/entc/integration/customid/ent/car" @@ -488,3 +489,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/entc/integration/customid/ent/group_query.go b/entc/integration/customid/ent/group_query.go index 1f30e7bacd..7e8df9d1ce 100644 --- a/entc/integration/customid/ent/group_query.go +++ b/entc/integration/customid/ent/group_query.go @@ -331,7 +331,7 @@ func (gq *GroupQuery) prepareQuery(ctx context.Context) error { return nil } -func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { +func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, error) { var ( nodes = []*Group{} _spec = gq.querySpec() @@ -340,18 +340,17 @@ func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Group{config: gq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Group).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Group{config: gq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, gq.driver, _spec); err != nil { return nil, err } @@ -360,66 +359,54 @@ func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { } if query := gq.withUsers; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*Group, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*Group) + nids := make(map[int]map[*Group]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Users = []*User{} } - var ( - edgeids []int - edges = make(map[int][]*Group) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: false, - Table: group.UsersTable, - Columns: group.UsersPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(group.UsersPrimaryKey[0], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") - } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") + query.Where(func(s *sql.Selector) { + joinT := sql.Table(group.UsersTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(group.UsersPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) - } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Group]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, gq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "users": %w`, err) - } - query.Where(user.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "users" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Users = append(nodes[i].Edges.Users, n) + for kn := range nodes { + kn.Edges.Users = append(kn.Edges.Users, n) } } } diff --git a/entc/integration/customid/ent/mixinid_query.go b/entc/integration/customid/ent/mixinid_query.go index d478d29c5e..5a1faaa19d 100644 --- a/entc/integration/customid/ent/mixinid_query.go +++ b/entc/integration/customid/ent/mixinid_query.go @@ -318,23 +318,22 @@ func (miq *MixinIDQuery) prepareQuery(ctx context.Context) error { return nil } -func (miq *MixinIDQuery) sqlAll(ctx context.Context) ([]*MixinID, error) { +func (miq *MixinIDQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*MixinID, error) { var ( nodes = []*MixinID{} _spec = miq.querySpec() ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &MixinID{config: miq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*MixinID).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &MixinID{config: miq.config} + nodes = append(nodes, node) return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, miq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/customid/ent/note_query.go b/entc/integration/customid/ent/note_query.go index b0ae7d6454..fffdaae0ec 100644 --- a/entc/integration/customid/ent/note_query.go +++ b/entc/integration/customid/ent/note_query.go @@ -391,7 +391,7 @@ func (nq *NoteQuery) prepareQuery(ctx context.Context) error { return nil } -func (nq *NoteQuery) sqlAll(ctx context.Context) ([]*Note, error) { +func (nq *NoteQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Note, error) { var ( nodes = []*Note{} withFKs = nq.withFKs @@ -408,18 +408,17 @@ func (nq *NoteQuery) sqlAll(ctx context.Context) ([]*Note, error) { _spec.Node.Columns = append(_spec.Node.Columns, note.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Note{config: nq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Note).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Note{config: nq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, nq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/customid/ent/other_query.go b/entc/integration/customid/ent/other_query.go index 869d07e728..10c9163920 100644 --- a/entc/integration/customid/ent/other_query.go +++ b/entc/integration/customid/ent/other_query.go @@ -294,23 +294,22 @@ func (oq *OtherQuery) prepareQuery(ctx context.Context) error { return nil } -func (oq *OtherQuery) sqlAll(ctx context.Context) ([]*Other, error) { +func (oq *OtherQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Other, error) { var ( nodes = []*Other{} _spec = oq.querySpec() ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Other{config: oq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Other).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Other{config: oq.config} + nodes = append(nodes, node) return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, oq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/customid/ent/pet_query.go b/entc/integration/customid/ent/pet_query.go index 39b59cc1c7..14765d413d 100644 --- a/entc/integration/customid/ent/pet_query.go +++ b/entc/integration/customid/ent/pet_query.go @@ -438,7 +438,7 @@ func (pq *PetQuery) prepareQuery(ctx context.Context) error { return nil } -func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) { +func (pq *PetQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Pet, error) { var ( nodes = []*Pet{} withFKs = pq.withFKs @@ -457,18 +457,17 @@ func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) { _spec.Node.Columns = append(_spec.Node.Columns, pet.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Pet{config: pq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Pet).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Pet{config: pq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, pq.driver, _spec); err != nil { return nil, err } @@ -535,66 +534,54 @@ func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) { } if query := pq.withFriends; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[string]*Pet, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[string]*Pet) + nids := make(map[string]map[*Pet]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Friends = []*Pet{} } - var ( - edgeids []string - edges = make(map[string][]*Pet) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: false, - Table: pet.FriendsTable, - Columns: pet.FriendsPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(pet.FriendsPrimaryKey[0], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullString), new(sql.NullString)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullString) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") - } - ein, ok := in.(*sql.NullString) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") - } - outValue := eout.String - inValue := ein.String - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) + query.Where(func(s *sql.Selector) { + joinT := sql.Table(pet.FriendsTable) + s.Join(joinT).On(s.C(pet.FieldID), joinT.C(pet.FriendsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(pet.FriendsPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(pet.FriendsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullString)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := values[0].(*sql.NullString).String + inValue := values[1].(*sql.NullString).String + if nids[inValue] == nil { + nids[inValue] = map[*Pet]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, pq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "friends": %w`, err) - } - query.Where(pet.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "friends" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Friends = append(nodes[i].Edges.Friends, n) + for kn := range nodes { + kn.Edges.Friends = append(kn.Edges.Friends, n) } } } diff --git a/entc/integration/customid/ent/session_query.go b/entc/integration/customid/ent/session_query.go index 4936b00aec..83cb355d67 100644 --- a/entc/integration/customid/ent/session_query.go +++ b/entc/integration/customid/ent/session_query.go @@ -332,7 +332,7 @@ func (sq *SessionQuery) prepareQuery(ctx context.Context) error { return nil } -func (sq *SessionQuery) sqlAll(ctx context.Context) ([]*Session, error) { +func (sq *SessionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Session, error) { var ( nodes = []*Session{} withFKs = sq.withFKs @@ -348,18 +348,17 @@ func (sq *SessionQuery) sqlAll(ctx context.Context) ([]*Session, error) { _spec.Node.Columns = append(_spec.Node.Columns, session.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Session{config: sq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Session).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Session{config: sq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, sq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/customid/ent/token_query.go b/entc/integration/customid/ent/token_query.go index 7c38e26521..dd4dcce56d 100644 --- a/entc/integration/customid/ent/token_query.go +++ b/entc/integration/customid/ent/token_query.go @@ -356,7 +356,7 @@ func (tq *TokenQuery) prepareQuery(ctx context.Context) error { return nil } -func (tq *TokenQuery) sqlAll(ctx context.Context) ([]*Token, error) { +func (tq *TokenQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Token, error) { var ( nodes = []*Token{} withFKs = tq.withFKs @@ -372,18 +372,17 @@ func (tq *TokenQuery) sqlAll(ctx context.Context) ([]*Token, error) { _spec.Node.Columns = append(_spec.Node.Columns, token.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Token{config: tq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Token).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Token{config: tq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, tq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/customid/ent/user_query.go b/entc/integration/customid/ent/user_query.go index 8478fa2983..87c356b6d7 100644 --- a/entc/integration/customid/ent/user_query.go +++ b/entc/integration/customid/ent/user_query.go @@ -438,7 +438,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { return nil } -func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} withFKs = uq.withFKs @@ -457,18 +457,17 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { _spec.Node.Columns = append(_spec.Node.Columns, user.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &User{config: uq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*User).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &User{config: uq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err } @@ -477,66 +476,54 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } if query := uq.withGroups; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*User, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Groups = []*Group{} } - var ( - edgeids []int - edges = make(map[int][]*User) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: true, - Table: user.GroupsTable, - Columns: user.GroupsPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(user.GroupsPrimaryKey[1], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") - } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") - } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.GroupsTable) + s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.GroupsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, uq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "groups": %w`, err) - } - query.Where(group.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "groups" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Groups = append(nodes[i].Edges.Groups, n) + for kn := range nodes { + kn.Edges.Groups = append(kn.Edges.Groups, n) } } } diff --git a/entc/integration/edgefield/ent/car_query.go b/entc/integration/edgefield/ent/car_query.go index 13374c5e49..f44918b293 100644 --- a/entc/integration/edgefield/ent/car_query.go +++ b/entc/integration/edgefield/ent/car_query.go @@ -356,7 +356,7 @@ func (cq *CarQuery) prepareQuery(ctx context.Context) error { return nil } -func (cq *CarQuery) sqlAll(ctx context.Context) ([]*Car, error) { +func (cq *CarQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Car, error) { var ( nodes = []*Car{} _spec = cq.querySpec() @@ -365,18 +365,17 @@ func (cq *CarQuery) sqlAll(ctx context.Context) ([]*Car, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Car{config: cq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Car).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Car{config: cq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, cq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/edgefield/ent/card_query.go b/entc/integration/edgefield/ent/card_query.go index 9de5eaf0ea..8d5997c7c3 100644 --- a/entc/integration/edgefield/ent/card_query.go +++ b/entc/integration/edgefield/ent/card_query.go @@ -354,7 +354,7 @@ func (cq *CardQuery) prepareQuery(ctx context.Context) error { return nil } -func (cq *CardQuery) sqlAll(ctx context.Context) ([]*Card, error) { +func (cq *CardQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Card, error) { var ( nodes = []*Card{} _spec = cq.querySpec() @@ -363,18 +363,17 @@ func (cq *CardQuery) sqlAll(ctx context.Context) ([]*Card, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Card{config: cq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Card).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Card{config: cq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, cq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/edgefield/ent/ent.go b/entc/integration/edgefield/ent/ent.go index 16299e1f31..5e32f13da0 100644 --- a/entc/integration/edgefield/ent/ent.go +++ b/entc/integration/edgefield/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/entc/integration/edgefield/ent/car" "entgo.io/ent/entc/integration/edgefield/ent/card" "entgo.io/ent/entc/integration/edgefield/ent/info" @@ -480,3 +481,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/entc/integration/edgefield/ent/info_query.go b/entc/integration/edgefield/ent/info_query.go index 14f330db3e..2f047da026 100644 --- a/entc/integration/edgefield/ent/info_query.go +++ b/entc/integration/edgefield/ent/info_query.go @@ -354,7 +354,7 @@ func (iq *InfoQuery) prepareQuery(ctx context.Context) error { return nil } -func (iq *InfoQuery) sqlAll(ctx context.Context) ([]*Info, error) { +func (iq *InfoQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Info, error) { var ( nodes = []*Info{} _spec = iq.querySpec() @@ -363,18 +363,17 @@ func (iq *InfoQuery) sqlAll(ctx context.Context) ([]*Info, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Info{config: iq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Info).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Info{config: iq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, iq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/edgefield/ent/metadata_query.go b/entc/integration/edgefield/ent/metadata_query.go index 7ca5f532d7..1e4d1ed610 100644 --- a/entc/integration/edgefield/ent/metadata_query.go +++ b/entc/integration/edgefield/ent/metadata_query.go @@ -425,7 +425,7 @@ func (mq *MetadataQuery) prepareQuery(ctx context.Context) error { return nil } -func (mq *MetadataQuery) sqlAll(ctx context.Context) ([]*Metadata, error) { +func (mq *MetadataQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Metadata, error) { var ( nodes = []*Metadata{} _spec = mq.querySpec() @@ -436,18 +436,17 @@ func (mq *MetadataQuery) sqlAll(ctx context.Context) ([]*Metadata, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Metadata{config: mq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Metadata).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Metadata{config: mq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, mq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/edgefield/ent/node_query.go b/entc/integration/edgefield/ent/node_query.go index 3148566101..0f91c16c28 100644 --- a/entc/integration/edgefield/ent/node_query.go +++ b/entc/integration/edgefield/ent/node_query.go @@ -389,7 +389,7 @@ func (nq *NodeQuery) prepareQuery(ctx context.Context) error { return nil } -func (nq *NodeQuery) sqlAll(ctx context.Context) ([]*Node, error) { +func (nq *NodeQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Node, error) { var ( nodes = []*Node{} _spec = nq.querySpec() @@ -399,18 +399,17 @@ func (nq *NodeQuery) sqlAll(ctx context.Context) ([]*Node, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Node{config: nq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Node).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Node{config: nq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, nq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/edgefield/ent/pet_query.go b/entc/integration/edgefield/ent/pet_query.go index c2090ceadf..4b477bf0d5 100644 --- a/entc/integration/edgefield/ent/pet_query.go +++ b/entc/integration/edgefield/ent/pet_query.go @@ -354,7 +354,7 @@ func (pq *PetQuery) prepareQuery(ctx context.Context) error { return nil } -func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) { +func (pq *PetQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Pet, error) { var ( nodes = []*Pet{} _spec = pq.querySpec() @@ -363,18 +363,17 @@ func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Pet{config: pq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Pet).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Pet{config: pq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, pq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/edgefield/ent/post_query.go b/entc/integration/edgefield/ent/post_query.go index 806e48aa24..7591eeba42 100644 --- a/entc/integration/edgefield/ent/post_query.go +++ b/entc/integration/edgefield/ent/post_query.go @@ -354,7 +354,7 @@ func (pq *PostQuery) prepareQuery(ctx context.Context) error { return nil } -func (pq *PostQuery) sqlAll(ctx context.Context) ([]*Post, error) { +func (pq *PostQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Post, error) { var ( nodes = []*Post{} _spec = pq.querySpec() @@ -363,18 +363,17 @@ func (pq *PostQuery) sqlAll(ctx context.Context) ([]*Post, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Post{config: pq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Post).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Post{config: pq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, pq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/edgefield/ent/rental_query.go b/entc/integration/edgefield/ent/rental_query.go index 7e73d0f134..8f81b522e6 100644 --- a/entc/integration/edgefield/ent/rental_query.go +++ b/entc/integration/edgefield/ent/rental_query.go @@ -391,7 +391,7 @@ func (rq *RentalQuery) prepareQuery(ctx context.Context) error { return nil } -func (rq *RentalQuery) sqlAll(ctx context.Context) ([]*Rental, error) { +func (rq *RentalQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Rental, error) { var ( nodes = []*Rental{} _spec = rq.querySpec() @@ -401,18 +401,17 @@ func (rq *RentalQuery) sqlAll(ctx context.Context) ([]*Rental, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Rental{config: rq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Rental).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Rental{config: rq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, rq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/edgefield/ent/user_query.go b/entc/integration/edgefield/ent/user_query.go index fafed2bf03..a33f7dfb79 100644 --- a/entc/integration/edgefield/ent/user_query.go +++ b/entc/integration/edgefield/ent/user_query.go @@ -604,7 +604,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { return nil } -func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} _spec = uq.querySpec() @@ -620,18 +620,17 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &User{config: uq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*User).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &User{config: uq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/ent/card_query.go b/entc/integration/ent/card_query.go index 420c20f1bf..42fc8bcd8d 100644 --- a/entc/integration/ent/card_query.go +++ b/entc/integration/ent/card_query.go @@ -394,7 +394,7 @@ func (cq *CardQuery) prepareQuery(ctx context.Context) error { return nil } -func (cq *CardQuery) sqlAll(ctx context.Context) ([]*Card, error) { +func (cq *CardQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Card, error) { var ( nodes = []*Card{} withFKs = cq.withFKs @@ -411,21 +411,20 @@ func (cq *CardQuery) sqlAll(ctx context.Context) ([]*Card, error) { _spec.Node.Columns = append(_spec.Node.Columns, card.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Card{config: cq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Card).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Card{config: cq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } if len(cq.modifiers) > 0 { _spec.Modifiers = cq.modifiers } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, cq.driver, _spec); err != nil { return nil, err } @@ -463,66 +462,54 @@ func (cq *CardQuery) sqlAll(ctx context.Context) ([]*Card, error) { } if query := cq.withSpec; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*Card, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*Card) + nids := make(map[int]map[*Card]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Spec = []*Spec{} } - var ( - edgeids []int - edges = make(map[int][]*Card) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: true, - Table: card.SpecTable, - Columns: card.SpecPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(card.SpecPrimaryKey[1], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") + query.Where(func(s *sql.Selector) { + joinT := sql.Table(card.SpecTable) + s.Join(joinT).On(s.C(spec.FieldID), joinT.C(card.SpecPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(card.SpecPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(card.SpecPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") - } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) - } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Card]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, cq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "spec": %w`, err) - } - query.Where(spec.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "spec" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Spec = append(nodes[i].Edges.Spec, n) + for kn := range nodes { + kn.Edges.Spec = append(kn.Edges.Spec, n) } } } diff --git a/entc/integration/ent/comment_query.go b/entc/integration/ent/comment_query.go index 5e246f33d6..2786e0e927 100644 --- a/entc/integration/ent/comment_query.go +++ b/entc/integration/ent/comment_query.go @@ -319,26 +319,25 @@ func (cq *CommentQuery) prepareQuery(ctx context.Context) error { return nil } -func (cq *CommentQuery) sqlAll(ctx context.Context) ([]*Comment, error) { +func (cq *CommentQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Comment, error) { var ( nodes = []*Comment{} _spec = cq.querySpec() ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Comment{config: cq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Comment).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Comment{config: cq.config} + nodes = append(nodes, node) return node.assignValues(columns, values) } if len(cq.modifiers) > 0 { _spec.Modifiers = cq.modifiers } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, cq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/ent/ent.go b/entc/integration/ent/ent.go index b710573bd7..66a16e24fe 100644 --- a/entc/integration/ent/ent.go +++ b/entc/integration/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/entc/integration/ent/card" "entgo.io/ent/entc/integration/ent/comment" "entgo.io/ent/entc/integration/ent/fieldtype" @@ -490,3 +491,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/entc/integration/ent/fieldtype_query.go b/entc/integration/ent/fieldtype_query.go index e0d79573dd..47c90a9596 100644 --- a/entc/integration/ent/fieldtype_query.go +++ b/entc/integration/ent/fieldtype_query.go @@ -320,7 +320,7 @@ func (ftq *FieldTypeQuery) prepareQuery(ctx context.Context) error { return nil } -func (ftq *FieldTypeQuery) sqlAll(ctx context.Context) ([]*FieldType, error) { +func (ftq *FieldTypeQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*FieldType, error) { var ( nodes = []*FieldType{} withFKs = ftq.withFKs @@ -330,20 +330,19 @@ func (ftq *FieldTypeQuery) sqlAll(ctx context.Context) ([]*FieldType, error) { _spec.Node.Columns = append(_spec.Node.Columns, fieldtype.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &FieldType{config: ftq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*FieldType).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &FieldType{config: ftq.config} + nodes = append(nodes, node) return node.assignValues(columns, values) } if len(ftq.modifiers) > 0 { _spec.Modifiers = ftq.modifiers } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, ftq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/ent/file_query.go b/entc/integration/ent/file_query.go index 6eade95653..e79495bec5 100644 --- a/entc/integration/ent/file_query.go +++ b/entc/integration/ent/file_query.go @@ -430,7 +430,7 @@ func (fq *FileQuery) prepareQuery(ctx context.Context) error { return nil } -func (fq *FileQuery) sqlAll(ctx context.Context) ([]*File, error) { +func (fq *FileQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*File, error) { var ( nodes = []*File{} withFKs = fq.withFKs @@ -448,21 +448,20 @@ func (fq *FileQuery) sqlAll(ctx context.Context) ([]*File, error) { _spec.Node.Columns = append(_spec.Node.Columns, file.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &File{config: fq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*File).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &File{config: fq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } if len(fq.modifiers) > 0 { _spec.Modifiers = fq.modifiers } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, fq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/ent/filetype_query.go b/entc/integration/ent/filetype_query.go index 9c7fc4223d..e8dd352290 100644 --- a/entc/integration/ent/filetype_query.go +++ b/entc/integration/ent/filetype_query.go @@ -357,7 +357,7 @@ func (ftq *FileTypeQuery) prepareQuery(ctx context.Context) error { return nil } -func (ftq *FileTypeQuery) sqlAll(ctx context.Context) ([]*FileType, error) { +func (ftq *FileTypeQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*FileType, error) { var ( nodes = []*FileType{} _spec = ftq.querySpec() @@ -366,21 +366,20 @@ func (ftq *FileTypeQuery) sqlAll(ctx context.Context) ([]*FileType, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &FileType{config: ftq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*FileType).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &FileType{config: ftq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } if len(ftq.modifiers) > 0 { _spec.Modifiers = ftq.modifiers } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, ftq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/ent/goods_query.go b/entc/integration/ent/goods_query.go index 37db1699e0..16f4bac4dc 100644 --- a/entc/integration/ent/goods_query.go +++ b/entc/integration/ent/goods_query.go @@ -295,26 +295,25 @@ func (gq *GoodsQuery) prepareQuery(ctx context.Context) error { return nil } -func (gq *GoodsQuery) sqlAll(ctx context.Context) ([]*Goods, error) { +func (gq *GoodsQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Goods, error) { var ( nodes = []*Goods{} _spec = gq.querySpec() ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Goods{config: gq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Goods).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Goods{config: gq.config} + nodes = append(nodes, node) return node.assignValues(columns, values) } if len(gq.modifiers) > 0 { _spec.Modifiers = gq.modifiers } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, gq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/ent/group_query.go b/entc/integration/ent/group_query.go index 5499a9ffba..41abcdccc2 100644 --- a/entc/integration/ent/group_query.go +++ b/entc/integration/ent/group_query.go @@ -465,7 +465,7 @@ func (gq *GroupQuery) prepareQuery(ctx context.Context) error { return nil } -func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { +func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, error) { var ( nodes = []*Group{} withFKs = gq.withFKs @@ -484,21 +484,20 @@ func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { _spec.Node.Columns = append(_spec.Node.Columns, group.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Group{config: gq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Group).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Group{config: gq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } if len(gq.modifiers) > 0 { _spec.Modifiers = gq.modifiers } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, gq.driver, _spec); err != nil { return nil, err } @@ -565,66 +564,54 @@ func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { } if query := gq.withUsers; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*Group, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*Group) + nids := make(map[int]map[*Group]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Users = []*User{} } - var ( - edgeids []int - edges = make(map[int][]*Group) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: true, - Table: group.UsersTable, - Columns: group.UsersPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(group.UsersPrimaryKey[1], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") - } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") - } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) + query.Where(func(s *sql.Selector) { + joinT := sql.Table(group.UsersTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(group.UsersPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Group]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, gq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "users": %w`, err) - } - query.Where(user.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "users" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Users = append(nodes[i].Edges.Users, n) + for kn := range nodes { + kn.Edges.Users = append(kn.Edges.Users, n) } } } diff --git a/entc/integration/ent/groupinfo_query.go b/entc/integration/ent/groupinfo_query.go index 125ff24de5..f4fddd7d12 100644 --- a/entc/integration/ent/groupinfo_query.go +++ b/entc/integration/ent/groupinfo_query.go @@ -357,7 +357,7 @@ func (giq *GroupInfoQuery) prepareQuery(ctx context.Context) error { return nil } -func (giq *GroupInfoQuery) sqlAll(ctx context.Context) ([]*GroupInfo, error) { +func (giq *GroupInfoQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*GroupInfo, error) { var ( nodes = []*GroupInfo{} _spec = giq.querySpec() @@ -366,21 +366,20 @@ func (giq *GroupInfoQuery) sqlAll(ctx context.Context) ([]*GroupInfo, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &GroupInfo{config: giq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*GroupInfo).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &GroupInfo{config: giq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } if len(giq.modifiers) > 0 { _spec.Modifiers = giq.modifiers } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, giq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/ent/item_query.go b/entc/integration/ent/item_query.go index 32fdfa76a5..6ead8f1ce9 100644 --- a/entc/integration/ent/item_query.go +++ b/entc/integration/ent/item_query.go @@ -319,26 +319,25 @@ func (iq *ItemQuery) prepareQuery(ctx context.Context) error { return nil } -func (iq *ItemQuery) sqlAll(ctx context.Context) ([]*Item, error) { +func (iq *ItemQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Item, error) { var ( nodes = []*Item{} _spec = iq.querySpec() ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Item{config: iq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Item).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Item{config: iq.config} + nodes = append(nodes, node) return node.assignValues(columns, values) } if len(iq.modifiers) > 0 { _spec.Modifiers = iq.modifiers } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, iq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/ent/node_query.go b/entc/integration/ent/node_query.go index ca0a580838..a82b000ec4 100644 --- a/entc/integration/ent/node_query.go +++ b/entc/integration/ent/node_query.go @@ -392,7 +392,7 @@ func (nq *NodeQuery) prepareQuery(ctx context.Context) error { return nil } -func (nq *NodeQuery) sqlAll(ctx context.Context) ([]*Node, error) { +func (nq *NodeQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Node, error) { var ( nodes = []*Node{} withFKs = nq.withFKs @@ -409,21 +409,20 @@ func (nq *NodeQuery) sqlAll(ctx context.Context) ([]*Node, error) { _spec.Node.Columns = append(_spec.Node.Columns, node.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Node{config: nq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Node).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Node{config: nq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } if len(nq.modifiers) > 0 { _spec.Modifiers = nq.modifiers } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, nq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/ent/pet_query.go b/entc/integration/ent/pet_query.go index 6816719bb0..1213381954 100644 --- a/entc/integration/ent/pet_query.go +++ b/entc/integration/ent/pet_query.go @@ -392,7 +392,7 @@ func (pq *PetQuery) prepareQuery(ctx context.Context) error { return nil } -func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) { +func (pq *PetQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Pet, error) { var ( nodes = []*Pet{} withFKs = pq.withFKs @@ -409,21 +409,20 @@ func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) { _spec.Node.Columns = append(_spec.Node.Columns, pet.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Pet{config: pq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Pet).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Pet{config: pq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } if len(pq.modifiers) > 0 { _spec.Modifiers = pq.modifiers } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, pq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/ent/spec_query.go b/entc/integration/ent/spec_query.go index a3f62d1b4a..9172d6e8fc 100644 --- a/entc/integration/ent/spec_query.go +++ b/entc/integration/ent/spec_query.go @@ -333,7 +333,7 @@ func (sq *SpecQuery) prepareQuery(ctx context.Context) error { return nil } -func (sq *SpecQuery) sqlAll(ctx context.Context) ([]*Spec, error) { +func (sq *SpecQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Spec, error) { var ( nodes = []*Spec{} _spec = sq.querySpec() @@ -342,21 +342,20 @@ func (sq *SpecQuery) sqlAll(ctx context.Context) ([]*Spec, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Spec{config: sq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Spec).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Spec{config: sq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } if len(sq.modifiers) > 0 { _spec.Modifiers = sq.modifiers } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, sq.driver, _spec); err != nil { return nil, err } @@ -365,66 +364,54 @@ func (sq *SpecQuery) sqlAll(ctx context.Context) ([]*Spec, error) { } if query := sq.withCard; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*Spec, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*Spec) + nids := make(map[int]map[*Spec]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Card = []*Card{} } - var ( - edgeids []int - edges = make(map[int][]*Spec) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: false, - Table: spec.CardTable, - Columns: spec.CardPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(spec.CardPrimaryKey[0], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") - } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") + query.Where(func(s *sql.Selector) { + joinT := sql.Table(spec.CardTable) + s.Join(joinT).On(s.C(card.FieldID), joinT.C(spec.CardPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(spec.CardPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(spec.CardPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) - } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Spec]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, sq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "card": %w`, err) - } - query.Where(card.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "card" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Card = append(nodes[i].Edges.Card, n) + for kn := range nodes { + kn.Edges.Card = append(kn.Edges.Card, n) } } } diff --git a/entc/integration/ent/task_query.go b/entc/integration/ent/task_query.go index 841dd8131a..7ef8f972f8 100644 --- a/entc/integration/ent/task_query.go +++ b/entc/integration/ent/task_query.go @@ -320,26 +320,25 @@ func (tq *TaskQuery) prepareQuery(ctx context.Context) error { return nil } -func (tq *TaskQuery) sqlAll(ctx context.Context) ([]*Task, error) { +func (tq *TaskQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Task, error) { var ( nodes = []*Task{} _spec = tq.querySpec() ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Task{config: tq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Task).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Task{config: tq.config} + nodes = append(nodes, node) return node.assignValues(columns, values) } if len(tq.modifiers) > 0 { _spec.Modifiers = tq.modifiers } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, tq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/ent/user_query.go b/entc/integration/ent/user_query.go index aa0f5f7c4a..92a9c2e30f 100644 --- a/entc/integration/ent/user_query.go +++ b/entc/integration/ent/user_query.go @@ -711,7 +711,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { return nil } -func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} withFKs = uq.withFKs @@ -737,21 +737,20 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { _spec.Node.Columns = append(_spec.Node.Columns, user.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &User{config: uq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*User).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &User{config: uq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } if len(uq.modifiers) > 0 { _spec.Modifiers = uq.modifiers } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err } @@ -846,261 +845,213 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } if query := uq.withGroups; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*User, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Groups = []*Group{} } - var ( - edgeids []int - edges = make(map[int][]*User) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: false, - Table: user.GroupsTable, - Columns: user.GroupsPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(user.GroupsPrimaryKey[0], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") - } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.GroupsTable) + s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.GroupsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) - } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, uq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "groups": %w`, err) - } - query.Where(group.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "groups" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Groups = append(nodes[i].Edges.Groups, n) + for kn := range nodes { + kn.Edges.Groups = append(kn.Edges.Groups, n) } } } if query := uq.withFriends; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*User, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Friends = []*User{} } - var ( - edgeids []int - edges = make(map[int][]*User) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: false, - Table: user.FriendsTable, - Columns: user.FriendsPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(user.FriendsPrimaryKey[0], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") - } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.FriendsTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FriendsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.FriendsPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.FriendsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) - } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, uq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "friends": %w`, err) - } - query.Where(user.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "friends" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Friends = append(nodes[i].Edges.Friends, n) + for kn := range nodes { + kn.Edges.Friends = append(kn.Edges.Friends, n) } } } if query := uq.withFollowers; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*User, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Followers = []*User{} } - var ( - edgeids []int - edges = make(map[int][]*User) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: true, - Table: user.FollowersTable, - Columns: user.FollowersPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(user.FollowersPrimaryKey[1], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.FollowersTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FollowersPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.FollowersPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.FollowersPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") - } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) - } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, uq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "followers": %w`, err) - } - query.Where(user.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "followers" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Followers = append(nodes[i].Edges.Followers, n) + for kn := range nodes { + kn.Edges.Followers = append(kn.Edges.Followers, n) } } } if query := uq.withFollowing; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*User, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Following = []*User{} } - var ( - edgeids []int - edges = make(map[int][]*User) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: false, - Table: user.FollowingTable, - Columns: user.FollowingPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(user.FollowingPrimaryKey[0], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") - } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") - } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.FollowingTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FollowingPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.FollowingPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.FollowingPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, uq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "following": %w`, err) - } - query.Where(user.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "following" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Following = append(nodes[i].Edges.Following, n) + for kn := range nodes { + kn.Edges.Following = append(kn.Edges.Following, n) } } } diff --git a/entc/integration/hooks/ent/card_query.go b/entc/integration/hooks/ent/card_query.go index 7278ba44d5..a7f681f7dc 100644 --- a/entc/integration/hooks/ent/card_query.go +++ b/entc/integration/hooks/ent/card_query.go @@ -355,7 +355,7 @@ func (cq *CardQuery) prepareQuery(ctx context.Context) error { return nil } -func (cq *CardQuery) sqlAll(ctx context.Context) ([]*Card, error) { +func (cq *CardQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Card, error) { var ( nodes = []*Card{} withFKs = cq.withFKs @@ -371,18 +371,17 @@ func (cq *CardQuery) sqlAll(ctx context.Context) ([]*Card, error) { _spec.Node.Columns = append(_spec.Node.Columns, card.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Card{config: cq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Card).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Card{config: cq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, cq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/hooks/ent/ent.go b/entc/integration/hooks/ent/ent.go index 08f4e12f11..2dbae67ea4 100644 --- a/entc/integration/hooks/ent/ent.go +++ b/entc/integration/hooks/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/entc/integration/hooks/ent/card" "entgo.io/ent/entc/integration/hooks/ent/user" ) @@ -466,3 +467,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/entc/integration/hooks/ent/user_query.go b/entc/integration/hooks/ent/user_query.go index ff437ec4ef..fa3a602f7e 100644 --- a/entc/integration/hooks/ent/user_query.go +++ b/entc/integration/hooks/ent/user_query.go @@ -426,7 +426,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { return nil } -func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} withFKs = uq.withFKs @@ -444,18 +444,17 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { _spec.Node.Columns = append(_spec.Node.Columns, user.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &User{config: uq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*User).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &User{config: uq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err } @@ -493,66 +492,54 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } if query := uq.withFriends; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*User, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Friends = []*User{} } - var ( - edgeids []int - edges = make(map[int][]*User) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: false, - Table: user.FriendsTable, - Columns: user.FriendsPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(user.FriendsPrimaryKey[0], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") - } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") - } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.FriendsTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FriendsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.FriendsPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.FriendsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, uq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "friends": %w`, err) - } - query.Where(user.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "friends" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Friends = append(nodes[i].Edges.Friends, n) + for kn := range nodes { + kn.Edges.Friends = append(kn.Edges.Friends, n) } } } diff --git a/entc/integration/idtype/ent/ent.go b/entc/integration/idtype/ent/ent.go index a119012467..d2f36f21f3 100644 --- a/entc/integration/idtype/ent/ent.go +++ b/entc/integration/idtype/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/entc/integration/idtype/ent/user" ) @@ -464,3 +465,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/entc/integration/idtype/ent/user_query.go b/entc/integration/idtype/ent/user_query.go index 95d009838b..5822cc05df 100644 --- a/entc/integration/idtype/ent/user_query.go +++ b/entc/integration/idtype/ent/user_query.go @@ -425,7 +425,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { return nil } -func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} withFKs = uq.withFKs @@ -443,18 +443,17 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { _spec.Node.Columns = append(_spec.Node.Columns, user.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &User{config: uq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*User).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &User{config: uq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err } @@ -492,131 +491,107 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } if query := uq.withFollowers; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[uint64]*User, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[uint64]*User) + nids := make(map[uint64]map[*User]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Followers = []*User{} } - var ( - edgeids []uint64 - edges = make(map[uint64][]*User) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: true, - Table: user.FollowersTable, - Columns: user.FollowersPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(user.FollowersPrimaryKey[1], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") - } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") - } - outValue := uint64(eout.Int64) - inValue := uint64(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.FollowersTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FollowersPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.FollowersPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.FollowersPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := uint64(values[0].(*sql.NullInt64).Int64) + inValue := uint64(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, uq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "followers": %w`, err) - } - query.Where(user.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "followers" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Followers = append(nodes[i].Edges.Followers, n) + for kn := range nodes { + kn.Edges.Followers = append(kn.Edges.Followers, n) } } } if query := uq.withFollowing; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[uint64]*User, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[uint64]*User) + nids := make(map[uint64]map[*User]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Following = []*User{} } - var ( - edgeids []uint64 - edges = make(map[uint64][]*User) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: false, - Table: user.FollowingTable, - Columns: user.FollowingPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(user.FollowingPrimaryKey[0], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") - } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.FollowingTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FollowingPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.FollowingPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.FollowingPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - outValue := uint64(eout.Int64) - inValue := uint64(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) - } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := uint64(values[0].(*sql.NullInt64).Int64) + inValue := uint64(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, uq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "following": %w`, err) - } - query.Where(user.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "following" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Following = append(nodes[i].Edges.Following, n) + for kn := range nodes { + kn.Edges.Following = append(kn.Edges.Following, n) } } } diff --git a/entc/integration/json/ent/ent.go b/entc/integration/json/ent/ent.go index 0d530e2ee4..99cf0e127d 100644 --- a/entc/integration/json/ent/ent.go +++ b/entc/integration/json/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/entc/integration/json/ent/user" ) @@ -464,3 +465,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/entc/integration/json/ent/user_query.go b/entc/integration/json/ent/user_query.go index 0c47e1be99..d0580a9f57 100644 --- a/entc/integration/json/ent/user_query.go +++ b/entc/integration/json/ent/user_query.go @@ -317,23 +317,22 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { return nil } -func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} _spec = uq.querySpec() ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &User{config: uq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*User).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &User{config: uq.config} + nodes = append(nodes, node) return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/migrate/entv1/car_query.go b/entc/integration/migrate/entv1/car_query.go index 3e4ec55147..06db587dd4 100644 --- a/entc/integration/migrate/entv1/car_query.go +++ b/entc/integration/migrate/entv1/car_query.go @@ -331,7 +331,7 @@ func (cq *CarQuery) prepareQuery(ctx context.Context) error { return nil } -func (cq *CarQuery) sqlAll(ctx context.Context) ([]*Car, error) { +func (cq *CarQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Car, error) { var ( nodes = []*Car{} withFKs = cq.withFKs @@ -347,18 +347,17 @@ func (cq *CarQuery) sqlAll(ctx context.Context) ([]*Car, error) { _spec.Node.Columns = append(_spec.Node.Columns, car.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Car{config: cq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Car).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("entv1: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Car{config: cq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, cq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/migrate/entv1/conversion_query.go b/entc/integration/migrate/entv1/conversion_query.go index edb5373e66..5d8b5a2005 100644 --- a/entc/integration/migrate/entv1/conversion_query.go +++ b/entc/integration/migrate/entv1/conversion_query.go @@ -317,23 +317,22 @@ func (cq *ConversionQuery) prepareQuery(ctx context.Context) error { return nil } -func (cq *ConversionQuery) sqlAll(ctx context.Context) ([]*Conversion, error) { +func (cq *ConversionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Conversion, error) { var ( nodes = []*Conversion{} _spec = cq.querySpec() ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Conversion{config: cq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Conversion).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("entv1: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Conversion{config: cq.config} + nodes = append(nodes, node) return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, cq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/migrate/entv1/customtype_query.go b/entc/integration/migrate/entv1/customtype_query.go index babdb0d8f3..55f60286c3 100644 --- a/entc/integration/migrate/entv1/customtype_query.go +++ b/entc/integration/migrate/entv1/customtype_query.go @@ -317,23 +317,22 @@ func (ctq *CustomTypeQuery) prepareQuery(ctx context.Context) error { return nil } -func (ctq *CustomTypeQuery) sqlAll(ctx context.Context) ([]*CustomType, error) { +func (ctq *CustomTypeQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*CustomType, error) { var ( nodes = []*CustomType{} _spec = ctq.querySpec() ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &CustomType{config: ctq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*CustomType).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("entv1: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &CustomType{config: ctq.config} + nodes = append(nodes, node) return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, ctq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/migrate/entv1/ent.go b/entc/integration/migrate/entv1/ent.go index ddd6e04349..86ff5d6006 100644 --- a/entc/integration/migrate/entv1/ent.go +++ b/entc/integration/migrate/entv1/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/entc/integration/migrate/entv1/car" "entgo.io/ent/entc/integration/migrate/entv1/conversion" "entgo.io/ent/entc/integration/migrate/entv1/customtype" @@ -470,3 +471,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/entc/integration/migrate/entv1/user_query.go b/entc/integration/migrate/entv1/user_query.go index 6c976138a4..18a4bf76a8 100644 --- a/entc/integration/migrate/entv1/user_query.go +++ b/entc/integration/migrate/entv1/user_query.go @@ -461,7 +461,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { return nil } -func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} withFKs = uq.withFKs @@ -480,18 +480,17 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { _spec.Node.Columns = append(_spec.Node.Columns, user.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &User{config: uq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*User).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("entv1: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &User{config: uq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/migrate/entv2/car_query.go b/entc/integration/migrate/entv2/car_query.go index 9f11625995..dae4e7be12 100644 --- a/entc/integration/migrate/entv2/car_query.go +++ b/entc/integration/migrate/entv2/car_query.go @@ -331,7 +331,7 @@ func (cq *CarQuery) prepareQuery(ctx context.Context) error { return nil } -func (cq *CarQuery) sqlAll(ctx context.Context) ([]*Car, error) { +func (cq *CarQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Car, error) { var ( nodes = []*Car{} withFKs = cq.withFKs @@ -347,18 +347,17 @@ func (cq *CarQuery) sqlAll(ctx context.Context) ([]*Car, error) { _spec.Node.Columns = append(_spec.Node.Columns, car.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Car{config: cq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Car).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("entv2: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Car{config: cq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, cq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/migrate/entv2/conversion_query.go b/entc/integration/migrate/entv2/conversion_query.go index 62363b7064..95658b33cd 100644 --- a/entc/integration/migrate/entv2/conversion_query.go +++ b/entc/integration/migrate/entv2/conversion_query.go @@ -317,23 +317,22 @@ func (cq *ConversionQuery) prepareQuery(ctx context.Context) error { return nil } -func (cq *ConversionQuery) sqlAll(ctx context.Context) ([]*Conversion, error) { +func (cq *ConversionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Conversion, error) { var ( nodes = []*Conversion{} _spec = cq.querySpec() ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Conversion{config: cq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Conversion).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("entv2: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Conversion{config: cq.config} + nodes = append(nodes, node) return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, cq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/migrate/entv2/customtype_query.go b/entc/integration/migrate/entv2/customtype_query.go index 7b5a731872..affe1252b5 100644 --- a/entc/integration/migrate/entv2/customtype_query.go +++ b/entc/integration/migrate/entv2/customtype_query.go @@ -317,23 +317,22 @@ func (ctq *CustomTypeQuery) prepareQuery(ctx context.Context) error { return nil } -func (ctq *CustomTypeQuery) sqlAll(ctx context.Context) ([]*CustomType, error) { +func (ctq *CustomTypeQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*CustomType, error) { var ( nodes = []*CustomType{} _spec = ctq.querySpec() ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &CustomType{config: ctq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*CustomType).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("entv2: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &CustomType{config: ctq.config} + nodes = append(nodes, node) return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, ctq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/migrate/entv2/ent.go b/entc/integration/migrate/entv2/ent.go index 23ed9e4826..3254bd2fb3 100644 --- a/entc/integration/migrate/entv2/ent.go +++ b/entc/integration/migrate/entv2/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/entc/integration/migrate/entv2/car" "entgo.io/ent/entc/integration/migrate/entv2/conversion" "entgo.io/ent/entc/integration/migrate/entv2/customtype" @@ -476,3 +477,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/entc/integration/migrate/entv2/group_query.go b/entc/integration/migrate/entv2/group_query.go index 564597ce45..bb284a21d7 100644 --- a/entc/integration/migrate/entv2/group_query.go +++ b/entc/integration/migrate/entv2/group_query.go @@ -293,23 +293,22 @@ func (gq *GroupQuery) prepareQuery(ctx context.Context) error { return nil } -func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { +func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, error) { var ( nodes = []*Group{} _spec = gq.querySpec() ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Group{config: gq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Group).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("entv2: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Group{config: gq.config} + nodes = append(nodes, node) return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, gq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/migrate/entv2/media_query.go b/entc/integration/migrate/entv2/media_query.go index 8ad4c2d148..707f178188 100644 --- a/entc/integration/migrate/entv2/media_query.go +++ b/entc/integration/migrate/entv2/media_query.go @@ -317,23 +317,22 @@ func (mq *MediaQuery) prepareQuery(ctx context.Context) error { return nil } -func (mq *MediaQuery) sqlAll(ctx context.Context) ([]*Media, error) { +func (mq *MediaQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Media, error) { var ( nodes = []*Media{} _spec = mq.querySpec() ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Media{config: mq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Media).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("entv2: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Media{config: mq.config} + nodes = append(nodes, node) return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, mq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/migrate/entv2/pet_query.go b/entc/integration/migrate/entv2/pet_query.go index 94bbe88a45..7f198ce3ad 100644 --- a/entc/integration/migrate/entv2/pet_query.go +++ b/entc/integration/migrate/entv2/pet_query.go @@ -331,7 +331,7 @@ func (pq *PetQuery) prepareQuery(ctx context.Context) error { return nil } -func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) { +func (pq *PetQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Pet, error) { var ( nodes = []*Pet{} withFKs = pq.withFKs @@ -347,18 +347,17 @@ func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) { _spec.Node.Columns = append(_spec.Node.Columns, pet.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Pet{config: pq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Pet).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("entv2: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Pet{config: pq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, pq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/migrate/entv2/user_query.go b/entc/integration/migrate/entv2/user_query.go index 30d672f83e..c899fdcce5 100644 --- a/entc/integration/migrate/entv2/user_query.go +++ b/entc/integration/migrate/entv2/user_query.go @@ -426,7 +426,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { return nil } -func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} _spec = uq.querySpec() @@ -437,18 +437,17 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &User{config: uq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*User).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("entv2: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &User{config: uq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err } @@ -514,66 +513,54 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } if query := uq.withFriends; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*User, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Friends = []*User{} } - var ( - edgeids []int - edges = make(map[int][]*User) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: false, - Table: user.FriendsTable, - Columns: user.FriendsPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(user.FriendsPrimaryKey[0], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") - } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") - } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.FriendsTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FriendsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.FriendsPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.FriendsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, uq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "friends": %w`, err) - } - query.Where(user.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "friends" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Friends = append(nodes[i].Edges.Friends, n) + for kn := range nodes { + kn.Edges.Friends = append(kn.Edges.Friends, n) } } } diff --git a/entc/integration/migrate/versioned/car_query.go b/entc/integration/migrate/versioned/car_query.go index 2da2bf1990..6960a21bbc 100644 --- a/entc/integration/migrate/versioned/car_query.go +++ b/entc/integration/migrate/versioned/car_query.go @@ -331,7 +331,7 @@ func (cq *CarQuery) prepareQuery(ctx context.Context) error { return nil } -func (cq *CarQuery) sqlAll(ctx context.Context) ([]*Car, error) { +func (cq *CarQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Car, error) { var ( nodes = []*Car{} withFKs = cq.withFKs @@ -347,18 +347,17 @@ func (cq *CarQuery) sqlAll(ctx context.Context) ([]*Car, error) { _spec.Node.Columns = append(_spec.Node.Columns, car.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Car{config: cq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Car).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("versioned: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Car{config: cq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, cq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/migrate/versioned/ent.go b/entc/integration/migrate/versioned/ent.go index fbed0759e8..f3e5ec9e12 100644 --- a/entc/integration/migrate/versioned/ent.go +++ b/entc/integration/migrate/versioned/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/entc/integration/migrate/versioned/car" "entgo.io/ent/entc/integration/migrate/versioned/user" ) @@ -466,3 +467,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/entc/integration/migrate/versioned/user_query.go b/entc/integration/migrate/versioned/user_query.go index 82ef6a3df9..70949a00d2 100644 --- a/entc/integration/migrate/versioned/user_query.go +++ b/entc/integration/migrate/versioned/user_query.go @@ -461,7 +461,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { return nil } -func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} withFKs = uq.withFKs @@ -480,18 +480,17 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { _spec.Node.Columns = append(_spec.Node.Columns, user.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &User{config: uq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*User).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("versioned: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &User{config: uq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/multischema/ent/ent.go b/entc/integration/multischema/ent/ent.go index 56e9a28144..541fd3dd4e 100644 --- a/entc/integration/multischema/ent/ent.go +++ b/entc/integration/multischema/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/entc/integration/multischema/ent/group" "entgo.io/ent/entc/integration/multischema/ent/pet" "entgo.io/ent/entc/integration/multischema/ent/user" @@ -468,3 +469,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/entc/integration/multischema/ent/group_query.go b/entc/integration/multischema/ent/group_query.go index a0d3ccb73a..adbc6f61f7 100644 --- a/entc/integration/multischema/ent/group_query.go +++ b/entc/integration/multischema/ent/group_query.go @@ -360,7 +360,7 @@ func (gq *GroupQuery) prepareQuery(ctx context.Context) error { return nil } -func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { +func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, error) { var ( nodes = []*Group{} _spec = gq.querySpec() @@ -369,15 +369,11 @@ func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Group{config: gq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Group).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Group{config: gq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } @@ -386,6 +382,9 @@ func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { if len(gq.modifiers) > 0 { _spec.Modifiers = gq.modifiers } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, gq.driver, _spec); err != nil { return nil, err } @@ -394,67 +393,54 @@ func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { } if query := gq.withUsers; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*Group, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*Group) + nids := make(map[int]map[*Group]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Users = []*User{} } - var ( - edgeids []int - edges = make(map[int][]*Group) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: false, - Table: group.UsersTable, - Columns: group.UsersPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(group.UsersPrimaryKey[0], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") - } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") + query.Where(func(s *sql.Selector) { + joinT := sql.Table(group.UsersTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(group.UsersPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) - } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Group]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - _spec.Edge.Schema = gq.schemaConfig.GroupUsers - if err := sqlgraph.QueryEdges(ctx, gq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "users": %w`, err) - } - query.Where(user.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "users" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Users = append(nodes[i].Edges.Users, n) + for kn := range nodes { + kn.Edges.Users = append(kn.Edges.Users, n) } } } diff --git a/entc/integration/multischema/ent/pet_query.go b/entc/integration/multischema/ent/pet_query.go index c104fcee81..15dd9c1d24 100644 --- a/entc/integration/multischema/ent/pet_query.go +++ b/entc/integration/multischema/ent/pet_query.go @@ -359,7 +359,7 @@ func (pq *PetQuery) prepareQuery(ctx context.Context) error { return nil } -func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) { +func (pq *PetQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Pet, error) { var ( nodes = []*Pet{} _spec = pq.querySpec() @@ -368,15 +368,11 @@ func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Pet{config: pq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Pet).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Pet{config: pq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } @@ -385,6 +381,9 @@ func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) { if len(pq.modifiers) > 0 { _spec.Modifiers = pq.modifiers } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, pq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/multischema/ent/user_query.go b/entc/integration/multischema/ent/user_query.go index bb5dc7f2bd..7bbe0268ab 100644 --- a/entc/integration/multischema/ent/user_query.go +++ b/entc/integration/multischema/ent/user_query.go @@ -399,7 +399,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { return nil } -func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} _spec = uq.querySpec() @@ -409,15 +409,11 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &User{config: uq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*User).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &User{config: uq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } @@ -426,6 +422,9 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { if len(uq.modifiers) > 0 { _spec.Modifiers = uq.modifiers } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err } @@ -459,67 +458,54 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } if query := uq.withGroups; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*User, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Groups = []*Group{} } - var ( - edgeids []int - edges = make(map[int][]*User) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: true, - Table: user.GroupsTable, - Columns: user.GroupsPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(user.GroupsPrimaryKey[1], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") - } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") - } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.GroupsTable) + s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.GroupsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - _spec.Edge.Schema = uq.schemaConfig.GroupUsers - if err := sqlgraph.QueryEdges(ctx, uq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "groups": %w`, err) - } - query.Where(group.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "groups" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Groups = append(nodes[i].Edges.Groups, n) + for kn := range nodes { + kn.Edges.Groups = append(kn.Edges.Groups, n) } } } diff --git a/entc/integration/privacy/ent/ent.go b/entc/integration/privacy/ent/ent.go index 9228491bbe..e1b7ba4b91 100644 --- a/entc/integration/privacy/ent/ent.go +++ b/entc/integration/privacy/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/entc/integration/privacy/ent/task" "entgo.io/ent/entc/integration/privacy/ent/team" "entgo.io/ent/entc/integration/privacy/ent/user" @@ -468,3 +469,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/entc/integration/privacy/ent/task_query.go b/entc/integration/privacy/ent/task_query.go index 2b5e488a03..77332ef703 100644 --- a/entc/integration/privacy/ent/task_query.go +++ b/entc/integration/privacy/ent/task_query.go @@ -399,7 +399,7 @@ func (tq *TaskQuery) prepareQuery(ctx context.Context) error { return nil } -func (tq *TaskQuery) sqlAll(ctx context.Context) ([]*Task, error) { +func (tq *TaskQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Task, error) { var ( nodes = []*Task{} withFKs = tq.withFKs @@ -416,18 +416,17 @@ func (tq *TaskQuery) sqlAll(ctx context.Context) ([]*Task, error) { _spec.Node.Columns = append(_spec.Node.Columns, task.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Task{config: tq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Task).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Task{config: tq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, tq.driver, _spec); err != nil { return nil, err } @@ -436,66 +435,54 @@ func (tq *TaskQuery) sqlAll(ctx context.Context) ([]*Task, error) { } if query := tq.withTeams; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*Task, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*Task) + nids := make(map[int]map[*Task]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Teams = []*Team{} } - var ( - edgeids []int - edges = make(map[int][]*Task) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: false, - Table: task.TeamsTable, - Columns: task.TeamsPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(task.TeamsPrimaryKey[0], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") + query.Where(func(s *sql.Selector) { + joinT := sql.Table(task.TeamsTable) + s.Join(joinT).On(s.C(team.FieldID), joinT.C(task.TeamsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(task.TeamsPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(task.TeamsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") - } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) - } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Task]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, tq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "teams": %w`, err) - } - query.Where(team.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "teams" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Teams = append(nodes[i].Edges.Teams, n) + for kn := range nodes { + kn.Edges.Teams = append(kn.Edges.Teams, n) } } } diff --git a/entc/integration/privacy/ent/team_query.go b/entc/integration/privacy/ent/team_query.go index d488354bca..7a85a98b02 100644 --- a/entc/integration/privacy/ent/team_query.go +++ b/entc/integration/privacy/ent/team_query.go @@ -398,7 +398,7 @@ func (tq *TeamQuery) prepareQuery(ctx context.Context) error { return nil } -func (tq *TeamQuery) sqlAll(ctx context.Context) ([]*Team, error) { +func (tq *TeamQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Team, error) { var ( nodes = []*Team{} _spec = tq.querySpec() @@ -408,18 +408,17 @@ func (tq *TeamQuery) sqlAll(ctx context.Context) ([]*Team, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Team{config: tq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Team).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Team{config: tq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, tq.driver, _spec); err != nil { return nil, err } @@ -428,131 +427,107 @@ func (tq *TeamQuery) sqlAll(ctx context.Context) ([]*Team, error) { } if query := tq.withTasks; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*Team, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*Team) + nids := make(map[int]map[*Team]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Tasks = []*Task{} } - var ( - edgeids []int - edges = make(map[int][]*Team) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: true, - Table: team.TasksTable, - Columns: team.TasksPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(team.TasksPrimaryKey[1], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") - } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") + query.Where(func(s *sql.Selector) { + joinT := sql.Table(team.TasksTable) + s.Join(joinT).On(s.C(task.FieldID), joinT.C(team.TasksPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(team.TasksPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(team.TasksPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) - } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Team]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, tq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "tasks": %w`, err) - } - query.Where(task.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "tasks" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Tasks = append(nodes[i].Edges.Tasks, n) + for kn := range nodes { + kn.Edges.Tasks = append(kn.Edges.Tasks, n) } } } if query := tq.withUsers; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*Team, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*Team) + nids := make(map[int]map[*Team]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Users = []*User{} } - var ( - edgeids []int - edges = make(map[int][]*Team) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: true, - Table: team.UsersTable, - Columns: team.UsersPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(team.UsersPrimaryKey[1], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") + query.Where(func(s *sql.Selector) { + joinT := sql.Table(team.UsersTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(team.UsersPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(team.UsersPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(team.UsersPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") - } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) - } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Team]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, tq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "users": %w`, err) - } - query.Where(user.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "users" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Users = append(nodes[i].Edges.Users, n) + for kn := range nodes { + kn.Edges.Users = append(kn.Edges.Users, n) } } } diff --git a/entc/integration/privacy/ent/user_query.go b/entc/integration/privacy/ent/user_query.go index 2816d8449e..087bb05c41 100644 --- a/entc/integration/privacy/ent/user_query.go +++ b/entc/integration/privacy/ent/user_query.go @@ -398,7 +398,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { return nil } -func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} _spec = uq.querySpec() @@ -408,18 +408,17 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &User{config: uq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*User).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &User{config: uq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err } @@ -428,66 +427,54 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } if query := uq.withTeams; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*User, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Teams = []*Team{} } - var ( - edgeids []int - edges = make(map[int][]*User) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: false, - Table: user.TeamsTable, - Columns: user.TeamsPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(user.TeamsPrimaryKey[0], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") - } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") - } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.TeamsTable) + s.Join(joinT).On(s.C(team.FieldID), joinT.C(user.TeamsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.TeamsPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.TeamsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, uq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "teams": %w`, err) - } - query.Where(team.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "teams" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Teams = append(nodes[i].Edges.Teams, n) + for kn := range nodes { + kn.Edges.Teams = append(kn.Edges.Teams, n) } } } diff --git a/entc/integration/template/ent/ent.go b/entc/integration/template/ent/ent.go index f3d8a09bce..7127f60f8f 100644 --- a/entc/integration/template/ent/ent.go +++ b/entc/integration/template/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/entc/integration/template/ent/group" "entgo.io/ent/entc/integration/template/ent/pet" "entgo.io/ent/entc/integration/template/ent/user" @@ -468,3 +469,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/entc/integration/template/ent/group_query.go b/entc/integration/template/ent/group_query.go index 620179e74d..012b68850f 100644 --- a/entc/integration/template/ent/group_query.go +++ b/entc/integration/template/ent/group_query.go @@ -320,26 +320,25 @@ func (gq *GroupQuery) prepareQuery(ctx context.Context) error { return nil } -func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { +func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, error) { var ( nodes = []*Group{} _spec = gq.querySpec() ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Group{config: gq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Group).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Group{config: gq.config} + nodes = append(nodes, node) return node.assignValues(columns, values) } if len(gq.modifiers) > 0 { _spec.Modifiers = gq.modifiers } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, gq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/template/ent/pet_query.go b/entc/integration/template/ent/pet_query.go index 969a79a3d1..77a017c090 100644 --- a/entc/integration/template/ent/pet_query.go +++ b/entc/integration/template/ent/pet_query.go @@ -358,7 +358,7 @@ func (pq *PetQuery) prepareQuery(ctx context.Context) error { return nil } -func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) { +func (pq *PetQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Pet, error) { var ( nodes = []*Pet{} withFKs = pq.withFKs @@ -374,21 +374,20 @@ func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) { _spec.Node.Columns = append(_spec.Node.Columns, pet.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Pet{config: pq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Pet).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Pet{config: pq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } if len(pq.modifiers) > 0 { _spec.Modifiers = pq.modifiers } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, pq.driver, _spec); err != nil { return nil, err } diff --git a/entc/integration/template/ent/user_query.go b/entc/integration/template/ent/user_query.go index 563813613f..25a518c34b 100644 --- a/entc/integration/template/ent/user_query.go +++ b/entc/integration/template/ent/user_query.go @@ -393,7 +393,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { return nil } -func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} _spec = uq.querySpec() @@ -403,21 +403,20 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &User{config: uq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*User).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &User{config: uq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } if len(uq.modifiers) > 0 { _spec.Modifiers = uq.modifiers } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err } @@ -455,66 +454,54 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } if query := uq.withFriends; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*User, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Friends = []*User{} } - var ( - edgeids []int - edges = make(map[int][]*User) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: false, - Table: user.FriendsTable, - Columns: user.FriendsPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(user.FriendsPrimaryKey[0], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") - } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") - } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.FriendsTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FriendsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.FriendsPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.FriendsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, uq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "friends": %w`, err) - } - query.Where(user.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "friends" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Friends = append(nodes[i].Edges.Friends, n) + for kn := range nodes { + kn.Edges.Friends = append(kn.Edges.Friends, n) } } } diff --git a/examples/edgeindex/ent/city_query.go b/examples/edgeindex/ent/city_query.go index 5f6cee7e15..8fe846d01e 100644 --- a/examples/edgeindex/ent/city_query.go +++ b/examples/edgeindex/ent/city_query.go @@ -355,7 +355,7 @@ func (cq *CityQuery) prepareQuery(ctx context.Context) error { return nil } -func (cq *CityQuery) sqlAll(ctx context.Context) ([]*City, error) { +func (cq *CityQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*City, error) { var ( nodes = []*City{} _spec = cq.querySpec() @@ -364,18 +364,17 @@ func (cq *CityQuery) sqlAll(ctx context.Context) ([]*City, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &City{config: cq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*City).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &City{config: cq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, cq.driver, _spec); err != nil { return nil, err } diff --git a/examples/edgeindex/ent/ent.go b/examples/edgeindex/ent/ent.go index fdc2cd9581..f8ffba8e6d 100644 --- a/examples/edgeindex/ent/ent.go +++ b/examples/edgeindex/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/examples/edgeindex/ent/city" "entgo.io/ent/examples/edgeindex/ent/street" ) @@ -466,3 +467,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/examples/edgeindex/ent/street_query.go b/examples/edgeindex/ent/street_query.go index 0b20371152..a586491b04 100644 --- a/examples/edgeindex/ent/street_query.go +++ b/examples/edgeindex/ent/street_query.go @@ -355,7 +355,7 @@ func (sq *StreetQuery) prepareQuery(ctx context.Context) error { return nil } -func (sq *StreetQuery) sqlAll(ctx context.Context) ([]*Street, error) { +func (sq *StreetQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Street, error) { var ( nodes = []*Street{} withFKs = sq.withFKs @@ -371,18 +371,17 @@ func (sq *StreetQuery) sqlAll(ctx context.Context) ([]*Street, error) { _spec.Node.Columns = append(_spec.Node.Columns, street.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Street{config: sq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Street).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Street{config: sq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, sq.driver, _spec); err != nil { return nil, err } diff --git a/examples/entcpkg/ent/ent.go b/examples/entcpkg/ent/ent.go index a25423dbce..3314d9e016 100644 --- a/examples/entcpkg/ent/ent.go +++ b/examples/entcpkg/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/examples/entcpkg/ent/user" ) @@ -464,3 +465,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/examples/entcpkg/ent/user_query.go b/examples/entcpkg/ent/user_query.go index 5e2834efe6..0c144d7fb7 100644 --- a/examples/entcpkg/ent/user_query.go +++ b/examples/entcpkg/ent/user_query.go @@ -317,23 +317,22 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { return nil } -func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} _spec = uq.querySpec() ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &User{config: uq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*User).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &User{config: uq.config} + nodes = append(nodes, node) return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err } diff --git a/examples/fs/ent/ent.go b/examples/fs/ent/ent.go index 7dfb1ed5a7..88690e4b23 100644 --- a/examples/fs/ent/ent.go +++ b/examples/fs/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/examples/fs/ent/file" ) @@ -464,3 +465,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/examples/fs/ent/file_query.go b/examples/fs/ent/file_query.go index 914b4385be..4b651aede9 100644 --- a/examples/fs/ent/file_query.go +++ b/examples/fs/ent/file_query.go @@ -389,7 +389,7 @@ func (fq *FileQuery) prepareQuery(ctx context.Context) error { return nil } -func (fq *FileQuery) sqlAll(ctx context.Context) ([]*File, error) { +func (fq *FileQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*File, error) { var ( nodes = []*File{} _spec = fq.querySpec() @@ -399,18 +399,17 @@ func (fq *FileQuery) sqlAll(ctx context.Context) ([]*File, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &File{config: fq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*File).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &File{config: fq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, fq.driver, _spec); err != nil { return nil, err } diff --git a/examples/m2m2types/ent/ent.go b/examples/m2m2types/ent/ent.go index 9bd3798c2a..8edfdba1df 100644 --- a/examples/m2m2types/ent/ent.go +++ b/examples/m2m2types/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/examples/m2m2types/ent/group" "entgo.io/ent/examples/m2m2types/ent/user" ) @@ -466,3 +467,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/examples/m2m2types/ent/group_query.go b/examples/m2m2types/ent/group_query.go index afc239abac..6bbe1bd11b 100644 --- a/examples/m2m2types/ent/group_query.go +++ b/examples/m2m2types/ent/group_query.go @@ -355,7 +355,7 @@ func (gq *GroupQuery) prepareQuery(ctx context.Context) error { return nil } -func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { +func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, error) { var ( nodes = []*Group{} _spec = gq.querySpec() @@ -364,18 +364,17 @@ func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Group{config: gq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Group).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Group{config: gq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, gq.driver, _spec); err != nil { return nil, err } @@ -384,66 +383,54 @@ func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { } if query := gq.withUsers; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*Group, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*Group) + nids := make(map[int]map[*Group]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Users = []*User{} } - var ( - edgeids []int - edges = make(map[int][]*Group) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: false, - Table: group.UsersTable, - Columns: group.UsersPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(group.UsersPrimaryKey[0], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") - } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") + query.Where(func(s *sql.Selector) { + joinT := sql.Table(group.UsersTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(group.UsersPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) - } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Group]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, gq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "users": %w`, err) - } - query.Where(user.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "users" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Users = append(nodes[i].Edges.Users, n) + for kn := range nodes { + kn.Edges.Users = append(kn.Edges.Users, n) } } } diff --git a/examples/m2m2types/ent/user_query.go b/examples/m2m2types/ent/user_query.go index 45c4f6062f..9fb06fbcb8 100644 --- a/examples/m2m2types/ent/user_query.go +++ b/examples/m2m2types/ent/user_query.go @@ -355,7 +355,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { return nil } -func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} _spec = uq.querySpec() @@ -364,18 +364,17 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &User{config: uq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*User).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &User{config: uq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err } @@ -384,66 +383,54 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } if query := uq.withGroups; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*User, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Groups = []*Group{} } - var ( - edgeids []int - edges = make(map[int][]*User) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: true, - Table: user.GroupsTable, - Columns: user.GroupsPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(user.GroupsPrimaryKey[1], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") - } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.GroupsTable) + s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.GroupsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) - } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, uq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "groups": %w`, err) - } - query.Where(group.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "groups" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Groups = append(nodes[i].Edges.Groups, n) + for kn := range nodes { + kn.Edges.Groups = append(kn.Edges.Groups, n) } } } diff --git a/examples/m2mbidi/ent/ent.go b/examples/m2mbidi/ent/ent.go index f4ea4dd039..37af3e25fa 100644 --- a/examples/m2mbidi/ent/ent.go +++ b/examples/m2mbidi/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/examples/m2mbidi/ent/user" ) @@ -464,3 +465,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/examples/m2mbidi/ent/user_query.go b/examples/m2mbidi/ent/user_query.go index 97ec4b0bba..39be92e446 100644 --- a/examples/m2mbidi/ent/user_query.go +++ b/examples/m2mbidi/ent/user_query.go @@ -354,7 +354,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { return nil } -func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} _spec = uq.querySpec() @@ -363,18 +363,17 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &User{config: uq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*User).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &User{config: uq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err } @@ -383,66 +382,54 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } if query := uq.withFriends; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*User, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Friends = []*User{} } - var ( - edgeids []int - edges = make(map[int][]*User) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: false, - Table: user.FriendsTable, - Columns: user.FriendsPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(user.FriendsPrimaryKey[0], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") - } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.FriendsTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FriendsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.FriendsPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.FriendsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) - } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, uq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "friends": %w`, err) - } - query.Where(user.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "friends" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Friends = append(nodes[i].Edges.Friends, n) + for kn := range nodes { + kn.Edges.Friends = append(kn.Edges.Friends, n) } } } diff --git a/examples/m2mrecur/ent/ent.go b/examples/m2mrecur/ent/ent.go index 7aea071442..030e414f11 100644 --- a/examples/m2mrecur/ent/ent.go +++ b/examples/m2mrecur/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/examples/m2mrecur/ent/user" ) @@ -464,3 +465,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/examples/m2mrecur/ent/user_query.go b/examples/m2mrecur/ent/user_query.go index 941003a3e9..7bdf72fb8b 100644 --- a/examples/m2mrecur/ent/user_query.go +++ b/examples/m2mrecur/ent/user_query.go @@ -389,7 +389,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { return nil } -func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} _spec = uq.querySpec() @@ -399,18 +399,17 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &User{config: uq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*User).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &User{config: uq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err } @@ -419,131 +418,107 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } if query := uq.withFollowers; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*User, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Followers = []*User{} } - var ( - edgeids []int - edges = make(map[int][]*User) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: true, - Table: user.FollowersTable, - Columns: user.FollowersPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(user.FollowersPrimaryKey[1], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") - } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.FollowersTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FollowersPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.FollowersPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.FollowersPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) - } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, uq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "followers": %w`, err) - } - query.Where(user.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "followers" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Followers = append(nodes[i].Edges.Followers, n) + for kn := range nodes { + kn.Edges.Followers = append(kn.Edges.Followers, n) } } } if query := uq.withFollowing; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*User, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Following = []*User{} } - var ( - edgeids []int - edges = make(map[int][]*User) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: false, - Table: user.FollowingTable, - Columns: user.FollowingPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(user.FollowingPrimaryKey[0], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.FollowingTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FollowingPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.FollowingPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.FollowingPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") - } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) - } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, uq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "following": %w`, err) - } - query.Where(user.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "following" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Following = append(nodes[i].Edges.Following, n) + for kn := range nodes { + kn.Edges.Following = append(kn.Edges.Following, n) } } } diff --git a/examples/o2m2types/ent/ent.go b/examples/o2m2types/ent/ent.go index 01c2029ecd..7e2141a167 100644 --- a/examples/o2m2types/ent/ent.go +++ b/examples/o2m2types/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/examples/o2m2types/ent/pet" "entgo.io/ent/examples/o2m2types/ent/user" ) @@ -466,3 +467,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/examples/o2m2types/ent/pet_query.go b/examples/o2m2types/ent/pet_query.go index c976c30d3a..41365c98db 100644 --- a/examples/o2m2types/ent/pet_query.go +++ b/examples/o2m2types/ent/pet_query.go @@ -355,7 +355,7 @@ func (pq *PetQuery) prepareQuery(ctx context.Context) error { return nil } -func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) { +func (pq *PetQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Pet, error) { var ( nodes = []*Pet{} withFKs = pq.withFKs @@ -371,18 +371,17 @@ func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) { _spec.Node.Columns = append(_spec.Node.Columns, pet.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Pet{config: pq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Pet).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Pet{config: pq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, pq.driver, _spec); err != nil { return nil, err } diff --git a/examples/o2m2types/ent/user_query.go b/examples/o2m2types/ent/user_query.go index 614c168a7d..5205328a81 100644 --- a/examples/o2m2types/ent/user_query.go +++ b/examples/o2m2types/ent/user_query.go @@ -355,7 +355,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { return nil } -func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} _spec = uq.querySpec() @@ -364,18 +364,17 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &User{config: uq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*User).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &User{config: uq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err } diff --git a/examples/o2mrecur/ent/ent.go b/examples/o2mrecur/ent/ent.go index 05c3b2e243..5e35bb76d2 100644 --- a/examples/o2mrecur/ent/ent.go +++ b/examples/o2mrecur/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/examples/o2mrecur/ent/node" ) @@ -464,3 +465,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/examples/o2mrecur/ent/node_query.go b/examples/o2mrecur/ent/node_query.go index 34f817e0fd..eb7a12a59a 100644 --- a/examples/o2mrecur/ent/node_query.go +++ b/examples/o2mrecur/ent/node_query.go @@ -390,7 +390,7 @@ func (nq *NodeQuery) prepareQuery(ctx context.Context) error { return nil } -func (nq *NodeQuery) sqlAll(ctx context.Context) ([]*Node, error) { +func (nq *NodeQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Node, error) { var ( nodes = []*Node{} withFKs = nq.withFKs @@ -407,18 +407,17 @@ func (nq *NodeQuery) sqlAll(ctx context.Context) ([]*Node, error) { _spec.Node.Columns = append(_spec.Node.Columns, node.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Node{config: nq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Node).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Node{config: nq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, nq.driver, _spec); err != nil { return nil, err } diff --git a/examples/o2o2types/ent/card_query.go b/examples/o2o2types/ent/card_query.go index 06898275cc..f35b6c81ce 100644 --- a/examples/o2o2types/ent/card_query.go +++ b/examples/o2o2types/ent/card_query.go @@ -355,7 +355,7 @@ func (cq *CardQuery) prepareQuery(ctx context.Context) error { return nil } -func (cq *CardQuery) sqlAll(ctx context.Context) ([]*Card, error) { +func (cq *CardQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Card, error) { var ( nodes = []*Card{} withFKs = cq.withFKs @@ -371,18 +371,17 @@ func (cq *CardQuery) sqlAll(ctx context.Context) ([]*Card, error) { _spec.Node.Columns = append(_spec.Node.Columns, card.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Card{config: cq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Card).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Card{config: cq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, cq.driver, _spec); err != nil { return nil, err } diff --git a/examples/o2o2types/ent/ent.go b/examples/o2o2types/ent/ent.go index eabaab71c6..e683344996 100644 --- a/examples/o2o2types/ent/ent.go +++ b/examples/o2o2types/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/examples/o2o2types/ent/card" "entgo.io/ent/examples/o2o2types/ent/user" ) @@ -466,3 +467,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/examples/o2o2types/ent/user_query.go b/examples/o2o2types/ent/user_query.go index 3acc0c55e8..c0a9907d16 100644 --- a/examples/o2o2types/ent/user_query.go +++ b/examples/o2o2types/ent/user_query.go @@ -355,7 +355,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { return nil } -func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} _spec = uq.querySpec() @@ -364,18 +364,17 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &User{config: uq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*User).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &User{config: uq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err } diff --git a/examples/o2obidi/ent/ent.go b/examples/o2obidi/ent/ent.go index 0896127e55..02929c19b6 100644 --- a/examples/o2obidi/ent/ent.go +++ b/examples/o2obidi/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/examples/o2obidi/ent/user" ) @@ -464,3 +465,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/examples/o2obidi/ent/user_query.go b/examples/o2obidi/ent/user_query.go index 50deaed5bc..5a2d8745d8 100644 --- a/examples/o2obidi/ent/user_query.go +++ b/examples/o2obidi/ent/user_query.go @@ -354,7 +354,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { return nil } -func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} withFKs = uq.withFKs @@ -370,18 +370,17 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { _spec.Node.Columns = append(_spec.Node.Columns, user.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &User{config: uq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*User).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &User{config: uq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err } diff --git a/examples/o2orecur/ent/ent.go b/examples/o2orecur/ent/ent.go index ddc3ba9701..f84fdc42d9 100644 --- a/examples/o2orecur/ent/ent.go +++ b/examples/o2orecur/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/examples/o2orecur/ent/node" ) @@ -464,3 +465,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/examples/o2orecur/ent/node_query.go b/examples/o2orecur/ent/node_query.go index af1014f59b..7f424628c9 100644 --- a/examples/o2orecur/ent/node_query.go +++ b/examples/o2orecur/ent/node_query.go @@ -390,7 +390,7 @@ func (nq *NodeQuery) prepareQuery(ctx context.Context) error { return nil } -func (nq *NodeQuery) sqlAll(ctx context.Context) ([]*Node, error) { +func (nq *NodeQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Node, error) { var ( nodes = []*Node{} withFKs = nq.withFKs @@ -407,18 +407,17 @@ func (nq *NodeQuery) sqlAll(ctx context.Context) ([]*Node, error) { _spec.Node.Columns = append(_spec.Node.Columns, node.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Node{config: nq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Node).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Node{config: nq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, nq.driver, _spec); err != nil { return nil, err } diff --git a/examples/privacyadmin/ent/ent.go b/examples/privacyadmin/ent/ent.go index 56b596f41c..7e3b9f98d6 100644 --- a/examples/privacyadmin/ent/ent.go +++ b/examples/privacyadmin/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/examples/privacyadmin/ent/user" ) @@ -464,3 +465,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/examples/privacyadmin/ent/user_query.go b/examples/privacyadmin/ent/user_query.go index 0e3ad6882d..4a6b1675de 100644 --- a/examples/privacyadmin/ent/user_query.go +++ b/examples/privacyadmin/ent/user_query.go @@ -324,23 +324,22 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { return nil } -func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} _spec = uq.querySpec() ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &User{config: uq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*User).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &User{config: uq.config} + nodes = append(nodes, node) return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err } diff --git a/examples/privacytenant/ent/ent.go b/examples/privacytenant/ent/ent.go index 00de15c5e9..0c03efb1c1 100644 --- a/examples/privacytenant/ent/ent.go +++ b/examples/privacytenant/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/examples/privacytenant/ent/group" "entgo.io/ent/examples/privacytenant/ent/tenant" "entgo.io/ent/examples/privacytenant/ent/user" @@ -468,3 +469,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/examples/privacytenant/ent/group_query.go b/examples/privacytenant/ent/group_query.go index a58f43e4c3..0bdc37464a 100644 --- a/examples/privacytenant/ent/group_query.go +++ b/examples/privacytenant/ent/group_query.go @@ -399,7 +399,7 @@ func (gq *GroupQuery) prepareQuery(ctx context.Context) error { return nil } -func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { +func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, error) { var ( nodes = []*Group{} withFKs = gq.withFKs @@ -416,18 +416,17 @@ func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { _spec.Node.Columns = append(_spec.Node.Columns, group.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Group{config: gq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Group).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Group{config: gq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, gq.driver, _spec); err != nil { return nil, err } @@ -465,66 +464,54 @@ func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { } if query := gq.withUsers; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*Group, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*Group) + nids := make(map[int]map[*Group]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Users = []*User{} } - var ( - edgeids []int - edges = make(map[int][]*Group) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: true, - Table: group.UsersTable, - Columns: group.UsersPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(group.UsersPrimaryKey[1], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") + query.Where(func(s *sql.Selector) { + joinT := sql.Table(group.UsersTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(group.UsersPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") - } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) - } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Group]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, gq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "users": %w`, err) - } - query.Where(user.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "users" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Users = append(nodes[i].Edges.Users, n) + for kn := range nodes { + kn.Edges.Users = append(kn.Edges.Users, n) } } } diff --git a/examples/privacytenant/ent/tenant_query.go b/examples/privacytenant/ent/tenant_query.go index b7e21c5495..5beaa896e0 100644 --- a/examples/privacytenant/ent/tenant_query.go +++ b/examples/privacytenant/ent/tenant_query.go @@ -324,23 +324,22 @@ func (tq *TenantQuery) prepareQuery(ctx context.Context) error { return nil } -func (tq *TenantQuery) sqlAll(ctx context.Context) ([]*Tenant, error) { +func (tq *TenantQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Tenant, error) { var ( nodes = []*Tenant{} _spec = tq.querySpec() ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Tenant{config: tq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Tenant).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Tenant{config: tq.config} + nodes = append(nodes, node) return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, tq.driver, _spec); err != nil { return nil, err } diff --git a/examples/privacytenant/ent/user_query.go b/examples/privacytenant/ent/user_query.go index d98cc6b7d3..6a5e923ed0 100644 --- a/examples/privacytenant/ent/user_query.go +++ b/examples/privacytenant/ent/user_query.go @@ -399,7 +399,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { return nil } -func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} withFKs = uq.withFKs @@ -416,18 +416,17 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { _spec.Node.Columns = append(_spec.Node.Columns, user.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &User{config: uq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*User).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &User{config: uq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err } @@ -465,66 +464,54 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } if query := uq.withGroups; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*User, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Groups = []*Group{} } - var ( - edgeids []int - edges = make(map[int][]*User) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: false, - Table: user.GroupsTable, - Columns: user.GroupsPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(user.GroupsPrimaryKey[0], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.GroupsTable) + s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.GroupsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") - } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) - } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, uq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "groups": %w`, err) - } - query.Where(group.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "groups" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Groups = append(nodes[i].Edges.Groups, n) + for kn := range nodes { + kn.Edges.Groups = append(kn.Edges.Groups, n) } } } diff --git a/examples/start/ent/car_query.go b/examples/start/ent/car_query.go index 4b986e882e..7187a5ddd5 100644 --- a/examples/start/ent/car_query.go +++ b/examples/start/ent/car_query.go @@ -355,7 +355,7 @@ func (cq *CarQuery) prepareQuery(ctx context.Context) error { return nil } -func (cq *CarQuery) sqlAll(ctx context.Context) ([]*Car, error) { +func (cq *CarQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Car, error) { var ( nodes = []*Car{} withFKs = cq.withFKs @@ -371,18 +371,17 @@ func (cq *CarQuery) sqlAll(ctx context.Context) ([]*Car, error) { _spec.Node.Columns = append(_spec.Node.Columns, car.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Car{config: cq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Car).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Car{config: cq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, cq.driver, _spec); err != nil { return nil, err } diff --git a/examples/start/ent/ent.go b/examples/start/ent/ent.go index c7317b1425..c848bc4599 100644 --- a/examples/start/ent/ent.go +++ b/examples/start/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/examples/start/ent/car" "entgo.io/ent/examples/start/ent/group" "entgo.io/ent/examples/start/ent/user" @@ -468,3 +469,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/examples/start/ent/group_query.go b/examples/start/ent/group_query.go index 04f18cbb32..68385c068f 100644 --- a/examples/start/ent/group_query.go +++ b/examples/start/ent/group_query.go @@ -355,7 +355,7 @@ func (gq *GroupQuery) prepareQuery(ctx context.Context) error { return nil } -func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { +func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, error) { var ( nodes = []*Group{} _spec = gq.querySpec() @@ -364,18 +364,17 @@ func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Group{config: gq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Group).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Group{config: gq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, gq.driver, _spec); err != nil { return nil, err } @@ -384,66 +383,54 @@ func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { } if query := gq.withUsers; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*Group, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*Group) + nids := make(map[int]map[*Group]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Users = []*User{} } - var ( - edgeids []int - edges = make(map[int][]*Group) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: false, - Table: group.UsersTable, - Columns: group.UsersPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(group.UsersPrimaryKey[0], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") - } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") + query.Where(func(s *sql.Selector) { + joinT := sql.Table(group.UsersTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(group.UsersPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) - } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Group]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, gq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "users": %w`, err) - } - query.Where(user.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "users" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Users = append(nodes[i].Edges.Users, n) + for kn := range nodes { + kn.Edges.Users = append(kn.Edges.Users, n) } } } diff --git a/examples/start/ent/user_query.go b/examples/start/ent/user_query.go index 9561ece40e..740372611e 100644 --- a/examples/start/ent/user_query.go +++ b/examples/start/ent/user_query.go @@ -391,7 +391,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { return nil } -func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} _spec = uq.querySpec() @@ -401,18 +401,17 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &User{config: uq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*User).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &User{config: uq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err } @@ -450,66 +449,54 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } if query := uq.withGroups; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*User, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Groups = []*Group{} } - var ( - edgeids []int - edges = make(map[int][]*User) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: true, - Table: user.GroupsTable, - Columns: user.GroupsPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(user.GroupsPrimaryKey[1], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") - } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") - } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.GroupsTable) + s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.GroupsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, uq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "groups": %w`, err) - } - query.Where(group.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "groups" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Groups = append(nodes[i].Edges.Groups, n) + for kn := range nodes { + kn.Edges.Groups = append(kn.Edges.Groups, n) } } } diff --git a/examples/traversal/ent/ent.go b/examples/traversal/ent/ent.go index a5163f88f0..5763946e37 100644 --- a/examples/traversal/ent/ent.go +++ b/examples/traversal/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/examples/traversal/ent/group" "entgo.io/ent/examples/traversal/ent/pet" "entgo.io/ent/examples/traversal/ent/user" @@ -468,3 +469,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/examples/traversal/ent/group_query.go b/examples/traversal/ent/group_query.go index b36c6ebd3b..8e33c0d68a 100644 --- a/examples/traversal/ent/group_query.go +++ b/examples/traversal/ent/group_query.go @@ -391,7 +391,7 @@ func (gq *GroupQuery) prepareQuery(ctx context.Context) error { return nil } -func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { +func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, error) { var ( nodes = []*Group{} withFKs = gq.withFKs @@ -408,18 +408,17 @@ func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { _spec.Node.Columns = append(_spec.Node.Columns, group.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Group{config: gq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Group).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Group{config: gq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, gq.driver, _spec); err != nil { return nil, err } @@ -428,66 +427,54 @@ func (gq *GroupQuery) sqlAll(ctx context.Context) ([]*Group, error) { } if query := gq.withUsers; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*Group, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*Group) + nids := make(map[int]map[*Group]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Users = []*User{} } - var ( - edgeids []int - edges = make(map[int][]*Group) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: false, - Table: group.UsersTable, - Columns: group.UsersPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(group.UsersPrimaryKey[0], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") + query.Where(func(s *sql.Selector) { + joinT := sql.Table(group.UsersTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(group.UsersPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") - } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) - } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Group]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, gq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "users": %w`, err) - } - query.Where(user.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "users" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Users = append(nodes[i].Edges.Users, n) + for kn := range nodes { + kn.Edges.Users = append(kn.Edges.Users, n) } } } diff --git a/examples/traversal/ent/pet_query.go b/examples/traversal/ent/pet_query.go index 21e28f2cb2..d55abfb5ad 100644 --- a/examples/traversal/ent/pet_query.go +++ b/examples/traversal/ent/pet_query.go @@ -391,7 +391,7 @@ func (pq *PetQuery) prepareQuery(ctx context.Context) error { return nil } -func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) { +func (pq *PetQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Pet, error) { var ( nodes = []*Pet{} withFKs = pq.withFKs @@ -408,18 +408,17 @@ func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) { _spec.Node.Columns = append(_spec.Node.Columns, pet.ForeignKeys...) } _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &Pet{config: pq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*Pet).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &Pet{config: pq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, pq.driver, _spec); err != nil { return nil, err } @@ -428,66 +427,54 @@ func (pq *PetQuery) sqlAll(ctx context.Context) ([]*Pet, error) { } if query := pq.withFriends; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*Pet, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*Pet) + nids := make(map[int]map[*Pet]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Friends = []*Pet{} } - var ( - edgeids []int - edges = make(map[int][]*Pet) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: false, - Table: pet.FriendsTable, - Columns: pet.FriendsPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(pet.FriendsPrimaryKey[0], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") + query.Where(func(s *sql.Selector) { + joinT := sql.Table(pet.FriendsTable) + s.Join(joinT).On(s.C(pet.FieldID), joinT.C(pet.FriendsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(pet.FriendsPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(pet.FriendsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") - } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) - } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Pet]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, pq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "friends": %w`, err) - } - query.Where(pet.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "friends" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Friends = append(nodes[i].Edges.Friends, n) + for kn := range nodes { + kn.Edges.Friends = append(kn.Edges.Friends, n) } } } diff --git a/examples/traversal/ent/user_query.go b/examples/traversal/ent/user_query.go index b6656a153e..18999c1159 100644 --- a/examples/traversal/ent/user_query.go +++ b/examples/traversal/ent/user_query.go @@ -461,7 +461,7 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { return nil } -func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} _spec = uq.querySpec() @@ -473,18 +473,17 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &User{config: uq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*User).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &User{config: uq.config} + nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err } @@ -522,131 +521,107 @@ func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { } if query := uq.withFriends; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*User, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Friends = []*User{} } - var ( - edgeids []int - edges = make(map[int][]*User) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: false, - Table: user.FriendsTable, - Columns: user.FriendsPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(user.FriendsPrimaryKey[0], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") - } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.FriendsTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FriendsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.FriendsPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.FriendsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) - } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, uq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "friends": %w`, err) - } - query.Where(user.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "friends" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Friends = append(nodes[i].Edges.Friends, n) + for kn := range nodes { + kn.Edges.Friends = append(kn.Edges.Friends, n) } } } if query := uq.withGroups; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - ids := make(map[int]*User, len(nodes)) - for _, node := range nodes { - ids[node.ID] = node - fks = append(fks, node.ID) + edgeids := make([]driver.Value, len(nodes)) + byid := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeids[i] = node.ID + byid[node.ID] = node node.Edges.Groups = []*Group{} } - var ( - edgeids []int - edges = make(map[int][]*User) - ) - _spec := &sqlgraph.EdgeQuerySpec{ - Edge: &sqlgraph.EdgeSpec{ - Inverse: true, - Table: user.GroupsTable, - Columns: user.GroupsPrimaryKey, - }, - Predicate: func(s *sql.Selector) { - s.Where(sql.InValues(user.GroupsPrimaryKey[1], fks...)) - }, - ScanValues: func() [2]interface{} { - return [2]interface{}{new(sql.NullInt64), new(sql.NullInt64)} - }, - Assign: func(out, in interface{}) error { - eout, ok := out.(*sql.NullInt64) - if !ok || eout == nil { - return fmt.Errorf("unexpected id value for edge-out") + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.GroupsTable) + s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[0]), edgeids...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.GroupsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - ein, ok := in.(*sql.NullInt64) - if !ok || ein == nil { - return fmt.Errorf("unexpected id value for edge-in") - } - outValue := int(eout.Int64) - inValue := int(ein.Int64) - node, ok := ids[outValue] - if !ok { - return fmt.Errorf("unexpected node id in edges: %v", outValue) - } - if _, ok := edges[inValue]; !ok { - edgeids = append(edgeids, inValue) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - edges[inValue] = append(edges[inValue], node) + nids[inValue][byid[outValue]] = struct{}{} return nil - }, - } - if err := sqlgraph.QueryEdges(ctx, uq.driver, _spec); err != nil { - return nil, fmt.Errorf(`query edges "groups": %w`, err) - } - query.Where(group.IDIn(edgeids...)) - neighbors, err := query.All(ctx) + } + }) if err != nil { return nil, err } for _, n := range neighbors { - nodes, ok := edges[n.ID] + nodes, ok := nids[n.ID] if !ok { return nil, fmt.Errorf(`unexpected "groups" node returned %v`, n.ID) } - for i := range nodes { - nodes[i].Edges.Groups = append(nodes[i].Edges.Groups, n) + for kn := range nodes { + kn.Edges.Groups = append(kn.Edges.Groups, n) } } } diff --git a/examples/version/ent/ent.go b/examples/version/ent/ent.go index 57fe4b8c68..3f549a3baa 100644 --- a/examples/version/ent/ent.go +++ b/examples/version/ent/ent.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/examples/version/ent/user" ) @@ -464,3 +465,6 @@ func (s *selector) BoolX(ctx context.Context) bool { } return v } + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/examples/version/ent/user_query.go b/examples/version/ent/user_query.go index 619bec3628..66008ac01c 100644 --- a/examples/version/ent/user_query.go +++ b/examples/version/ent/user_query.go @@ -317,23 +317,22 @@ func (uq *UserQuery) prepareQuery(ctx context.Context) error { return nil } -func (uq *UserQuery) sqlAll(ctx context.Context) ([]*User, error) { +func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, error) { var ( nodes = []*User{} _spec = uq.querySpec() ) _spec.ScanValues = func(columns []string) ([]interface{}, error) { - node := &User{config: uq.config} - nodes = append(nodes, node) - return node.scanValues(columns) + return (*User).scanValues(nil, columns) } _spec.Assign = func(columns []string, values []interface{}) error { - if len(nodes) == 0 { - return fmt.Errorf("ent: Assign called without calling ScanValues") - } - node := nodes[len(nodes)-1] + node := &User{config: uq.config} + nodes = append(nodes, node) return node.assignValues(columns, values) } + for i := range hooks { + hooks[i](ctx, _spec) + } if err := sqlgraph.QueryNodes(ctx, uq.driver, _spec); err != nil { return nil, err }