Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

entc/gen: add the sql/execquery feature flag #2447

Merged
merged 1 commit into from Apr 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 5 additions & 2 deletions .github/workflows/ci.yml
Expand Up @@ -14,10 +14,13 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-go@v2
with:
go-version: 1.17
- name: Run linters
uses: golangci/golangci-lint-action@v2.5.2
uses: golangci/golangci-lint-action@v3.1.0
with:
version: v1.44.0
version: v1.45.2

unit:
runs-on: ubuntu-latest
Expand Down
1 change: 1 addition & 0 deletions .golangci.yml
@@ -1,4 +1,5 @@
run:
go: '1.17'
timeout: 3m

linters-settings:
Expand Down
50 changes: 49 additions & 1 deletion dialect/dialect.go
Expand Up @@ -93,12 +93,36 @@ func (d *DebugDriver) Exec(ctx context.Context, query string, args, v interface{
return d.Driver.Exec(ctx, query, args, v)
}

// ExecContext logs its params and calls the underlying driver ExecContext method if it is supported.
func (d *DebugDriver) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
drv, ok := d.Driver.(interface {
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
})
if !ok {
return nil, fmt.Errorf("Driver.ExecContext is not supported")
}
d.log(ctx, fmt.Sprintf("driver.ExecContext: query=%v args=%v", query, args))
return drv.ExecContext(ctx, query, args...)
}

// Query logs its params and calls the underlying driver Query method.
func (d *DebugDriver) Query(ctx context.Context, query string, args, v interface{}) error {
d.log(ctx, fmt.Sprintf("driver.Query: query=%v args=%v", query, args))
return d.Driver.Query(ctx, query, args, v)
}

// QueryContext logs its params and calls the underlying driver QueryContext method if it is supported.
func (d *DebugDriver) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
drv, ok := d.Driver.(interface {
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
})
if !ok {
return nil, fmt.Errorf("Driver.QueryContext is not supported")
}
d.log(ctx, fmt.Sprintf("driver.QueryContext: query=%v args=%v", query, args))
return drv.QueryContext(ctx, query, args...)
}

// Tx adds an log-id for the transaction and calls the underlying driver Tx command.
func (d *DebugDriver) Tx(ctx context.Context) (Tx, error) {
tx, err := d.Driver.Tx(ctx)
Expand All @@ -110,7 +134,7 @@ func (d *DebugDriver) Tx(ctx context.Context) (Tx, error) {
return &DebugTx{tx, id, d.log, ctx}, nil
}

// BeginTx adds an log-id for the transaction and calls the underlying driver BeginTx command if it's supported.
// BeginTx adds an log-id for the transaction and calls the underlying driver BeginTx command if it is supported.
func (d *DebugDriver) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) {
drv, ok := d.Driver.(interface {
BeginTx(context.Context, *sql.TxOptions) (Tx, error)
Expand Down Expand Up @@ -141,12 +165,36 @@ func (d *DebugTx) Exec(ctx context.Context, query string, args, v interface{}) e
return d.Tx.Exec(ctx, query, args, v)
}

// ExecContext logs its params and calls the underlying transaction ExecContext method if it is supported.
func (d *DebugTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
drv, ok := d.Tx.(interface {
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
})
if !ok {
return nil, fmt.Errorf("Tx.ExecContext is not supported")
}
d.log(ctx, fmt.Sprintf("Tx(%s).ExecContext: query=%v args=%v", d.id, query, args))
return drv.ExecContext(ctx, query, args...)
}

// Query logs its params and calls the underlying transaction Query method.
func (d *DebugTx) Query(ctx context.Context, query string, args, v interface{}) error {
d.log(ctx, fmt.Sprintf("Tx(%s).Query: query=%v args=%v", d.id, query, args))
return d.Tx.Query(ctx, query, args, v)
}

// QueryContext logs its params and calls the underlying transaction QueryContext method if it is supported.
func (d *DebugTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
drv, ok := d.Tx.(interface {
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
})
if !ok {
return nil, fmt.Errorf("Tx.QueryContext is not supported")
}
d.log(ctx, fmt.Sprintf("Tx(%s).QueryContext: query=%v args=%v", d.id, query, args))
return drv.QueryContext(ctx, query, args...)
}

// Commit logs this step and calls the underlying transaction Commit method.
func (d *DebugTx) Commit() error {
d.log(d.ctx, fmt.Sprintf("Tx(%s): committed", d.id))
Expand Down
6 changes: 3 additions & 3 deletions dialect/sql/driver.go
Expand Up @@ -62,8 +62,8 @@ func (d *Driver) BeginTx(ctx context.Context, opts *TxOptions) (dialect.Tx, erro
return nil, err
}
return &Tx{
ExecQuerier: Conn{tx},
Tx: tx,
Conn: Conn{tx},
Tx: tx,
}, nil
}

Expand All @@ -72,7 +72,7 @@ func (d *Driver) Close() error { return d.DB().Close() }

// Tx implements dialect.Tx interface.
type Tx struct {
dialect.ExecQuerier
Conn
driver.Tx
}

Expand Down
9 changes: 9 additions & 0 deletions entc/gen/feature.go
Expand Up @@ -84,6 +84,14 @@ var (
Description: "Allows users to attach custom modifiers to queries",
}

// FeatureExecQuery provides a feature-flag for exposing the ExecContext/QueryContext methods of the underlying SQL drivers.
FeatureExecQuery = Feature{
Name: "sql/execquery",
Stage: Experimental,
Default: false,
Description: "Allows users to execute statements using the ExecContext/QueryContext methods of the underlying driver",
}

// FeatureUpsert provides a feature-flag for adding upsert (ON CONFLICT) capabilities to create builders.
FeatureUpsert = Feature{
Name: "sql/upsert",
Expand All @@ -107,6 +115,7 @@ var (
FeatureSchemaConfig,
FeatureLock,
FeatureModifier,
FeatureExecQuery,
FeatureUpsert,
FeatureVersionedMigration,
}
Expand Down
2 changes: 2 additions & 0 deletions entc/gen/template.go
Expand Up @@ -195,6 +195,8 @@ var (
"model/additional/*",
"model/comment/additional/*",
"model/edges/fields/additional/*",
"tx/additional/*",
"tx/additional/*/*",
"update/additional/*",
"query/additional/*",
}
Expand Down
6 changes: 6 additions & 0 deletions entc/gen/template/config.tmpl
Expand Up @@ -108,4 +108,10 @@ func Driver(driver dialect.Driver) Option {
{{- end }}
{{- end }}

{{- with $tmpls := matchTemplate "config/additional/*" "config/additional/*/*" }}
{{- range $tmpl := $tmpls }}
{{- xtemplate $tmpl $ }}
{{- end }}
{{- end }}

{{ end }}
71 changes: 71 additions & 0 deletions entc/gen/template/dialect/sql/feature/execquery.tmpl
@@ -0,0 +1,71 @@
{{/*
Copyright 2019-present Facebook Inc. All rights reserved.
This source code is licensed under the Apache 2.0 license found
in the LICENSE file in the root directory of this source tree.
*/}}

{{/* gotype: entgo.io/ent/entc/gen.Graph*/}}

{{- define "import/additional/stdsql" -}}
{{- if $.FeatureEnabled "sql/execquery" }}
stdsql "database/sql"
{{- end }}
{{- end -}}

{{/* Template for adding "ExecContext"/"QueryContext" methods to the config. */}}
{{ define "config/additional/sql/execquery" }}
{{- if $.FeatureEnabled "sql/execquery" }}
// ExecContext allows calling the underlying ExecContext method of the driver if it is supported by it.
// See, database/sql#DB.ExecContext for more information.
func (c *config) ExecContext(ctx context.Context, query string, args ...interface{}) (stdsql.Result, error) {
ex, ok := c.driver.(interface {
ExecContext(context.Context, string, ...interface{}) (stdsql.Result, error)
})
if !ok {
return nil, fmt.Errorf("Driver.ExecContext is not supported")
}
return ex.ExecContext(ctx, query, args...)
}

// QueryContext allows calling the underlying QueryContext method of the driver if it is supported by it.
// See, database/sql#DB.QueryContext for more information.
func (c *config) QueryContext(ctx context.Context, query string, args ...interface{}) (*stdsql.Rows, error) {
q, ok := c.driver.(interface {
QueryContext(context.Context, string, ...interface{}) (*stdsql.Rows, error)
})
if !ok {
return nil, fmt.Errorf("Driver.QueryContext is not supported")
}
return q.QueryContext(ctx, query, args...)
}
{{- end }}
{{ end }}

{{/* Template for adding "ExecContext"/"QueryContext" methods to the client. */}}
{{ define "tx/additional/sql/execquery" }}
{{- if $.FeatureEnabled "sql/execquery" }}
// ExecContext allows calling the underlying ExecContext method of the transaction if it is supported by it.
// See, database/sql#Tx.ExecContext for more information.
func (tx *txDriver) ExecContext(ctx context.Context, query string, args ...interface{}) (stdsql.Result, error) {
ex, ok := tx.tx.(interface {
ExecContext(context.Context, string, ...interface{}) (stdsql.Result, error)
})
if !ok {
return nil, fmt.Errorf("Tx.ExecContext is not supported")
}
return ex.ExecContext(ctx, query, args...)
}

// QueryContext allows calling the underlying QueryContext method of the transaction if it is supported by it.
// See, database/sql#Tx.QueryContext for more information.
func (tx *txDriver) QueryContext(ctx context.Context, query string, args ...interface{}) (*stdsql.Rows, error) {
q, ok := tx.tx.(interface {
QueryContext(context.Context, string, ...interface{}) (*stdsql.Rows, error)
})
if !ok {
return nil, fmt.Errorf("Tx.QueryContext is not supported")
}
return q.QueryContext(ctx, query, args...)
}
{{- end }}
{{ end }}
6 changes: 6 additions & 0 deletions entc/gen/template/tx.tmpl
Expand Up @@ -174,4 +174,10 @@ func (tx *txDriver) Query(ctx context.Context, query string, args, v interface{}

var _ dialect.Driver = (*txDriver)(nil)

{{- with $tmpls := matchTemplate "tx/additional/*" "tx/additional/*/*" }}
{{- range $tmpl := $tmpls }}
{{- xtemplate $tmpl $ }}
{{- end }}
{{- end }}

{{ end }}
28 changes: 28 additions & 0 deletions entc/integration/ent/config.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion entc/integration/ent/generate.go
Expand Up @@ -4,4 +4,4 @@

package ent

//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --feature entql,sql/modifier,sql/lock,sql/upsert --template ./template --header "// Copyright 2019-present Facebook Inc. All rights reserved.\n// This source code is licensed under the Apache 2.0 license found\n// in the LICENSE file in the root directory of this source tree.\n\n// Code generated by entc, DO NOT EDIT." ./schema
//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --feature entql,sql/modifier,sql/lock,sql/upsert,sql/execquery --template ./template --header "// Copyright 2019-present Facebook Inc. All rights reserved.\n// This source code is licensed under the Apache 2.0 license found\n// in the LICENSE file in the root directory of this source tree.\n\n// Code generated by entc, DO NOT EDIT." ./schema
26 changes: 26 additions & 0 deletions entc/integration/ent/tx.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 21 additions & 0 deletions entc/integration/integration_test.go
Expand Up @@ -38,6 +38,7 @@ import (
"entgo.io/ent/entc/integration/ent/pet"
"entgo.io/ent/entc/integration/ent/schema"
"entgo.io/ent/entc/integration/ent/user"
"entgo.io/ent/entc/integration/privacy/ent/task"

"github.com/go-sql-driver/mysql"
"github.com/lib/pq"
Expand Down Expand Up @@ -128,6 +129,7 @@ var (
Delete,
Upsert,
Relation,
ExecQuery,
Predicate,
AddValues,
ClearEdges,
Expand Down Expand Up @@ -680,6 +682,25 @@ func Select(t *testing.T, client *ent.Client) {
require.Equal(lab.QueryUsers().CountX(ctx), gs[1].UsersCount)
}

func ExecQuery(t *testing.T, client *ent.Client) {
require := require.New(t)
ctx := context.Background()
rows, err := client.QueryContext(ctx, "SELECT 1")
require.NoError(err)
require.True(rows.Next())
require.NoError(rows.Close())
tx, err := client.Tx(ctx)
require.NoError(err)
tx.Task.Create().ExecX(ctx)
require.Equal(1, tx.Task.Query().CountX(ctx))
rows, err = tx.QueryContext(ctx, "SELECT COUNT(*) FROM "+task.Table)
require.NoError(err)
count, err := sql.ScanInt(rows)
require.NoError(err)
require.Equal(1, count)
require.NoError(tx.Commit())
}

func Predicate(t *testing.T, client *ent.Client) {
require := require.New(t)
ctx := context.Background()
Expand Down