Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

entc/gen: move select and group builders' scan functions to shared struct #2412

Merged
merged 1 commit into from Mar 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
66 changes: 66 additions & 0 deletions entc/gen/template/base.tmpl
Expand Up @@ -190,6 +190,72 @@ func IsConstraintError(err error) bool {
return errors.As(err, &e)
}


// selector embedded by the different Select/GroupBy builders.
type selector struct {
label string
flds *[]string
scan func (context.Context, interface{}) error
}

// ScanX is like Scan, but panics if an error occurs.
func (s *selector) ScanX(ctx context.Context, v interface{}) {
if err := s.scan(ctx, v); err != nil {
panic(err)
}
}

{{ range $t := primitives }}
{{ $plural := pascal $t | plural }}
// {{ $plural }} returns list of {{ plural $t }} from a selector. It is only allowed when selecting one field.
func (s *selector) {{ $plural }}(ctx context.Context) ([]{{ $t }}, error) {
if len(*s.flds) > 1 {
return nil, errors.New("{{ $pkg }}: {{ $plural }} is not achievable when selecting more than 1 field")
}
var v []{{ $t }}
if err := s.scan(ctx, &v); err != nil {
return nil, err
}
return v, nil
}

// {{ $plural }}X is like {{ $plural }}, but panics if an error occurs.
func (s *selector) {{ $plural }}X(ctx context.Context) []{{ $t }} {
v, err := s.{{ $plural }}(ctx)
if err != nil {
panic(err)
}
return v
}

{{ $singular := pascal $t -}}
// {{ $singular }} returns a single {{ $t }} from a selector. It is only allowed when selecting one field.
func (s *selector) {{ $singular }}(ctx context.Context) (_ {{ $t }}, err error) {
var v []{{ $t }}
if v, err = s.{{ $plural }}(ctx); err != nil {
return
}
switch len(v) {
case 1:
return v[0], nil
case 0:
err = &NotFoundError{s.label}
default:
err = fmt.Errorf("{{ $pkg }}: {{ $plural }} returned %d results when one was expected", len(v))
}
return
}

// {{ $singular }}X is like {{ $singular }}, but panics if an error occurs.
func (s *selector) {{ $singular }}X(ctx context.Context) {{ $t }} {
v, err := s.{{ $singular }}(ctx)
if err != nil {
panic(err)
}
return v
}
{{ end }}

{{/* expand error types and global helpers. */}}
{{ $tmpl = printf "dialect/%s/errors" $.Storage }}
{{ if hasTemplate $tmpl }}
Expand Down
136 changes: 12 additions & 124 deletions entc/gen/template/builder/query.tmpl
Expand Up @@ -324,15 +324,17 @@ func ({{ $receiver }} *{{ $builder }}) Clone() *{{ $builder }} {
//
{{- end }}
func ({{ $receiver }} *{{ $builder }}) GroupBy(field string, fields ...string) *{{ $groupBuilder }} {
group := &{{ $groupBuilder }}{config: {{ $receiver }}.config}
group.fields = append([]string{field}, fields...)
group.path = func(ctx context.Context) (prev {{ $.Storage.Builder }}, err error) {
grbuild := &{{ $groupBuilder }}{config: {{ $receiver }}.config}
grbuild.fields = append([]string{field}, fields...)
grbuild.path = func(ctx context.Context) (prev {{ $.Storage.Builder }}, err error) {
if err := {{ $receiver }}.prepareQuery(ctx); err != nil {
return nil, err
}
return {{ $receiver }}.{{ $.Storage }}Query(ctx), nil
}
return group
grbuild.label = {{ $.Package }}.Label
grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan
return grbuild
}

{{ $selectBuilder := pascal $.Name | printf "%sSelect" }}
Expand All @@ -355,7 +357,10 @@ func ({{ $receiver }} *{{ $builder }}) GroupBy(field string, fields ...string) *
{{- end }}
func ({{ $receiver }} *{{ $builder }}) Select(fields ...string) *{{ $selectBuilder }} {
{{ $receiver }}.fields = append({{ $receiver }}.fields, fields...)
return &{{ $selectBuilder }}{ {{ $builder }}: {{ $receiver }} }
selbuild := &{{ $selectBuilder }}{ {{ $builder }}: {{ $receiver }} }
selbuild.label = {{ $.Package }}.Label
selbuild.flds, selbuild.scan = &{{ $receiver }}.fields, selbuild.Scan
return selbuild
}

func ({{ $receiver }} *{{ $builder }}) prepareQuery(ctx context.Context) error {
Expand Down Expand Up @@ -405,6 +410,7 @@ func ({{ $receiver }} *{{ $builder }}) prepareQuery(ctx context.Context) error {
// {{ $groupBuilder }} is the group-by builder for {{ $.Name }} entities.
type {{ $groupBuilder }} struct {
config
selector
fields []string
fns []AggregateFunc
// intermediate query (i.e. traversal path).
Expand All @@ -428,66 +434,6 @@ func ({{ $groupReceiver }} *{{ $groupBuilder }}) Scan(ctx context.Context, v int
return {{ $groupReceiver }}.{{ $.Storage }}Scan(ctx, v)
}

// ScanX is like Scan, but panics if an error occurs.
func ({{ $groupReceiver }} *{{ $groupBuilder }}) ScanX(ctx context.Context, v interface{}) {
if err := {{ $groupReceiver }}.Scan(ctx, v); err != nil {
panic(err)
}
}

{{ range $t := primitives }}
{{ $plural := pascal $t | plural }}
// {{ $plural }} returns list of {{ plural $t }} from group-by.
// It is only allowed when executing a group-by query with one field.
func ({{ $groupReceiver }} *{{ $groupBuilder }}) {{ $plural }}(ctx context.Context) ([]{{ $t }}, error) {
if len({{ $groupReceiver }}.fields) > 1 {
return nil, errors.New("{{ $pkg }}: {{ $groupBuilder }}.{{ $plural }} is not achievable when grouping more than 1 field")
}
var v []{{ $t }}
if err := {{ $groupReceiver }}.Scan(ctx, &v); err != nil {
return nil, err
}
return v, nil
}

// {{ $plural }}X is like {{ $plural }}, but panics if an error occurs.
func ({{ $groupReceiver }} *{{ $groupBuilder }}) {{ $plural }}X(ctx context.Context) []{{ $t }} {
v, err := {{ $groupReceiver }}.{{ $plural }}(ctx)
if err != nil {
panic(err)
}
return v
}

{{ $singular := pascal $t -}}
// {{ $singular }} returns a single {{ $t }} from a group-by query.
// It is only allowed when executing a group-by query with one field.
func ({{ $groupReceiver }} *{{ $groupBuilder }}) {{ $singular }}(ctx context.Context) (_ {{ $t }}, err error) {
var v []{{ $t }}
if v, err = {{ $groupReceiver }}.{{ $plural }}(ctx); err != nil {
return
}
switch len(v) {
case 1:
return v[0], nil
case 0:
err = &NotFoundError{ {{ $.Package }}.Label}
default:
err = fmt.Errorf("{{ $pkg }}: {{ $groupBuilder }}.{{ $plural }} returned %d results when one was expected", len(v))
}
return
}

// {{ $singular }}X is like {{ $singular }}, but panics if an error occurs.
func ({{ $groupReceiver }} *{{ $groupBuilder }}) {{ $singular }}X(ctx context.Context) {{ $t }} {
v, err := {{ $groupReceiver }}.{{ $singular }}(ctx)
if err != nil {
panic(err)
}
return v
}
{{ end }}

{{ with extend $ "Builder" $groupBuilder }}
{{ $tmpl := printf "dialect/%s/group" $.Storage }}
{{ xtemplate $tmpl . }}
Expand All @@ -500,11 +446,11 @@ func ({{ $groupReceiver }} *{{ $groupBuilder }}) ScanX(ctx context.Context, v in
// {{ $selectBuilder }} is the builder for selecting fields of {{ pascal $.Name }} entities.
type {{ $selectBuilder }} struct {
*{{ $builder }}
selector
// intermediate query (i.e. traversal path).
{{ $.Storage }} {{ $.Storage.Builder }}
}


// Scan applies the selector query and scans the result into the given value.
func ({{ $selectReceiver }} *{{ $selectBuilder }}) Scan(ctx context.Context, v interface{}) error {
if err := {{ $selectReceiver }}.prepareQuery(ctx); err != nil {
Expand All @@ -514,64 +460,6 @@ func ({{ $selectReceiver }} *{{ $selectBuilder }}) Scan(ctx context.Context, v i
return {{ $selectReceiver }}.{{ $.Storage }}Scan(ctx, v)
}

// ScanX is like Scan, but panics if an error occurs.
func ({{ $selectReceiver }} *{{ $selectBuilder }}) ScanX(ctx context.Context, v interface{}) {
if err := {{ $selectReceiver }}.Scan(ctx, v); err != nil {
panic(err)
}
}

{{ range $t := primitives }}
{{ $plural := pascal $t | plural }}
// {{ $plural }} returns list of {{ plural $t }} from a selector. It is only allowed when selecting one field.
func ({{ $selectReceiver }} *{{ $selectBuilder }}) {{ $plural }}(ctx context.Context) ([]{{ $t }}, error) {
if len({{ $selectReceiver }}.fields) > 1 {
return nil, errors.New("{{ $pkg }}: {{ $selectBuilder }}.{{ $plural }} is not achievable when selecting more than 1 field")
}
var v []{{ $t }}
if err := {{ $selectReceiver }}.Scan(ctx, &v); err != nil {
return nil, err
}
return v, nil
}

// {{ $plural }}X is like {{ $plural }}, but panics if an error occurs.
func ({{ $selectReceiver }} *{{ $selectBuilder }}) {{ $plural }}X(ctx context.Context) []{{ $t }} {
v, err := {{ $selectReceiver }}.{{ $plural }}(ctx)
if err != nil {
panic(err)
}
return v
}

{{ $singular := pascal $t -}}
// {{ $singular }} returns a single {{ $t }} from a selector. It is only allowed when selecting one field.
func ({{ $selectReceiver }} *{{ $selectBuilder }}) {{ $singular }}(ctx context.Context) (_ {{ $t }}, err error) {
var v []{{ $t }}
if v, err = {{ $selectReceiver }}.{{ $plural }}(ctx); err != nil {
return
}
switch len(v) {
case 1:
return v[0], nil
case 0:
err = &NotFoundError{ {{ $.Package }}.Label}
default:
err = fmt.Errorf("{{ $pkg }}: {{ $selectBuilder }}.{{ $plural }} returned %d results when one was expected", len(v))
}
return
}

// {{ $singular }}X is like {{ $singular }}, but panics if an error occurs.
func ({{ $selectReceiver }} *{{ $selectBuilder }}) {{ $singular }}X(ctx context.Context) {{ $t }} {
v, err := {{ $selectReceiver }}.{{ $singular }}(ctx)
if err != nil {
panic(err)
}
return v
}
{{ end }}

{{ with extend $ "Builder" $selectBuilder }}
{{ $tmpl := printf "dialect/%s/select" $.Storage }}
{{ xtemplate $tmpl . }}
Expand Down