diff --git a/database/pgx/pgx.go b/database/pgx/pgx.go index 7e42d29c9..b746f0604 100644 --- a/database/pgx/pgx.go +++ b/database/pgx/pgx.go @@ -77,19 +77,19 @@ type Postgres struct { config *Config } -func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { +func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Postgres, error) { if config == nil { return nil, ErrNilConfig } - if err := instance.Ping(); err != nil { + if err := conn.PingContext(ctx); err != nil { return nil, err } if config.DatabaseName == "" { query := `SELECT CURRENT_DATABASE()` var databaseName string - if err := instance.QueryRow(query).Scan(&databaseName); err != nil { + if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil { return nil, &database.Error{OrigErr: err, Query: []byte(query)} } @@ -103,7 +103,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config.SchemaName == "" { query := `SELECT CURRENT_SCHEMA()` var schemaName string - if err := instance.QueryRow(query).Scan(&schemaName); err != nil { + if err := conn.QueryRowContext(ctx, query).Scan(&schemaName); err != nil { return nil, &database.Error{OrigErr: err, Query: []byte(query)} } @@ -139,15 +139,8 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { } } - conn, err := instance.Conn(context.Background()) - - if err != nil { - return nil, err - } - px := &Postgres{ conn: conn, - db: instance, config: config, } @@ -162,6 +155,26 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { return px, nil } +func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { + ctx := context.Background() + + if err := instance.Ping(); err != nil { + return nil, err + } + + conn, err := instance.Conn(ctx) + if err != nil { + return nil, err + } + + px, err := WithConnection(ctx, conn, config) + if err != nil { + return nil, err + } + px.db = instance + return px, nil +} + func (p *Postgres) Open(url string) (database.Driver, error) { purl, err := nurl.Parse(url) if err != nil { @@ -241,7 +254,11 @@ func (p *Postgres) Open(url string) (database.Driver, error) { func (p *Postgres) Close() error { connErr := p.conn.Close() - dbErr := p.db.Close() + var dbErr error + if p.db != nil { + dbErr = p.db.Close() + } + if connErr != nil || dbErr != nil { return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) } @@ -353,7 +370,7 @@ func (p *Postgres) releaseTableLock() error { } query := "DELETE FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1" - if _, err := p.db.Exec(query, aid); err != nil { + if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { return database.Error{OrigErr: err, Err: "failed to release migration lock", Query: []byte(query)} } @@ -598,7 +615,7 @@ func (p *Postgres) ensureLockTable() error { var count int query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` - if err := p.db.QueryRow(query, p.config.LockTable).Scan(&count); err != nil { + if err := p.conn.QueryRowContext(context.Background(), query, p.config.LockTable).Scan(&count); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } if count == 1 { @@ -606,7 +623,7 @@ func (p *Postgres) ensureLockTable() error { } query = `CREATE TABLE ` + pq.QuoteIdentifier(p.config.LockTable) + ` (lock_id BIGINT NOT NULL PRIMARY KEY)` - if _, err := p.db.Exec(query); err != nil { + if _, err := p.conn.ExecContext(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } diff --git a/database/pgx/pgx_test.go b/database/pgx/pgx_test.go index 53e8e1d86..d7146875c 100644 --- a/database/pgx/pgx_test.go +++ b/database/pgx/pgx_test.go @@ -708,6 +708,44 @@ func TestWithInstance_Concurrent(t *testing.T) { } }) } + +func TestWithConnection(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + db, err := sql.Open("pgx", pgConnectionString(ip, port)) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := db.Close(); err != nil { + t.Error(err) + } + }() + + ctx := context.Background() + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + + p, err := WithConnection(ctx, conn, &Config{}) + if err != nil { + t.Fatal(err) + } + + defer func() { + if err := p.Close(); err != nil { + t.Error(err) + } + }() + dt.Test(t, p, []byte("SELECT 1")) + }) +} + func Test_computeLineFromPos(t *testing.T) { testcases := []struct { pos int diff --git a/database/pgx/v5/pgx.go b/database/pgx/v5/pgx.go index 1b5a6ea7a..3a225ff23 100644 --- a/database/pgx/v5/pgx.go +++ b/database/pgx/v5/pgx.go @@ -65,19 +65,19 @@ type Postgres struct { config *Config } -func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { +func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Postgres, error) { if config == nil { return nil, ErrNilConfig } - if err := instance.Ping(); err != nil { + if err := conn.PingContext(ctx); err != nil { return nil, err } if config.DatabaseName == "" { query := `SELECT CURRENT_DATABASE()` var databaseName string - if err := instance.QueryRow(query).Scan(&databaseName); err != nil { + if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil { return nil, &database.Error{OrigErr: err, Query: []byte(query)} } @@ -91,7 +91,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config.SchemaName == "" { query := `SELECT CURRENT_SCHEMA()` var schemaName string - if err := instance.QueryRow(query).Scan(&schemaName); err != nil { + if err := conn.QueryRowContext(ctx, query).Scan(&schemaName); err != nil { return nil, &database.Error{OrigErr: err, Query: []byte(query)} } @@ -119,15 +119,8 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { } } - conn, err := instance.Conn(context.Background()) - - if err != nil { - return nil, err - } - px := &Postgres{ conn: conn, - db: instance, config: config, } @@ -138,6 +131,26 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { return px, nil } +func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { + ctx := context.Background() + + if err := instance.Ping(); err != nil { + return nil, err + } + + conn, err := instance.Conn(ctx) + if err != nil { + return nil, err + } + + px, err := WithConnection(ctx, conn, config) + if err != nil { + return nil, err + } + px.db = instance + return px, nil +} + func (p *Postgres) Open(url string) (database.Driver, error) { purl, err := nurl.Parse(url) if err != nil { @@ -212,7 +225,11 @@ func (p *Postgres) Open(url string) (database.Driver, error) { func (p *Postgres) Close() error { connErr := p.conn.Close() - dbErr := p.db.Close() + var dbErr error + if p.db != nil { + dbErr = p.db.Close() + } + if connErr != nil || dbErr != nil { return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) } diff --git a/database/pgx/v5/pgx_test.go b/database/pgx/v5/pgx_test.go index c7339c4fc..61c9260ec 100644 --- a/database/pgx/v5/pgx_test.go +++ b/database/pgx/v5/pgx_test.go @@ -683,6 +683,44 @@ func TestWithInstance_Concurrent(t *testing.T) { } }) } + +func TestWithConnection(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + db, err := sql.Open("pgx", pgConnectionString(ip, port)) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := db.Close(); err != nil { + t.Error(err) + } + }() + + ctx := context.Background() + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + + p, err := WithConnection(ctx, conn, &Config{}) + if err != nil { + t.Fatal(err) + } + + defer func() { + if err := p.Close(); err != nil { + t.Error(err) + } + }() + dt.Test(t, p, []byte("SELECT 1")) + }) +} + func Test_computeLineFromPos(t *testing.T) { testcases := []struct { pos int