From d14446912ffe770fcd2a9c8689f5d8f9cfca375a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Nieto?= Date: Thu, 16 Jun 2022 18:56:08 -0500 Subject: [PATCH] fix panic on json.RawMessage parse --- adapter/postgresql/database_pgx.go | 1 + adapter/postgresql/helper_test.go | 1 + adapter/postgresql/postgresql_test.go | 6 + internal/sqladapter/collection.go | 68 ++++----- internal/sqladapter/result.go | 15 +- internal/sqladapter/session.go | 198 ++++++++++++++----------- internal/sqladapter/sqladapter_test.go | 10 +- internal/sqlbuilder/builder.go | 20 ++- internal/sqlbuilder/convert.go | 10 +- internal/testsuite/sql_suite.go | 89 ++++++++++- session.go | 7 +- settings.go | 23 +++ 12 files changed, 296 insertions(+), 152 deletions(-) diff --git a/adapter/postgresql/database_pgx.go b/adapter/postgresql/database_pgx.go index 94efa00b..954a9382 100644 --- a/adapter/postgresql/database_pgx.go +++ b/adapter/postgresql/database_pgx.go @@ -1,3 +1,4 @@ +//go:build !pq // +build !pq package postgresql diff --git a/adapter/postgresql/helper_test.go b/adapter/postgresql/helper_test.go index 04654e71..e553df17 100644 --- a/adapter/postgresql/helper_test.go +++ b/adapter/postgresql/helper_test.go @@ -194,6 +194,7 @@ func (h *Helper) TearUp() error { , integer_array integer[] , string_array text[] , jsonb_map jsonb + , raw_jsonb_map jsonb , integer_array_ptr integer[] , string_array_ptr text[] diff --git a/adapter/postgresql/postgresql_test.go b/adapter/postgresql/postgresql_test.go index 7d3708e5..a2f54f80 100644 --- a/adapter/postgresql/postgresql_test.go +++ b/adapter/postgresql/postgresql_test.go @@ -25,6 +25,7 @@ import ( "context" "database/sql" "database/sql/driver" + "encoding/json" "fmt" "math/rand" "strings" @@ -229,6 +230,8 @@ func testPostgreSQLTypes(t *testing.T, sess db.Session) { StringArray StringArray `db:"string_array,stringarray"` JSONBMap JSONBMap `db:"jsonb_map"` + RawJSONBMap json.RawMessage `db:"raw_jsonb_map"` + PGTypeInline `db:",inline"` PGTypeAutoInline `db:",inline"` @@ -329,6 +332,9 @@ func testPostgreSQLTypes(t *testing.T, sess db.Session) { AutoJSONBMapInteger: map[string]interface{}{"a": 12.0, "b": 13.0}, }, }, + PGType{ + RawJSONBMap: json.RawMessage(`{"foo": "bar"}`), + }, PGType{ IntegerValue: integerValue, StringValue: stringValue, diff --git a/internal/sqladapter/collection.go b/internal/sqladapter/collection.go index d77cb4ae..f70d0c93 100644 --- a/internal/sqladapter/collection.go +++ b/internal/sqladapter/collection.go @@ -67,42 +67,43 @@ type condsFilter interface { // collection is the implementation of Collection. type collection struct { - name string - sess Session - + name string adapter CollectionAdapter } -// NewCollection initializes a Collection by wrapping a CollectionAdapter. -func NewCollection(sess Session, name string, adapter CollectionAdapter) Collection { +type collectionWithSession struct { + *collection + + session Session +} + +func newCollection(name string, adapter CollectionAdapter) *collection { if adapter == nil { - panic("upper: received nil adapter") + panic("upper: nil adapter") } - c := &collection{ - sess: sess, + return &collection{ name: name, adapter: adapter, } - return c } -func (c *collection) SQL() db.SQL { - return c.sess.SQL() +func (c *collectionWithSession) SQL() db.SQL { + return c.session.SQL() } -func (c *collection) Session() db.Session { - return c.sess +func (c *collectionWithSession) Session() db.Session { + return c.session } -func (c *collection) Name() string { +func (c *collectionWithSession) Name() string { return c.name } -func (c *collection) Count() (uint64, error) { +func (c *collectionWithSession) Count() (uint64, error) { return c.Find().Count() } -func (c *collection) Insert(item interface{}) (db.InsertResult, error) { +func (c *collectionWithSession) Insert(item interface{}) (db.InsertResult, error) { id, err := c.adapter.Insert(c, item) if err != nil { return nil, err @@ -111,11 +112,11 @@ func (c *collection) Insert(item interface{}) (db.InsertResult, error) { return db.NewInsertResult(id), nil } -func (c *collection) PrimaryKeys() ([]string, error) { - return c.sess.PrimaryKeys(c.Name()) +func (c *collectionWithSession) PrimaryKeys() ([]string, error) { + return c.session.PrimaryKeys(c.Name()) } -func (c *collection) filterConds(conds ...interface{}) ([]interface{}, error) { +func (c *collectionWithSession) filterConds(conds ...interface{}) ([]interface{}, error) { pk, err := c.PrimaryKeys() if err != nil { return nil, err @@ -131,15 +132,16 @@ func (c *collection) filterConds(conds ...interface{}) ([]interface{}, error) { return conds, nil } -func (c *collection) Find(conds ...interface{}) db.Result { +func (c *collectionWithSession) Find(conds ...interface{}) db.Result { filteredConds, err := c.filterConds(conds...) if err != nil { res := &Result{} res.setErr(err) return res } + res := NewResult( - c.sess.SQL(), + c.session.SQL(), c.Name(), filteredConds, ) @@ -149,14 +151,14 @@ func (c *collection) Find(conds ...interface{}) db.Result { return res } -func (c *collection) Exists() (bool, error) { - if err := c.sess.TableExists(c.Name()); err != nil { +func (c *collectionWithSession) Exists() (bool, error) { + if err := c.session.TableExists(c.Name()); err != nil { return false, err } return true, nil } -func (c *collection) InsertReturning(item interface{}) error { +func (c *collectionWithSession) InsertReturning(item interface{}) error { if item == nil || reflect.TypeOf(item).Kind() != reflect.Ptr { return fmt.Errorf("Expecting a pointer but got %T", item) } @@ -175,12 +177,12 @@ func (c *collection) InsertReturning(item interface{}) error { } var tx Session - isTransaction := c.sess.IsTransaction() + isTransaction := c.session.IsTransaction() if isTransaction { - tx = c.sess + tx = c.session } else { var err error - tx, err = c.sess.NewTransaction(c.sess.Context(), nil) + tx, err = c.session.NewTransaction(c.session.Context(), nil) if err != nil { return err } @@ -261,7 +263,7 @@ cancel: return err } -func (c *collection) UpdateReturning(item interface{}) error { +func (c *collectionWithSession) UpdateReturning(item interface{}) error { if item == nil || reflect.TypeOf(item).Kind() != reflect.Ptr { return fmt.Errorf("Expecting a pointer but got %T", item) } @@ -280,14 +282,14 @@ func (c *collection) UpdateReturning(item interface{}) error { } var tx Session - isTransaction := c.sess.IsTransaction() + isTransaction := c.session.IsTransaction() if isTransaction { - tx = c.sess + tx = c.session } else { // Not within a transaction, let's create one. var err error - tx, err = c.sess.NewTransaction(c.sess.Context(), nil) + tx, err = c.session.NewTransaction(c.session.Context(), nil) if err != nil { return err } @@ -355,12 +357,12 @@ cancel: return err } -func (c *collection) Truncate() error { +func (c *collectionWithSession) Truncate() error { stmt := exql.Statement{ Type: exql.Truncate, Table: exql.TableWithName(c.Name()), } - if _, err := c.sess.SQL().Exec(&stmt); err != nil { + if _, err := c.session.SQL().Exec(&stmt); err != nil { return err } return nil diff --git a/internal/sqladapter/result.go b/internal/sqladapter/result.go index 1b25e2ea..3a7e9392 100644 --- a/internal/sqladapter/result.go +++ b/internal/sqladapter/result.go @@ -213,7 +213,7 @@ func (r *Result) Select(fields ...interface{}) db.Result { // String satisfies fmt.Stringer func (r *Result) String() string { - query, err := r.buildPaginator() + query, err := r.Paginator() if err != nil { panic(err.Error()) } @@ -222,7 +222,7 @@ func (r *Result) String() string { // All dumps all Results into a pointer to an slice of structs or maps. func (r *Result) All(dst interface{}) error { - query, err := r.buildPaginator() + query, err := r.Paginator() if err != nil { r.setErr(err) return err @@ -235,11 +235,12 @@ func (r *Result) All(dst interface{}) error { // One fetches only one Result from the set. func (r *Result) One(dst interface{}) error { one := r.Limit(1).(*Result) - query, err := one.buildPaginator() + query, err := one.Paginator() if err != nil { r.setErr(err) return err } + err = query.Iterator().One(dst) r.setErr(err) return err @@ -251,7 +252,7 @@ func (r *Result) Next(dst interface{}) bool { defer r.iterMu.Unlock() if r.iter == nil { - query, err := r.buildPaginator() + query, err := r.Paginator() if err != nil { r.setErr(err) return false @@ -309,7 +310,7 @@ func (r *Result) Update(values interface{}) error { } func (r *Result) TotalPages() (uint, error) { - query, err := r.buildPaginator() + query, err := r.Paginator() if err != nil { r.setErr(err) return 0, err @@ -325,7 +326,7 @@ func (r *Result) TotalPages() (uint, error) { } func (r *Result) TotalEntries() (uint64, error) { - query, err := r.buildPaginator() + query, err := r.Paginator() if err != nil { r.setErr(err) return 0, err @@ -391,7 +392,7 @@ func (r *Result) Count() (uint64, error) { return counter.Count, nil } -func (r *Result) buildPaginator() (db.Paginator, error) { +func (r *Result) Paginator() (db.Paginator, error) { if err := r.Err(); err != nil { return nil, err } diff --git a/internal/sqladapter/session.go b/internal/sqladapter/session.go index a2fd95a1..b9a6921d 100644 --- a/internal/sqladapter/session.go +++ b/internal/sqladapter/session.go @@ -161,7 +161,7 @@ type Session interface { // Context returns the default context the session is using. Context() context.Context - // SetContext sets a default context for the session. + // SetContext sets the default context for the session. SetContext(context.Context) NewTransaction(ctx context.Context, opts *sql.TxOptions) (Session, error) @@ -183,31 +183,35 @@ type Session interface { // NewTx wraps a *sql.Tx and returns a Tx. func NewTx(adapter AdapterSession, tx *sql.Tx) (Session, error) { - sess := &session{ - Settings: db.DefaultSettings, - - sqlTx: tx, - adapter: adapter, - cachedPKs: cache.NewCache(), - cachedCollections: cache.NewCache(), - cachedStatements: cache.NewCache(), - } - sess.builder = sqlbuilder.WithSession(sess, adapter.Template()) - return sess, nil + sessTx := &sessionWithContext{ + session: &session{ + Settings: db.DefaultSettings, + + sqlTx: tx, + adapter: adapter, + cachedPKs: cache.NewCache(), + cachedCollections: cache.NewCache(), + cachedStatements: cache.NewCache(), + }, + ctx: context.Background(), + } + return sessTx, nil } // NewSession creates a new Session. func NewSession(connURL db.ConnectionURL, adapter AdapterSession) Session { - sess := &session{ - Settings: db.DefaultSettings, - - connURL: connURL, - adapter: adapter, - cachedPKs: cache.NewCache(), - cachedCollections: cache.NewCache(), - cachedStatements: cache.NewCache(), + sess := &sessionWithContext{ + session: &session{ + Settings: db.DefaultSettings, + + connURL: connURL, + adapter: adapter, + cachedPKs: cache.NewCache(), + cachedCollections: cache.NewCache(), + cachedStatements: cache.NewCache(), + }, + ctx: context.Background(), } - sess.builder = sqlbuilder.WithSession(sess, adapter.Template()) return sess } @@ -224,7 +228,6 @@ type session struct { name string mu sync.Mutex // guards ctx, txOptions - ctx context.Context txOptions *sql.TxOptions sqlDBMu sync.Mutex // guards sess, baseTx @@ -243,36 +246,46 @@ type session struct { template *exql.Template } -var ( - _ = db.Session(&session{}) -) +type sessionWithContext struct { + *session -func (sess *session) WithContext(ctx context.Context) db.Session { - newDB, _ := sess.NewClone(sess.adapter, false) - newDB.SetContext(ctx) - return newDB + ctx context.Context } -func (sess *session) Tx(fn func(sess db.Session) error) error { +func (sess *sessionWithContext) WithContext(ctx context.Context) db.Session { + if ctx == nil { + panic("nil context") + } + newSess := &sessionWithContext{ + session: sess.session, + ctx: ctx, + } + return newSess +} + +func (sess *sessionWithContext) Tx(fn func(sess db.Session) error) error { return TxContext(sess.Context(), sess, fn, nil) } -func (sess *session) TxContext(ctx context.Context, fn func(sess db.Session) error, opts *sql.TxOptions) error { +func (sess *sessionWithContext) TxContext(ctx context.Context, fn func(sess db.Session) error, opts *sql.TxOptions) error { return TxContext(ctx, sess, fn, opts) } -func (sess *session) SQL() db.SQL { - return sess.builder +func (sess *sessionWithContext) SQL() db.SQL { + return sqlbuilder.WithSession( + sess, + sess.adapter.Template(), + ) } -func (sess *session) Err(errIn error) (errOur error) { +func (sess *sessionWithContext) Err(errIn error) (errOur error) { if convertError, ok := sess.adapter.(errorConverter); ok { return convertError.Err(errIn) } return errIn } -func (sess *session) PrimaryKeys(tableName string) ([]string, error) { +func (sess *sessionWithContext) PrimaryKeys(tableName string) ([]string, error) { h := cache.String(tableName) cachedPK, ok := sess.cachedPKs.ReadRaw(h) if ok { @@ -288,11 +301,11 @@ func (sess *session) PrimaryKeys(tableName string) ([]string, error) { return pk, nil } -func (sess *session) TableExists(name string) error { +func (sess *sessionWithContext) TableExists(name string) error { return sess.adapter.TableExists(sess, name) } -func (sess *session) NewTransaction(ctx context.Context, opts *sql.TxOptions) (Session, error) { +func (sess *sessionWithContext) NewTransaction(ctx context.Context, opts *sql.TxOptions) (Session, error) { if ctx == nil { ctx = context.Background() } @@ -316,7 +329,7 @@ func (sess *session) NewTransaction(ctx context.Context, opts *sql.TxOptions) (S return clone, nil } -func (sess *session) Collections() ([]db.Collection, error) { +func (sess *sessionWithContext) Collections() ([]db.Collection, error) { names, err := sess.adapter.Collections(sess) if err != nil { return nil, err @@ -330,11 +343,11 @@ func (sess *session) Collections() ([]db.Collection, error) { return collections, nil } -func (sess *session) ConnectionURL() db.ConnectionURL { +func (sess *sessionWithContext) ConnectionURL() db.ConnectionURL { return sess.connURL } -func (sess *session) Open() error { +func (sess *sessionWithContext) Open() error { var sqlDB *sql.DB var err error @@ -345,6 +358,7 @@ func (sess *session) Open() error { } sqlDB.SetConnMaxLifetime(sess.ConnMaxLifetime()) + sqlDB.SetConnMaxIdleTime(sess.ConnMaxIdleTime()) sqlDB.SetMaxIdleConns(sess.MaxIdleConns()) sqlDB.SetMaxOpenConns(sess.MaxOpenConns()) return nil @@ -357,7 +371,7 @@ func (sess *session) Open() error { return sess.BindDB(sqlDB) } -func (sess *session) Get(record db.Record, id interface{}) error { +func (sess *sessionWithContext) Get(record db.Record, id interface{}) error { store := record.Store(sess) if getter, ok := store.(db.StoreGetter); ok { return getter.Get(record, id) @@ -365,7 +379,7 @@ func (sess *session) Get(record db.Record, id interface{}) error { return store.Find(id).One(record) } -func (sess *session) Save(record db.Record) error { +func (sess *sessionWithContext) Save(record db.Record) error { if record == nil { return db.ErrNilRecord } @@ -401,7 +415,7 @@ func (sess *session) Save(record db.Record) error { return recordCreate(store, record) } -func (sess *session) Delete(record db.Record) error { +func (sess *sessionWithContext) Delete(record db.Record) error { if record == nil { return db.ErrNilRecord } @@ -441,32 +455,27 @@ func (sess *session) Delete(record db.Record) error { return nil } -func (sess *session) DB() *sql.DB { +func (sess *sessionWithContext) DB() *sql.DB { return sess.sqlDB } -func (sess *session) SetContext(ctx context.Context) { +func (sess *sessionWithContext) SetContext(ctx context.Context) { sess.mu.Lock() sess.ctx = ctx sess.mu.Unlock() } -func (sess *session) Context() context.Context { - sess.mu.Lock() - defer sess.mu.Unlock() - if sess.ctx == nil { - return context.Background() - } +func (sess *sessionWithContext) Context() context.Context { return sess.ctx } -func (sess *session) SetTxOptions(txOptions sql.TxOptions) { +func (sess *sessionWithContext) SetTxOptions(txOptions sql.TxOptions) { sess.mu.Lock() sess.txOptions = &txOptions sess.mu.Unlock() } -func (sess *session) TxOptions() *sql.TxOptions { +func (sess *sessionWithContext) TxOptions() *sql.TxOptions { sess.mu.Lock() defer sess.mu.Unlock() if sess.txOptions == nil { @@ -475,7 +484,7 @@ func (sess *session) TxOptions() *sql.TxOptions { return sess.txOptions } -func (sess *session) BindTx(ctx context.Context, tx *sql.Tx) error { +func (sess *sessionWithContext) BindTx(ctx context.Context, tx *sql.Tx) error { sess.sqlDBMu.Lock() defer sess.sqlDBMu.Unlock() @@ -487,29 +496,29 @@ func (sess *session) BindTx(ctx context.Context, tx *sql.Tx) error { return nil } -func (sess *session) Commit() error { +func (sess *sessionWithContext) Commit() error { if sess.sqlTx != nil { return sess.sqlTx.Commit() } return db.ErrNotWithinTransaction } -func (sess *session) Rollback() error { +func (sess *sessionWithContext) Rollback() error { if sess.sqlTx != nil { return sess.sqlTx.Rollback() } return db.ErrNotWithinTransaction } -func (sess *session) IsTransaction() bool { +func (sess *sessionWithContext) IsTransaction() bool { return sess.sqlTx != nil } -func (sess *session) Transaction() *sql.Tx { +func (sess *sessionWithContext) Transaction() *sql.Tx { return sess.sqlTx } -func (sess *session) Name() string { +func (sess *sessionWithContext) Name() string { sess.lookupNameOnce.Do(func() { if sess.name == "" { sess.name, _ = sess.adapter.LookupName(sess) @@ -519,7 +528,8 @@ func (sess *session) Name() string { return sess.name } -func (sess *session) BindDB(sqlDB *sql.DB) error { +func (sess *sessionWithContext) BindDB(sqlDB *sql.DB) error { + sess.sqlDBMu.Lock() sess.sqlDB = sqlDB sess.sqlDBMu.Unlock() @@ -538,28 +548,35 @@ func (sess *session) BindDB(sqlDB *sql.DB) error { return nil } -func (sess *session) Ping() error { +func (sess *sessionWithContext) Ping() error { if sess.sqlDB != nil { return sess.sqlDB.Ping() } return db.ErrNotConnected } -func (sess *session) SetConnMaxLifetime(t time.Duration) { +func (sess *sessionWithContext) SetConnMaxLifetime(t time.Duration) { sess.Settings.SetConnMaxLifetime(t) if sessDB := sess.DB(); sessDB != nil { sessDB.SetConnMaxLifetime(sess.Settings.ConnMaxLifetime()) } } -func (sess *session) SetMaxIdleConns(n int) { +func (sess *sessionWithContext) SetConnMaxIdleTime(t time.Duration) { + sess.Settings.SetConnMaxIdleTime(t) + if sessDB := sess.DB(); sessDB != nil { + sessDB.SetConnMaxIdleTime(sess.Settings.ConnMaxIdleTime()) + } +} + +func (sess *sessionWithContext) SetMaxIdleConns(n int) { sess.Settings.SetMaxIdleConns(n) if sessDB := sess.DB(); sessDB != nil { sessDB.SetMaxIdleConns(sess.Settings.MaxIdleConns()) } } -func (sess *session) SetMaxOpenConns(n int) { +func (sess *sessionWithContext) SetMaxOpenConns(n int) { sess.Settings.SetMaxOpenConns(n) if sessDB := sess.DB(); sessDB != nil { sessDB.SetMaxOpenConns(sess.Settings.MaxOpenConns()) @@ -567,7 +584,7 @@ func (sess *session) SetMaxOpenConns(n int) { } // Reset removes all caches. -func (sess *session) Reset() { +func (sess *sessionWithContext) Reset() { sess.cacheMu.Lock() defer sess.cacheMu.Unlock() @@ -580,8 +597,9 @@ func (sess *session) Reset() { } } -func (sess *session) NewClone(adapter AdapterSession, checkConn bool) (Session, error) { - newSess := NewSession(sess.connURL, adapter).(*session) +func (sess *sessionWithContext) NewClone(adapter AdapterSession, checkConn bool) (Session, error) { + + newSess := NewSession(sess.connURL, adapter).(*sessionWithContext) newSess.name = sess.name newSess.sqlDB = sess.sqlDB @@ -602,7 +620,7 @@ func (sess *session) NewClone(adapter AdapterSession, checkConn bool) (Session, return newSess, nil } -func (sess *session) Close() error { +func (sess *sessionWithContext) Close() error { defer func() { sess.sqlDBMu.Lock() sess.sqlDB = nil @@ -630,21 +648,21 @@ func (sess *session) Close() error { return nil } -func (sess *session) Collection(name string) db.Collection { +func (sess *sessionWithContext) Collection(name string) db.Collection { sess.cacheMu.Lock() defer sess.cacheMu.Unlock() h := cache.String(name) - - cachedCol, ok := sess.cachedCollections.ReadRaw(h) - if ok { - return cachedCol.(db.Collection) + col, ok := sess.cachedCollections.ReadRaw(h) + if !ok { + col = newCollection(name, sess.adapter.NewCollection()) + sess.cachedCollections.Write(h, col) } - col := NewCollection(sess, name, sess.adapter.NewCollection()) - sess.cachedCollections.Write(h, col) - - return col + return &collectionWithSession{ + collection: col.(*collection), + session: sess, + } } func queryLog(status *db.QueryStatus) { @@ -664,7 +682,7 @@ func queryLog(status *db.QueryStatus) { db.LC().Debug(status) } -func (sess *session) StatementPrepare(ctx context.Context, stmt *exql.Statement) (sqlStmt *sql.Stmt, err error) { +func (sess *sessionWithContext) StatementPrepare(ctx context.Context, stmt *exql.Statement) (sqlStmt *sql.Stmt, err error) { var query string defer func(start time.Time) { @@ -694,7 +712,7 @@ func (sess *session) StatementPrepare(ctx context.Context, stmt *exql.Statement) return } -func (sess *session) ConvertValue(value interface{}) interface{} { +func (sess *sessionWithContext) ConvertValue(value interface{}) interface{} { if scannerValuer, ok := value.(sqlbuilder.ScannerValuer); ok { return scannerValuer } @@ -717,7 +735,7 @@ func (sess *session) ConvertValue(value interface{}) interface{} { return value } -func (sess *session) StatementExec(ctx context.Context, stmt *exql.Statement, args ...interface{}) (res sql.Result, err error) { +func (sess *sessionWithContext) StatementExec(ctx context.Context, stmt *exql.Statement, args ...interface{}) (res sql.Result, err error) { var query string defer func(start time.Time) { @@ -781,7 +799,7 @@ func (sess *session) StatementExec(ctx context.Context, stmt *exql.Statement, ar } // StatementQuery compiles and executes a statement that returns rows. -func (sess *session) StatementQuery(ctx context.Context, stmt *exql.Statement, args ...interface{}) (rows *sql.Rows, err error) { +func (sess *sessionWithContext) StatementQuery(ctx context.Context, stmt *exql.Statement, args ...interface{}) (rows *sql.Rows, err error) { var query string defer func(start time.Time) { @@ -822,12 +840,11 @@ func (sess *session) StatementQuery(ctx context.Context, stmt *exql.Statement, a rows, err = compat.QueryContext(sess.sqlDB, ctx, query, args) return - } // StatementQueryRow compiles and executes a statement that returns at most one // row. -func (sess *session) StatementQueryRow(ctx context.Context, stmt *exql.Statement, args ...interface{}) (row *sql.Row, err error) { +func (sess *sessionWithContext) StatementQueryRow(ctx context.Context, stmt *exql.Statement, args ...interface{}) (row *sql.Row, err error) { var query string defer func(start time.Time) { @@ -871,7 +888,7 @@ func (sess *session) StatementQueryRow(ctx context.Context, stmt *exql.Statement } // Driver returns the underlying *sql.DB or *sql.Tx instance. -func (sess *session) Driver() interface{} { +func (sess *sessionWithContext) Driver() interface{} { if sess.sqlTx != nil { return sess.sqlTx } @@ -879,7 +896,7 @@ func (sess *session) Driver() interface{} { } // compileStatement compiles the given statement into a string. -func (sess *session) compileStatement(stmt *exql.Statement, args []interface{}) (string, []interface{}, error) { +func (sess *sessionWithContext) compileStatement(stmt *exql.Statement, args []interface{}) (string, []interface{}, error) { for i := range args { args[i] = sess.ConvertValue(args[i]) } @@ -897,7 +914,7 @@ func (sess *session) compileStatement(stmt *exql.Statement, args []interface{}) // prepareStatement compiles a query and tries to use previously generated // statement. -func (sess *session) prepareStatement(ctx context.Context, stmt *exql.Statement, args []interface{}) (*Stmt, string, []interface{}, error) { +func (sess *sessionWithContext) prepareStatement(ctx context.Context, stmt *exql.Statement, args []interface{}) (*Stmt, string, []interface{}, error) { sess.sqlDBMu.Lock() defer sess.sqlDBMu.Unlock() @@ -947,7 +964,7 @@ var waitForConnMu sync.Mutex // connectFn returns an error, then WaitForConnection will keep trying until // connectFn returns nil. Maximum waiting time is 5s after having acquired the // lock. -func (sess *session) WaitForConnection(connectFn func() error) error { +func (sess *sessionWithContext) WaitForConnection(connectFn func() error) error { // This lock ensures first-come, first-served and prevents opening too many // file descriptors. waitForConnMu.Lock() @@ -1012,6 +1029,7 @@ func ReplaceWithDollarSign(in string) string { func copySettings(from Session, into Session) { into.SetPreparedStatementCache(from.PreparedStatementCacheEnabled()) into.SetConnMaxLifetime(from.ConnMaxLifetime()) + into.SetConnMaxIdleTime(from.ConnMaxIdleTime()) into.SetMaxIdleConns(from.MaxIdleConns()) into.SetMaxOpenConns(from.MaxOpenConns()) } @@ -1032,8 +1050,6 @@ func newBaseTxID() uint64 { return atomic.AddUint64(&lastTxID, 1) } -var _ db.Session = &session{} - // TxContext creates a transaction context and runs fn within it. func TxContext(ctx context.Context, sess db.Session, fn func(tx db.Session) error, opts *sql.TxOptions) error { txFn := func(sess db.Session) error { @@ -1056,7 +1072,7 @@ func TxContext(ctx context.Context, sess db.Session, fn func(tx db.Session) erro var txErr error for i := 0; i < sess.MaxTransactionRetries(); i++ { - txErr = sess.(*session).Err(txFn(sess)) + txErr = sess.(*sessionWithContext).Err(txFn(sess)) if txErr == nil { return nil } @@ -1075,3 +1091,5 @@ func TxContext(ctx context.Context, sess db.Session, fn func(tx db.Session) erro return fmt.Errorf("db: giving up trying to commit transaction: %w", txErr) } + +var _ = db.Session(&sessionWithContext{}) diff --git a/internal/sqladapter/sqladapter_test.go b/internal/sqladapter/sqladapter_test.go index 773c76ad..1e1ab5be 100644 --- a/internal/sqladapter/sqladapter_test.go +++ b/internal/sqladapter/sqladapter_test.go @@ -7,12 +7,10 @@ import ( "github.com/upper/db/v4" ) -func TestInterfaces(t *testing.T) { - var ( - _ db.Collection = &collection{} - _ Collection = &collection{} - ) -} +var ( + _ db.Collection = &collectionWithSession{} + _ Collection = &collectionWithSession{} +) func TestReplaceWithDollarSign(t *testing.T) { tests := []struct { diff --git a/internal/sqlbuilder/builder.go b/internal/sqlbuilder/builder.go index da7ecb3c..b912e6cc 100644 --- a/internal/sqlbuilder/builder.go +++ b/internal/sqlbuilder/builder.go @@ -27,7 +27,6 @@ import ( "database/sql" "errors" "fmt" - "log" "reflect" "sort" "strconv" @@ -51,7 +50,11 @@ var defaultMapOptions = MapOptions{ IncludeNil: false, } -type compilable interface { +type hasPaginator interface { + Paginator() (db.Paginator, error) +} + +type isCompilable interface { Compile() (string, error) Arguments() []interface{} } @@ -347,7 +350,17 @@ func columnFragments(columns []interface{}) ([]exql.Fragment, []interface{}, err for i := range columns { switch v := columns[i].(type) { - case compilable: + case hasPaginator: + p, err := v.Paginator() + if err != nil { + return nil, nil, err + } + + q, a := Preprocess(p.String(), p.Arguments()) + + f[i] = exql.RawValue("(" + q + ")") + args = append(args, a...) + case isCompilable: c, err := v.Compile() if err != nil { return nil, nil, err @@ -560,7 +573,6 @@ type exprProxy struct { } func (p *exprProxy) Context() context.Context { - log.Printf("Missing context") return context.Background() } diff --git a/internal/sqlbuilder/convert.go b/internal/sqlbuilder/convert.go index 34d0c59a..37901f61 100644 --- a/internal/sqlbuilder/convert.go +++ b/internal/sqlbuilder/convert.go @@ -57,7 +57,7 @@ func toInterfaceArguments(value interface{}) (args []interface{}, isSlice bool) // Byte slice gets transformed into a string. if v.Type().Elem().Kind() == reflect.Uint8 { - return []interface{}{string(value.([]byte))}, false + return []interface{}{string(v.Bytes())}, false } total = v.Len() @@ -122,7 +122,13 @@ func preprocessFn(arg interface{}) (string, []interface{}) { switch t := arg.(type) { case *adapter.RawExpr: return Preprocess(t.Raw(), t.Arguments()) - case compilable: + case hasPaginator: + p, err := t.Paginator() + if err == nil { + return `(` + p.String() + `)`, p.Arguments() + } + panic(err.Error()) + case isCompilable: c, err := t.Compile() if err == nil { return `(` + c + `)`, t.Arguments() diff --git a/internal/testsuite/sql_suite.go b/internal/testsuite/sql_suite.go index 407e484b..073ea8d7 100644 --- a/internal/testsuite/sql_suite.go +++ b/internal/testsuite/sql_suite.go @@ -1889,12 +1889,89 @@ func (s *SQLTestSuite) TestCustomType() { } func (s *SQLTestSuite) Test_Issue565() { - ctx, _ := context.WithTimeout(context.Background(), time.Nanosecond) - sess := s.Session().WithContext(ctx) + s.Session().Collection("birthdays").Insert(&birthday{ + Name: "Lucy", + Born: time.Now(), + }) - var result birthday - err := sess.Collection("birthdays").Find().Select("name").One(&result) + parentCtx := context.WithValue(s.Session().Context(), "carry", 1) + s.NotZero(parentCtx.Value("carry")) + + { + ctx, cancel := context.WithTimeout(parentCtx, time.Nanosecond) + defer cancel() + + sess := s.Session() + + sess = sess.WithContext(ctx) + + var result birthday + err := sess.Collection("birthdays").Find().Select("name").One(&result) + + s.Error(err) + s.Zero(result.Name) + + s.NotZero(ctx.Value("carry")) + } + + { + ctx, cancel := context.WithTimeout(parentCtx, time.Second*10) + cancel() // cancel before passing + + sess := s.Session().WithContext(ctx) + + var result birthday + err := sess.Collection("birthdays").Find().Select("name").One(&result) + + s.Error(err) + s.Zero(result.Name) + + s.NotZero(ctx.Value("carry")) + } + + { + ctx, cancel := context.WithTimeout(parentCtx, time.Second) + defer cancel() + + sess := s.Session().WithContext(ctx) + + var result birthday + err := sess.Collection("birthdays").Find().Select("name").One(&result) + + s.NoError(err) + s.NotZero(result.Name) + + s.NotZero(ctx.Value("carry")) + } +} + +func (s *SQLTestSuite) TestSelectFromSubquery() { + sess := s.Session() + + { + var artists []artistType + q := sess.SQL().SelectFrom( + sess.SQL().SelectFrom("artist").Where(db.Cond{ + "name": db.IsNotNull(), + }), + ).As("_q") + err := q.All(&artists) + s.NoError(err) + + s.NotZero(len(artists)) + } + + { + var artists []artistType + q := sess.SQL().SelectFrom( + sess.Collection("artist").Find(db.Cond{ + "name": db.IsNotNull(), + }), + ).As("_q") + err := q.All(&artists) + s.NoError(err) + + s.NotZero(len(artists)) + } - s.Error(err) - s.Zero(result.Name) } diff --git a/session.go b/session.go index 507b86f8..80b8105c 100644 --- a/session.go +++ b/session.go @@ -90,10 +90,9 @@ type Session interface { // context.Background() is returned. Context() context.Context - // WithContext returns a copy of the session that uses the given context as - // default. Copies are safe to use concurrently but they're backed by the - // same Session. You may close a copy at any point but that won't close the - // parent session. + // WithContext returns the same session on a different default context. The + // session is identical to the original one in all ways except for the + // context. WithContext(ctx context.Context) Session Settings diff --git a/settings.go b/settings.go index 98811f09..d9b4177f 100644 --- a/settings.go +++ b/settings.go @@ -45,6 +45,14 @@ type Settings interface { // may be reused. ConnMaxLifetime() time.Duration + // SetConnMaxIdleTime sets the default maximum amount of time a connection + // may remain idle. + SetConnMaxIdleTime(time.Duration) + + // ConnMaxIdleTime returns the default maximum amount of time a connection + // may remain idle. + ConnMaxIdleTime() time.Duration + // SetMaxIdleConns sets the default maximum number of connections in the idle // connection pool. SetMaxIdleConns(int) @@ -76,6 +84,7 @@ type settings struct { preparedStatementCacheEnabled uint32 connMaxLifetime time.Duration + connMaxIdleTime time.Duration maxOpenConns int maxIdleConns int @@ -114,6 +123,18 @@ func (c *settings) ConnMaxLifetime() time.Duration { return c.connMaxLifetime } +func (c *settings) SetConnMaxIdleTime(t time.Duration) { + c.Lock() + c.connMaxIdleTime = t + c.Unlock() +} + +func (c *settings) ConnMaxIdleTime() time.Duration { + c.RLock() + defer c.RUnlock() + return c.connMaxIdleTime +} + func (c *settings) SetMaxIdleConns(n int) { c.Lock() c.maxIdleConns = n @@ -160,6 +181,7 @@ func NewSettings() Settings { return &settings{ preparedStatementCacheEnabled: def.preparedStatementCacheEnabled, connMaxLifetime: def.connMaxLifetime, + connMaxIdleTime: def.connMaxIdleTime, maxIdleConns: def.maxIdleConns, maxOpenConns: def.maxOpenConns, maxTransactionRetries: def.maxTransactionRetries, @@ -171,6 +193,7 @@ func NewSettings() Settings { var DefaultSettings Settings = &settings{ preparedStatementCacheEnabled: 0, connMaxLifetime: time.Duration(0), + connMaxIdleTime: time.Duration(0), maxIdleConns: 10, maxOpenConns: 0, maxTransactionRetries: 1,