diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 35d8dfda9c..a448ab0649 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,10 +14,13 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 + - uses: actions/setup-go@v2 + with: + go-version: 1.17 - name: Run linters - uses: golangci/golangci-lint-action@v2.5.2 + uses: golangci/golangci-lint-action@v3.1.0 with: - version: v1.44.0 + version: v1.45.2 unit: runs-on: ubuntu-latest diff --git a/.golangci.yml b/.golangci.yml index 709f8b1ad8..8335b6ef5e 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,4 +1,5 @@ run: + go: '1.17' timeout: 3m linters-settings: diff --git a/dialect/dialect.go b/dialect/dialect.go index 293860c284..95f07f95e1 100644 --- a/dialect/dialect.go +++ b/dialect/dialect.go @@ -93,12 +93,36 @@ func (d *DebugDriver) Exec(ctx context.Context, query string, args, v interface{ return d.Driver.Exec(ctx, query, args, v) } +// ExecContext logs its params and calls the underlying driver ExecContext method if it is supported. +func (d *DebugDriver) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + drv, ok := d.Driver.(interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + }) + if !ok { + return nil, fmt.Errorf("Driver.ExecContext is not supported") + } + d.log(ctx, fmt.Sprintf("driver.ExecContext: query=%v args=%v", query, args)) + return drv.ExecContext(ctx, query, args...) +} + // Query logs its params and calls the underlying driver Query method. func (d *DebugDriver) Query(ctx context.Context, query string, args, v interface{}) error { d.log(ctx, fmt.Sprintf("driver.Query: query=%v args=%v", query, args)) return d.Driver.Query(ctx, query, args, v) } +// QueryContext logs its params and calls the underlying driver QueryContext method if it is supported. +func (d *DebugDriver) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + drv, ok := d.Driver.(interface { + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + }) + if !ok { + return nil, fmt.Errorf("Driver.QueryContext is not supported") + } + d.log(ctx, fmt.Sprintf("driver.QueryContext: query=%v args=%v", query, args)) + return drv.QueryContext(ctx, query, args...) +} + // Tx adds an log-id for the transaction and calls the underlying driver Tx command. func (d *DebugDriver) Tx(ctx context.Context) (Tx, error) { tx, err := d.Driver.Tx(ctx) @@ -110,7 +134,7 @@ func (d *DebugDriver) Tx(ctx context.Context) (Tx, error) { return &DebugTx{tx, id, d.log, ctx}, nil } -// BeginTx adds an log-id for the transaction and calls the underlying driver BeginTx command if it's supported. +// BeginTx adds an log-id for the transaction and calls the underlying driver BeginTx command if it is supported. func (d *DebugDriver) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { drv, ok := d.Driver.(interface { BeginTx(context.Context, *sql.TxOptions) (Tx, error) @@ -141,12 +165,36 @@ func (d *DebugTx) Exec(ctx context.Context, query string, args, v interface{}) e return d.Tx.Exec(ctx, query, args, v) } +// ExecContext logs its params and calls the underlying transaction ExecContext method if it is supported. +func (d *DebugTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + drv, ok := d.Tx.(interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + }) + if !ok { + return nil, fmt.Errorf("Tx.ExecContext is not supported") + } + d.log(ctx, fmt.Sprintf("Tx(%s).ExecContext: query=%v args=%v", d.id, query, args)) + return drv.ExecContext(ctx, query, args...) +} + // Query logs its params and calls the underlying transaction Query method. func (d *DebugTx) Query(ctx context.Context, query string, args, v interface{}) error { d.log(ctx, fmt.Sprintf("Tx(%s).Query: query=%v args=%v", d.id, query, args)) return d.Tx.Query(ctx, query, args, v) } +// QueryContext logs its params and calls the underlying transaction QueryContext method if it is supported. +func (d *DebugTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + drv, ok := d.Tx.(interface { + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + }) + if !ok { + return nil, fmt.Errorf("Tx.QueryContext is not supported") + } + d.log(ctx, fmt.Sprintf("Tx(%s).QueryContext: query=%v args=%v", d.id, query, args)) + return drv.QueryContext(ctx, query, args...) +} + // Commit logs this step and calls the underlying transaction Commit method. func (d *DebugTx) Commit() error { d.log(d.ctx, fmt.Sprintf("Tx(%s): committed", d.id)) diff --git a/dialect/sql/driver.go b/dialect/sql/driver.go index 6f79be8dd3..2d24957d27 100644 --- a/dialect/sql/driver.go +++ b/dialect/sql/driver.go @@ -62,8 +62,8 @@ func (d *Driver) BeginTx(ctx context.Context, opts *TxOptions) (dialect.Tx, erro return nil, err } return &Tx{ - ExecQuerier: Conn{tx}, - Tx: tx, + Conn: Conn{tx}, + Tx: tx, }, nil } @@ -72,7 +72,7 @@ func (d *Driver) Close() error { return d.DB().Close() } // Tx implements dialect.Tx interface. type Tx struct { - dialect.ExecQuerier + Conn driver.Tx } diff --git a/entc/gen/feature.go b/entc/gen/feature.go index 357b93e16f..55c55ae38c 100644 --- a/entc/gen/feature.go +++ b/entc/gen/feature.go @@ -84,6 +84,14 @@ var ( Description: "Allows users to attach custom modifiers to queries", } + // FeatureExecQuery provides a feature-flag for exposing the ExecContext/QueryContext methods of the underlying SQL drivers. + FeatureExecQuery = Feature{ + Name: "sql/execquery", + Stage: Experimental, + Default: false, + Description: "Allows users to execute statements using the ExecContext/QueryContext methods of the underlying driver", + } + // FeatureUpsert provides a feature-flag for adding upsert (ON CONFLICT) capabilities to create builders. FeatureUpsert = Feature{ Name: "sql/upsert", @@ -107,6 +115,7 @@ var ( FeatureSchemaConfig, FeatureLock, FeatureModifier, + FeatureExecQuery, FeatureUpsert, FeatureVersionedMigration, } diff --git a/entc/gen/template.go b/entc/gen/template.go index 2f8f14cafd..563812b5bf 100644 --- a/entc/gen/template.go +++ b/entc/gen/template.go @@ -195,6 +195,8 @@ var ( "model/additional/*", "model/comment/additional/*", "model/edges/fields/additional/*", + "tx/additional/*", + "tx/additional/*/*", "update/additional/*", "query/additional/*", } diff --git a/entc/gen/template/config.tmpl b/entc/gen/template/config.tmpl index 9831d3d75c..e166bfc547 100644 --- a/entc/gen/template/config.tmpl +++ b/entc/gen/template/config.tmpl @@ -108,4 +108,10 @@ func Driver(driver dialect.Driver) Option { {{- end }} {{- end }} +{{- with $tmpls := matchTemplate "config/additional/*" "config/additional/*/*" }} + {{- range $tmpl := $tmpls }} + {{- xtemplate $tmpl $ }} + {{- end }} +{{- end }} + {{ end }} diff --git a/entc/gen/template/dialect/sql/feature/execquery.tmpl b/entc/gen/template/dialect/sql/feature/execquery.tmpl new file mode 100644 index 0000000000..50ad096a6f --- /dev/null +++ b/entc/gen/template/dialect/sql/feature/execquery.tmpl @@ -0,0 +1,71 @@ +{{/* +Copyright 2019-present Facebook Inc. All rights reserved. +This source code is licensed under the Apache 2.0 license found +in the LICENSE file in the root directory of this source tree. +*/}} + +{{/* gotype: entgo.io/ent/entc/gen.Graph*/}} + +{{- define "import/additional/stdsql" -}} + {{- if $.FeatureEnabled "sql/execquery" }} + stdsql "database/sql" + {{- end }} +{{- end -}} + +{{/* Template for adding "ExecContext"/"QueryContext" methods to the config. */}} +{{ define "config/additional/sql/execquery" }} + {{- if $.FeatureEnabled "sql/execquery" }} + // ExecContext allows calling the underlying ExecContext method of the driver if it is supported by it. + // See, database/sql#DB.ExecContext for more information. + func (c *config) ExecContext(ctx context.Context, query string, args ...interface{}) (stdsql.Result, error) { + ex, ok := c.driver.(interface { + ExecContext(context.Context, string, ...interface{}) (stdsql.Result, error) + }) + if !ok { + return nil, fmt.Errorf("Driver.ExecContext is not supported") + } + return ex.ExecContext(ctx, query, args...) + } + + // QueryContext allows calling the underlying QueryContext method of the driver if it is supported by it. + // See, database/sql#DB.QueryContext for more information. + func (c *config) QueryContext(ctx context.Context, query string, args ...interface{}) (*stdsql.Rows, error) { + q, ok := c.driver.(interface { + QueryContext(context.Context, string, ...interface{}) (*stdsql.Rows, error) + }) + if !ok { + return nil, fmt.Errorf("Driver.QueryContext is not supported") + } + return q.QueryContext(ctx, query, args...) + } + {{- end }} +{{ end }} + +{{/* Template for adding "ExecContext"/"QueryContext" methods to the client. */}} +{{ define "tx/additional/sql/execquery" }} + {{- if $.FeatureEnabled "sql/execquery" }} + // ExecContext allows calling the underlying ExecContext method of the transaction if it is supported by it. + // See, database/sql#Tx.ExecContext for more information. + func (tx *txDriver) ExecContext(ctx context.Context, query string, args ...interface{}) (stdsql.Result, error) { + ex, ok := tx.tx.(interface { + ExecContext(context.Context, string, ...interface{}) (stdsql.Result, error) + }) + if !ok { + return nil, fmt.Errorf("Tx.ExecContext is not supported") + } + return ex.ExecContext(ctx, query, args...) + } + + // QueryContext allows calling the underlying QueryContext method of the transaction if it is supported by it. + // See, database/sql#Tx.QueryContext for more information. + func (tx *txDriver) QueryContext(ctx context.Context, query string, args ...interface{}) (*stdsql.Rows, error) { + q, ok := tx.tx.(interface { + QueryContext(context.Context, string, ...interface{}) (*stdsql.Rows, error) + }) + if !ok { + return nil, fmt.Errorf("Tx.QueryContext is not supported") + } + return q.QueryContext(ctx, query, args...) + } + {{- end }} +{{ end }} \ No newline at end of file diff --git a/entc/gen/template/tx.tmpl b/entc/gen/template/tx.tmpl index c2b8e17c98..ecc303a25b 100644 --- a/entc/gen/template/tx.tmpl +++ b/entc/gen/template/tx.tmpl @@ -174,4 +174,10 @@ func (tx *txDriver) Query(ctx context.Context, query string, args, v interface{} var _ dialect.Driver = (*txDriver)(nil) +{{- with $tmpls := matchTemplate "tx/additional/*" "tx/additional/*/*" }} + {{- range $tmpl := $tmpls }} + {{- xtemplate $tmpl $ }} + {{- end }} +{{- end }} + {{ end }} diff --git a/entc/integration/ent/config.go b/entc/integration/ent/config.go index b47ceb0002..cd1c9e3a67 100644 --- a/entc/integration/ent/config.go +++ b/entc/integration/ent/config.go @@ -7,6 +7,10 @@ package ent import ( + "context" + stdsql "database/sql" + "fmt" + "entgo.io/ent" "entgo.io/ent/dialect" ) @@ -74,3 +78,27 @@ func Driver(driver dialect.Driver) Option { c.driver = driver } } + +// ExecContext allows calling the underlying ExecContext method of the driver if it is supported by it. +// See, database/sql#DB.ExecContext for more information. +func (c *config) ExecContext(ctx context.Context, query string, args ...interface{}) (stdsql.Result, error) { + ex, ok := c.driver.(interface { + ExecContext(context.Context, string, ...interface{}) (stdsql.Result, error) + }) + if !ok { + return nil, fmt.Errorf("Driver.ExecContext is not supported") + } + return ex.ExecContext(ctx, query, args...) +} + +// QueryContext allows calling the underlying QueryContext method of the driver if it is supported by it. +// See, database/sql#DB.QueryContext for more information. +func (c *config) QueryContext(ctx context.Context, query string, args ...interface{}) (*stdsql.Rows, error) { + q, ok := c.driver.(interface { + QueryContext(context.Context, string, ...interface{}) (*stdsql.Rows, error) + }) + if !ok { + return nil, fmt.Errorf("Driver.QueryContext is not supported") + } + return q.QueryContext(ctx, query, args...) +} diff --git a/entc/integration/ent/generate.go b/entc/integration/ent/generate.go index fb51bc9776..6dcdac3eb8 100644 --- a/entc/integration/ent/generate.go +++ b/entc/integration/ent/generate.go @@ -4,4 +4,4 @@ package ent -//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --feature entql,sql/modifier,sql/lock,sql/upsert --template ./template --header "// Copyright 2019-present Facebook Inc. All rights reserved.\n// This source code is licensed under the Apache 2.0 license found\n// in the LICENSE file in the root directory of this source tree.\n\n// Code generated by entc, DO NOT EDIT." ./schema +//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --feature entql,sql/modifier,sql/lock,sql/upsert,sql/execquery --template ./template --header "// Copyright 2019-present Facebook Inc. All rights reserved.\n// This source code is licensed under the Apache 2.0 license found\n// in the LICENSE file in the root directory of this source tree.\n\n// Code generated by entc, DO NOT EDIT." ./schema diff --git a/entc/integration/ent/tx.go b/entc/integration/ent/tx.go index 4ad31a9055..7e2ab871d5 100644 --- a/entc/integration/ent/tx.go +++ b/entc/integration/ent/tx.go @@ -8,6 +8,8 @@ package ent import ( "context" + stdsql "database/sql" + "fmt" "sync" "entgo.io/ent/dialect" @@ -251,3 +253,27 @@ func (tx *txDriver) Query(ctx context.Context, query string, args, v interface{} } var _ dialect.Driver = (*txDriver)(nil) + +// ExecContext allows calling the underlying ExecContext method of the transaction if it is supported by it. +// See, database/sql#Tx.ExecContext for more information. +func (tx *txDriver) ExecContext(ctx context.Context, query string, args ...interface{}) (stdsql.Result, error) { + ex, ok := tx.tx.(interface { + ExecContext(context.Context, string, ...interface{}) (stdsql.Result, error) + }) + if !ok { + return nil, fmt.Errorf("Tx.ExecContext is not supported") + } + return ex.ExecContext(ctx, query, args...) +} + +// QueryContext allows calling the underlying QueryContext method of the transaction if it is supported by it. +// See, database/sql#Tx.QueryContext for more information. +func (tx *txDriver) QueryContext(ctx context.Context, query string, args ...interface{}) (*stdsql.Rows, error) { + q, ok := tx.tx.(interface { + QueryContext(context.Context, string, ...interface{}) (*stdsql.Rows, error) + }) + if !ok { + return nil, fmt.Errorf("Tx.QueryContext is not supported") + } + return q.QueryContext(ctx, query, args...) +} diff --git a/entc/integration/integration_test.go b/entc/integration/integration_test.go index 8c64f5f0b3..d080099c21 100644 --- a/entc/integration/integration_test.go +++ b/entc/integration/integration_test.go @@ -38,6 +38,7 @@ import ( "entgo.io/ent/entc/integration/ent/pet" "entgo.io/ent/entc/integration/ent/schema" "entgo.io/ent/entc/integration/ent/user" + "entgo.io/ent/entc/integration/privacy/ent/task" "github.com/go-sql-driver/mysql" "github.com/lib/pq" @@ -128,6 +129,7 @@ var ( Delete, Upsert, Relation, + ExecQuery, Predicate, AddValues, ClearEdges, @@ -680,6 +682,25 @@ func Select(t *testing.T, client *ent.Client) { require.Equal(lab.QueryUsers().CountX(ctx), gs[1].UsersCount) } +func ExecQuery(t *testing.T, client *ent.Client) { + require := require.New(t) + ctx := context.Background() + rows, err := client.QueryContext(ctx, "SELECT 1") + require.NoError(err) + require.True(rows.Next()) + require.NoError(rows.Close()) + tx, err := client.Tx(ctx) + require.NoError(err) + tx.Task.Create().ExecX(ctx) + require.Equal(1, tx.Task.Query().CountX(ctx)) + rows, err = tx.QueryContext(ctx, "SELECT COUNT(*) FROM "+task.Table) + require.NoError(err) + count, err := sql.ScanInt(rows) + require.NoError(err) + require.Equal(1, count) + require.NoError(tx.Commit()) +} + func Predicate(t *testing.T, client *ent.Client) { require := require.New(t) ctx := context.Background()