diff --git a/database/pgx/pgx.go b/database/pgx/pgx.go index deaca94ea..f47206017 100644 --- a/database/pgx/pgx.go +++ b/database/pgx/pgx.go @@ -67,19 +67,19 @@ type Postgres struct { config *Config } -func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { +func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Postgres, error) { if config == nil { return nil, ErrNilConfig } - if err := instance.Ping(); err != nil { + if err := conn.PingContext(ctx); err != nil { return nil, err } if config.DatabaseName == "" { query := `SELECT CURRENT_DATABASE()` var databaseName string - if err := instance.QueryRow(query).Scan(&databaseName); err != nil { + if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil { return nil, &database.Error{OrigErr: err, Query: []byte(query)} } @@ -93,7 +93,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config.SchemaName == "" { query := `SELECT CURRENT_SCHEMA()` var schemaName string - if err := instance.QueryRow(query).Scan(&schemaName); err != nil { + if err := conn.QueryRowContext(ctx, query).Scan(&schemaName); err != nil { return nil, &database.Error{OrigErr: err, Query: []byte(query)} } @@ -121,15 +121,8 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { } } - conn, err := instance.Conn(context.Background()) - - if err != nil { - return nil, err - } - px := &Postgres{ conn: conn, - db: instance, config: config, } @@ -140,6 +133,26 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { return px, nil } +func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { + ctx := context.Background() + + if err := instance.Ping(); err != nil { + return nil, err + } + + conn, err := instance.Conn(ctx) + if err != nil { + return nil, err + } + + px, err := WithConnection(ctx, conn, config) + if err != nil { + return nil, err + } + px.db = instance + return px, nil +} + func (p *Postgres) Open(url string) (database.Driver, error) { purl, err := nurl.Parse(url) if err != nil { @@ -214,7 +227,11 @@ func (p *Postgres) Open(url string) (database.Driver, error) { func (p *Postgres) Close() error { connErr := p.conn.Close() - dbErr := p.db.Close() + var dbErr error + if p.db != nil { + dbErr = p.db.Close() + } + if connErr != nil || dbErr != nil { return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) } diff --git a/database/pgx/pgx_test.go b/database/pgx/pgx_test.go index 5d7a5238e..98a0477ed 100644 --- a/database/pgx/pgx_test.go +++ b/database/pgx/pgx_test.go @@ -682,6 +682,44 @@ func TestWithInstance_Concurrent(t *testing.T) { } }) } + +func TestWithConnection(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + db, err := sql.Open("pgx", pgConnectionString(ip, port)) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := db.Close(); err != nil { + t.Error(err) + } + }() + + ctx := context.Background() + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + + p, err := WithConnection(ctx, conn, &Config{}) + if err != nil { + t.Fatal(err) + } + + defer func() { + if err := p.Close(); err != nil { + t.Error(err) + } + }() + dt.Test(t, p, []byte("SELECT 1")) + }) +} + func Test_computeLineFromPos(t *testing.T) { testcases := []struct { pos int diff --git a/database/pgx/v5/pgx.go b/database/pgx/v5/pgx.go index 1b5a6ea7a..3a225ff23 100644 --- a/database/pgx/v5/pgx.go +++ b/database/pgx/v5/pgx.go @@ -65,19 +65,19 @@ type Postgres struct { config *Config } -func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { +func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Postgres, error) { if config == nil { return nil, ErrNilConfig } - if err := instance.Ping(); err != nil { + if err := conn.PingContext(ctx); err != nil { return nil, err } if config.DatabaseName == "" { query := `SELECT CURRENT_DATABASE()` var databaseName string - if err := instance.QueryRow(query).Scan(&databaseName); err != nil { + if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil { return nil, &database.Error{OrigErr: err, Query: []byte(query)} } @@ -91,7 +91,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config.SchemaName == "" { query := `SELECT CURRENT_SCHEMA()` var schemaName string - if err := instance.QueryRow(query).Scan(&schemaName); err != nil { + if err := conn.QueryRowContext(ctx, query).Scan(&schemaName); err != nil { return nil, &database.Error{OrigErr: err, Query: []byte(query)} } @@ -119,15 +119,8 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { } } - conn, err := instance.Conn(context.Background()) - - if err != nil { - return nil, err - } - px := &Postgres{ conn: conn, - db: instance, config: config, } @@ -138,6 +131,26 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { return px, nil } +func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { + ctx := context.Background() + + if err := instance.Ping(); err != nil { + return nil, err + } + + conn, err := instance.Conn(ctx) + if err != nil { + return nil, err + } + + px, err := WithConnection(ctx, conn, config) + if err != nil { + return nil, err + } + px.db = instance + return px, nil +} + func (p *Postgres) Open(url string) (database.Driver, error) { purl, err := nurl.Parse(url) if err != nil { @@ -212,7 +225,11 @@ func (p *Postgres) Open(url string) (database.Driver, error) { func (p *Postgres) Close() error { connErr := p.conn.Close() - dbErr := p.db.Close() + var dbErr error + if p.db != nil { + dbErr = p.db.Close() + } + if connErr != nil || dbErr != nil { return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) } diff --git a/database/pgx/v5/pgx_test.go b/database/pgx/v5/pgx_test.go index c7339c4fc..61c9260ec 100644 --- a/database/pgx/v5/pgx_test.go +++ b/database/pgx/v5/pgx_test.go @@ -683,6 +683,44 @@ func TestWithInstance_Concurrent(t *testing.T) { } }) } + +func TestWithConnection(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + db, err := sql.Open("pgx", pgConnectionString(ip, port)) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := db.Close(); err != nil { + t.Error(err) + } + }() + + ctx := context.Background() + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + + p, err := WithConnection(ctx, conn, &Config{}) + if err != nil { + t.Fatal(err) + } + + defer func() { + if err := p.Close(); err != nil { + t.Error(err) + } + }() + dt.Test(t, p, []byte("SELECT 1")) + }) +} + func Test_computeLineFromPos(t *testing.T) { testcases := []struct { pos int