Skip to content

Commit

Permalink
entc/gen: use join for loading m2m relationship (ent#2417)
Browse files Browse the repository at this point in the history
* entc/gen: use join for m2m relationship

* entc/gen: add test for eager-load inverse-m2m
  • Loading branch information
a8m authored and gitlawr committed Apr 13, 2022
1 parent 5961136 commit 06f1223
Show file tree
Hide file tree
Showing 125 changed files with 2,093 additions and 2,483 deletions.
6 changes: 5 additions & 1 deletion entc/gen/template/dialect/sql/globals.tmpl
Expand Up @@ -7,4 +7,8 @@ in the LICENSE file in the root directory of this source tree.
{{/* gotype: entgo.io/ent/entc/gen.Graph */}}

{{/* custom globals and helpers for sql dialects */}}
{{ define "dialect/sql/globals" }}{{ end }}
{{ define "dialect/sql/globals" }}

// queryHook describes an internal hook for the different sqlAll methods.
type queryHook func(context.Context, *sqlgraph.QuerySpec)
{{ end }}
111 changes: 47 additions & 64 deletions entc/gen/template/dialect/sql/query.tmpl
Expand Up @@ -23,7 +23,7 @@ in the LICENSE file in the root directory of this source tree.
{{ $builder := pascal $.Scope.Builder }}
{{ $receiver := receiver $builder }}

func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context) ([]*{{ $.Name }}, error) {
func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...queryHook) ([]*{{ $.Name }}, error) {
var (
nodes = []*{{ $.Name }}{}
{{- with $.UnexportedForeignKeys }}
Expand All @@ -49,15 +49,11 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context) ([]*{{ $.Name
}
{{- end }}
_spec.ScanValues = func(columns []string) ([]interface{}, error) {
node := &{{ $.Name }}{config: {{ $receiver }}.config}
nodes = append(nodes, node)
return node.scanValues(columns)
return (*{{ $.Name }}).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []interface{}) error {
if len(nodes) == 0 {
return fmt.Errorf("{{ $pkg }}: Assign called without calling ScanValues")
}
node := nodes[len(nodes)-1]
node := &{{ $.Name }}{config: {{ $receiver }}.config}
nodes = append(nodes, node)
{{- with $.Edges }}
node.Edges.loadedTypes = loadedTypes
{{- end }}
Expand All @@ -69,6 +65,9 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context) ([]*{{ $.Name
{{- xtemplate $tmpl $ }}
{{- end }}
{{- end }}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, {{ $receiver }}.driver, _spec); err != nil {
return nil, err
}
Expand Down Expand Up @@ -272,74 +271,58 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
{{- $receiver := $.Scope.Rec }}
if query := {{ $receiver }}.{{ $e.EagerLoadField }}; query != nil {
{{- if $e.M2M }}
fks := make([]driver.Value, 0, len(nodes))
ids := make(map[{{ $.ID.Type }}]*{{ $.Name }}, len(nodes))
for _, node := range nodes {
ids[node.ID] = node
fks = append(fks, node.ID)
edgeids := make([]driver.Value, len(nodes))
byid := make(map[{{ $.ID.Type }}]*{{ $.Name }})
nids := make(map[{{ $e.Type.ID.Type }}]map[*{{ $.Name }}]struct{})
for i, node := range nodes {
edgeids[i] = node.ID
byid[node.ID] = node
node.Edges.{{ $e.StructField }} = []*{{ $e.Type.Name }}{}
}
var (
edgeids []{{ $e.Type.ID.Type }}
edges = make(map[{{ $e.Type.ID.Type }}][]*{{ $.Name }})
)
_spec := &sqlgraph.EdgeQuerySpec{
Edge: &sqlgraph.EdgeSpec{
Inverse: {{ $e.IsInverse }},
Table: {{ $.Package }}.{{ $e.TableConstant }},
Columns: {{ $.Package }}.{{ $e.PKConstant }},
},
Predicate: func(s *sql.Selector) {
s.Where(sql.InValues({{ $.Package }}.{{ $e.PKConstant }}[{{ if $e.IsInverse }}1{{ else }}0{{ end }}], fks...))
},
query.Where(func(s *sql.Selector) {
joinT := sql.Table({{ $.Package }}.{{ $e.TableConstant }})
{{- $edgeid := print $e.Type.Package "." $e.Type.ID.Constant }}
{{- $fk1idx := 1 }}{{- $fk2idx := 0 }}{{ if $e.IsInverse }}{{ $fk1idx = 0 }}{{ $fk2idx = 1 }}{{ end }}
s.Join(joinT).On(s.C({{ $edgeid }}), joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk1idx }}]))
s.Where(sql.InValues(joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk2idx }}]), edgeids...))
columns := s.SelectedColumns()
s.Select(joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk2idx }}]))
s.AppendSelect(columns...)
s.SetDistinct(false)
})
neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) {
assign := spec.Assign
values := spec.ScanValues
{{- $out := "sql.NullInt64" }}{{ if $.ID.UserDefined }}{{ $out = $.ID.ScanType }}{{ end }}
{{- $in := "sql.NullInt64" }}{{ if $e.Type.ID.UserDefined }}{{ $in = $e.Type.ID.ScanType }}{{ end }}
ScanValues: func() [2]interface{}{
return [2]interface{}{new({{ $out }}), new({{ $in }})}
},
Assign: func(out, in interface{}) error {
eout, ok := out.(*{{ $out }})
if !ok || eout == nil {
return fmt.Errorf("unexpected id value for edge-out")
}
ein, ok := in.(*{{ $in }})
if !ok || ein == nil {
return fmt.Errorf("unexpected id value for edge-in")
}
outValue := {{ with extend $ "Arg" "eout" "Field" $.ID "ScanType" $out }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
inValue := {{ with extend $ "Arg" "ein" "Field" $e.Type.ID "ScanType" $in }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
node, ok := ids[outValue]
if !ok {
return fmt.Errorf("unexpected node id in edges: %v", outValue)
spec.ScanValues = func(columns []string) ([]interface{}, error) {
values, err := values(columns[1:])
if err != nil {
return nil, err
}
if _, ok := edges[inValue]; !ok {
edgeids = append(edgeids, inValue)
return append([]interface{}{new({{ $out }})}, values...), nil
}
spec.Assign = func(columns []string, values []interface{}) error {
outValue := {{ with extend $ "Arg" "values[0]" "Field" $.ID "ScanType" $out }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
inValue := {{ with extend $ "Arg" "values[1]" "Field" $e.Type.ID "ScanType" $in }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
if nids[inValue] == nil {
nids[inValue] = map[*{{ $.Name }}]struct{}{byid[outValue]: struct{}{}}
return assign(columns[1:], values[1:])
}
edges[inValue] = append(edges[inValue], node)
nids[inValue][byid[outValue]] = struct{}{}
return nil
},
}
{{- /* Allow mutating the sqlgraph.EdgeQuerySpec by ent extensions or user templates.*/}}
{{- with $tmpls := matchTemplate "dialect/sql/query/eagerloading/spec/*" }}
{{- range $tmpl := $tmpls }}
{{- xtemplate $tmpl $ }}
{{- end }}
{{- end }}
if err := sqlgraph.QueryEdges(ctx, {{ $receiver }}.driver, _spec); err != nil {
return nil, fmt.Errorf(`query edges "{{ $e.Name }}": %w`, err)
}
query.Where({{ $e.Type.Package }}.IDIn(edgeids...))
neighbors, err := query.All(ctx)
}
})
if err != nil {
return nil, err
}
for _, n := range neighbors {
nodes, ok := edges[n.ID]
nodes, ok := nids[n.ID]
if !ok {
return nil, fmt.Errorf(`unexpected "{{ $e.Name }}" node returned %v`, n.ID)
}
for i := range nodes {
nodes[i].Edges.{{ $e.StructField }} = append(nodes[i].Edges.{{ $e.StructField }}, n)
for kn := range nodes {
kn.Edges.{{ $e.StructField }} = append(kn.Edges.{{ $e.StructField }}, n)
}
}
{{- else if $e.OwnFK }}
Expand Down Expand Up @@ -415,9 +398,9 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
{{- $field := $.Scope.Field }}
{{- $scantype := $.Scope.ScanType }}
{{- if hasPrefix $scantype "sql" -}}
{{ $field.ScanTypeField $arg -}}
{{ printf "%s.(*%s)" $arg $scantype | $field.ScanTypeField -}}
{{- else -}}
{{ if not $field.Nillable }}*{{ end }}{{ $arg }}
{{ if not $field.Nillable }}*{{ end }}{{ printf "%s.(*%s)" $arg $scantype }}
{{- end }}
{{- end }}

Expand Down
15 changes: 7 additions & 8 deletions entc/integration/cascadelete/ent/comment_query.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions entc/integration/cascadelete/ent/ent.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 7 additions & 8 deletions entc/integration/cascadelete/ent/post_query.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 7 additions & 8 deletions entc/integration/cascadelete/ent/user_query.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions entc/integration/config/ent/ent.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 7 additions & 8 deletions entc/integration/config/ent/user_query.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 7 additions & 8 deletions entc/integration/customid/ent/account_query.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 06f1223

Please sign in to comment.