Skip to content

Commit

Permalink
entc/gen: add the sql/execquery feature flag
Browse files Browse the repository at this point in the history
  • Loading branch information
a8m committed Apr 5, 2022
1 parent 05246cb commit 8498d64
Show file tree
Hide file tree
Showing 11 changed files with 222 additions and 5 deletions.
50 changes: 49 additions & 1 deletion dialect/dialect.go
Expand Up @@ -93,12 +93,36 @@ func (d *DebugDriver) Exec(ctx context.Context, query string, args, v interface{
return d.Driver.Exec(ctx, query, args, v)
}

// ExecContext logs its params and calls the underlying driver ExecContext method if it is supported.
func (d *DebugDriver) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
drv, ok := d.Driver.(interface {
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
})
if !ok {
return nil, fmt.Errorf("Driver.ExecContext is not supported")
}
d.log(ctx, fmt.Sprintf("driver.ExecContext: query=%v args=%v", query, args))
return drv.ExecContext(ctx, query, args...)
}

// Query logs its params and calls the underlying driver Query method.
func (d *DebugDriver) Query(ctx context.Context, query string, args, v interface{}) error {
d.log(ctx, fmt.Sprintf("driver.Query: query=%v args=%v", query, args))
return d.Driver.Query(ctx, query, args, v)
}

// QueryContext logs its params and calls the underlying driver QueryContext method if it is supported.
func (d *DebugDriver) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
drv, ok := d.Driver.(interface {
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
})
if !ok {
return nil, fmt.Errorf("Driver.QueryContext is not supported")
}
d.log(ctx, fmt.Sprintf("driver.QueryContext: query=%v args=%v", query, args))
return drv.QueryContext(ctx, query, args...)
}

// Tx adds an log-id for the transaction and calls the underlying driver Tx command.
func (d *DebugDriver) Tx(ctx context.Context) (Tx, error) {
tx, err := d.Driver.Tx(ctx)
Expand All @@ -110,7 +134,7 @@ func (d *DebugDriver) Tx(ctx context.Context) (Tx, error) {
return &DebugTx{tx, id, d.log, ctx}, nil
}

// BeginTx adds an log-id for the transaction and calls the underlying driver BeginTx command if it's supported.
// BeginTx adds an log-id for the transaction and calls the underlying driver BeginTx command if it is supported.
func (d *DebugDriver) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) {
drv, ok := d.Driver.(interface {
BeginTx(context.Context, *sql.TxOptions) (Tx, error)
Expand Down Expand Up @@ -141,12 +165,36 @@ func (d *DebugTx) Exec(ctx context.Context, query string, args, v interface{}) e
return d.Tx.Exec(ctx, query, args, v)
}

// ExecContext logs its params and calls the underlying transaction ExecContext method if it is supported.
func (d *DebugTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
drv, ok := d.Tx.(interface {
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
})
if !ok {
return nil, fmt.Errorf("Tx.ExecContext is not supported")
}
d.log(ctx, fmt.Sprintf("Tx(%s).ExecContext: query=%v args=%v", d.id, query, args))
return drv.ExecContext(ctx, query, args...)
}

// Query logs its params and calls the underlying transaction Query method.
func (d *DebugTx) Query(ctx context.Context, query string, args, v interface{}) error {
d.log(ctx, fmt.Sprintf("Tx(%s).Query: query=%v args=%v", d.id, query, args))
return d.Tx.Query(ctx, query, args, v)
}

// QueryContext logs its params and calls the underlying transaction QueryContext method if it is supported.
func (d *DebugTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
drv, ok := d.Tx.(interface {
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
})
if !ok {
return nil, fmt.Errorf("Tx.QueryContext is not supported")
}
d.log(ctx, fmt.Sprintf("Tx(%s).QueryContext: query=%v args=%v", d.id, query, args))
return drv.QueryContext(ctx, query, args...)
}

// Commit logs this step and calls the underlying transaction Commit method.
func (d *DebugTx) Commit() error {
d.log(d.ctx, fmt.Sprintf("Tx(%s): committed", d.id))
Expand Down
6 changes: 3 additions & 3 deletions dialect/sql/driver.go
Expand Up @@ -62,8 +62,8 @@ func (d *Driver) BeginTx(ctx context.Context, opts *TxOptions) (dialect.Tx, erro
return nil, err
}
return &Tx{
ExecQuerier: Conn{tx},
Tx: tx,
Conn: Conn{tx},
Tx: tx,
}, nil
}

Expand All @@ -72,7 +72,7 @@ func (d *Driver) Close() error { return d.DB().Close() }

// Tx implements dialect.Tx interface.
type Tx struct {
dialect.ExecQuerier
Conn
driver.Tx
}

Expand Down
9 changes: 9 additions & 0 deletions entc/gen/feature.go
Expand Up @@ -84,6 +84,14 @@ var (
Description: "Allows users to attach custom modifiers to queries",
}

// FeatureExecQuery provides a feature-flag for exposing the ExecContext/QueryContext methods of the underlying SQL drivers.
FeatureExecQuery = Feature{
Name: "sql/execquery",
Stage: Experimental,
Default: false,
Description: "Allows users to execute statements using the ExecContext/QueryContext methods of the underlying driver",
}

// FeatureUpsert provides a feature-flag for adding upsert (ON CONFLICT) capabilities to create builders.
FeatureUpsert = Feature{
Name: "sql/upsert",
Expand All @@ -107,6 +115,7 @@ var (
FeatureSchemaConfig,
FeatureLock,
FeatureModifier,
FeatureExecQuery,
FeatureUpsert,
FeatureVersionedMigration,
}
Expand Down
2 changes: 2 additions & 0 deletions entc/gen/template.go
Expand Up @@ -195,6 +195,8 @@ var (
"model/additional/*",
"model/comment/additional/*",
"model/edges/fields/additional/*",
"tx/additional/*",
"tx/additional/*/*",
"update/additional/*",
"query/additional/*",
}
Expand Down
6 changes: 6 additions & 0 deletions entc/gen/template/config.tmpl
Expand Up @@ -108,4 +108,10 @@ func Driver(driver dialect.Driver) Option {
{{- end }}
{{- end }}

{{- with $tmpls := matchTemplate "config/additional/*" "config/additional/*/*" }}
{{- range $tmpl := $tmpls }}
{{- xtemplate $tmpl $ }}
{{- end }}
{{- end }}

{{ end }}
71 changes: 71 additions & 0 deletions entc/gen/template/dialect/sql/feature/execquery.tmpl
@@ -0,0 +1,71 @@
{{/*
Copyright 2019-present Facebook Inc. All rights reserved.
This source code is licensed under the Apache 2.0 license found
in the LICENSE file in the root directory of this source tree.
*/}}

{{/* gotype: entgo.io/ent/entc/gen.Graph*/}}

{{ define "import/additional/stdsql" }}
{{- if $.FeatureEnabled "sql/execquery" }}
stdsql "database/sql"
{{- end }}
{{ end }}

{{/* Template for adding "ExecContext"/"QueryContext" methods to the config. */}}
{{ define "config/additional/sql/execquery" }}
{{- if $.FeatureEnabled "sql/execquery" }}
// ExecContext allows calling the underlying ExecContext method of the driver if it is supported by it.
// See, database/sql#DB.ExecContext for more information.
func (c *config) ExecContext(ctx context.Context, query string, args ...interface{}) (stdsql.Result, error) {
ex, ok := c.driver.(interface {
ExecContext(context.Context, string, ...interface{}) (stdsql.Result, error)
})
if !ok {
return nil, fmt.Errorf("Driver.ExecContext is not supported")
}
return ex.ExecContext(ctx, query, args...)
}

// QueryContext allows calling the underlying QueryContext method of the driver if it is supported by it.
// See, database/sql#DB.QueryContext for more information.
func (c *config) QueryContext(ctx context.Context, query string, args ...interface{}) (*stdsql.Rows, error) {
q, ok := c.driver.(interface {
QueryContext(context.Context, string, ...interface{}) (*stdsql.Rows, error)
})
if !ok {
return nil, fmt.Errorf("Driver.QueryContext is not supported")
}
return q.QueryContext(ctx, query, args...)
}
{{- end }}
{{ end }}

{{/* Template for adding "ExecContext"/"QueryContext" methods to the client. */}}
{{ define "tx/additional/sql/execquery" }}
{{- if $.FeatureEnabled "sql/execquery" }}
// ExecContext allows calling the underlying ExecContext method of the transaction if it is supported by it.
// See, database/sql#Tx.ExecContext for more information.
func (tx *txDriver) ExecContext(ctx context.Context, query string, args ...interface{}) (stdsql.Result, error) {
ex, ok := tx.tx.(interface {
ExecContext(context.Context, string, ...interface{}) (stdsql.Result, error)
})
if !ok {
return nil, fmt.Errorf("Tx.ExecContext is not supported")
}
return ex.ExecContext(ctx, query, args...)
}

// QueryContext allows calling the underlying QueryContext method of the transaction if it is supported by it.
// See, database/sql#Tx.QueryContext for more information.
func (tx *txDriver) QueryContext(ctx context.Context, query string, args ...interface{}) (*stdsql.Rows, error) {
q, ok := tx.tx.(interface {
QueryContext(context.Context, string, ...interface{}) (*stdsql.Rows, error)
})
if !ok {
return nil, fmt.Errorf("Tx.QueryContext is not supported")
}
return q.QueryContext(ctx, query, args...)
}
{{- end }}
{{ end }}
6 changes: 6 additions & 0 deletions entc/gen/template/tx.tmpl
Expand Up @@ -174,4 +174,10 @@ func (tx *txDriver) Query(ctx context.Context, query string, args, v interface{}

var _ dialect.Driver = (*txDriver)(nil)

{{- with $tmpls := matchTemplate "tx/additional/*" "tx/additional/*/*" }}
{{- range $tmpl := $tmpls }}
{{- xtemplate $tmpl $ }}
{{- end }}
{{- end }}

{{ end }}
28 changes: 28 additions & 0 deletions entc/integration/ent/config.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion entc/integration/ent/generate.go
Expand Up @@ -4,4 +4,4 @@

package ent

//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --feature entql,sql/modifier,sql/lock,sql/upsert --template ./template --header "// Copyright 2019-present Facebook Inc. All rights reserved.\n// This source code is licensed under the Apache 2.0 license found\n// in the LICENSE file in the root directory of this source tree.\n\n// Code generated by entc, DO NOT EDIT." ./schema
//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --feature entql,sql/modifier,sql/lock,sql/upsert,sql/execquery --template ./template --header "// Copyright 2019-present Facebook Inc. All rights reserved.\n// This source code is licensed under the Apache 2.0 license found\n// in the LICENSE file in the root directory of this source tree.\n\n// Code generated by entc, DO NOT EDIT." ./schema
26 changes: 26 additions & 0 deletions entc/integration/ent/tx.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 21 additions & 0 deletions entc/integration/integration_test.go
Expand Up @@ -38,6 +38,7 @@ import (
"entgo.io/ent/entc/integration/ent/pet"
"entgo.io/ent/entc/integration/ent/schema"
"entgo.io/ent/entc/integration/ent/user"
"entgo.io/ent/entc/integration/privacy/ent/task"

"github.com/go-sql-driver/mysql"
"github.com/lib/pq"
Expand Down Expand Up @@ -128,6 +129,7 @@ var (
Delete,
Upsert,
Relation,
ExecQuery,
Predicate,
AddValues,
ClearEdges,
Expand Down Expand Up @@ -680,6 +682,25 @@ func Select(t *testing.T, client *ent.Client) {
require.Equal(lab.QueryUsers().CountX(ctx), gs[1].UsersCount)
}

func ExecQuery(t *testing.T, client *ent.Client) {
require := require.New(t)
ctx := context.Background()
rows, err := client.QueryContext(ctx, "SELECT 1")
require.NoError(err)
require.True(rows.Next())
require.NoError(rows.Close())
tx, err := client.Tx(ctx)
require.NoError(err)
tx.Task.Create().ExecX(ctx)
require.Equal(1, tx.Task.Query().CountX(ctx))
rows, err = tx.QueryContext(ctx, "SELECT COUNT(*) FROM "+task.Table)
require.NoError(err)
count, err := sql.ScanInt(rows)
require.NoError(err)
require.Equal(1, count)
require.NoError(tx.Commit())
}

func Predicate(t *testing.T, client *ent.Client) {
require := require.New(t)
ctx := context.Background()
Expand Down

0 comments on commit 8498d64

Please sign in to comment.