Skip to content

Commit

Permalink
Merge pull request #2 from nicheinc/feature/pgx-with-connection
Browse files Browse the repository at this point in the history
Add WithConnection for pgx and pgx5
  • Loading branch information
bclarkx2 committed Sep 29, 2023
2 parents 856ea12 + 6242e1a commit e3a73d1
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 24 deletions.
41 changes: 29 additions & 12 deletions database/pgx/pgx.go
Expand Up @@ -67,19 +67,19 @@ type Postgres struct {
config *Config
}

func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Postgres, error) {
if config == nil {
return nil, ErrNilConfig
}

if err := instance.Ping(); err != nil {
if err := conn.PingContext(ctx); err != nil {
return nil, err
}

if config.DatabaseName == "" {
query := `SELECT CURRENT_DATABASE()`
var databaseName string
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}

Expand All @@ -93,7 +93,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
if config.SchemaName == "" {
query := `SELECT CURRENT_SCHEMA()`
var schemaName string
if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
if err := conn.QueryRowContext(ctx, query).Scan(&schemaName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}

Expand Down Expand Up @@ -121,15 +121,8 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
}
}

conn, err := instance.Conn(context.Background())

if err != nil {
return nil, err
}

px := &Postgres{
conn: conn,
db: instance,
config: config,
}

Expand All @@ -140,6 +133,26 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
return px, nil
}

func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
ctx := context.Background()

if err := instance.Ping(); err != nil {
return nil, err
}

conn, err := instance.Conn(ctx)
if err != nil {
return nil, err
}

px, err := WithConnection(ctx, conn, config)
if err != nil {
return nil, err
}
px.db = instance
return px, nil
}

func (p *Postgres) Open(url string) (database.Driver, error) {
purl, err := nurl.Parse(url)
if err != nil {
Expand Down Expand Up @@ -214,7 +227,11 @@ func (p *Postgres) Open(url string) (database.Driver, error) {

func (p *Postgres) Close() error {
connErr := p.conn.Close()
dbErr := p.db.Close()
var dbErr error
if p.db != nil {
dbErr = p.db.Close()
}

if connErr != nil || dbErr != nil {
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
}
Expand Down
38 changes: 38 additions & 0 deletions database/pgx/pgx_test.go
Expand Up @@ -682,6 +682,44 @@ func TestWithInstance_Concurrent(t *testing.T) {
}
})
}

func TestWithConnection(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}

db, err := sql.Open("pgx", pgConnectionString(ip, port))
if err != nil {
t.Fatal(err)
}
defer func() {
if err := db.Close(); err != nil {
t.Error(err)
}
}()

ctx := context.Background()
conn, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}

p, err := WithConnection(ctx, conn, &Config{})
if err != nil {
t.Fatal(err)
}

defer func() {
if err := p.Close(); err != nil {
t.Error(err)
}
}()
dt.Test(t, p, []byte("SELECT 1"))
})
}

func Test_computeLineFromPos(t *testing.T) {
testcases := []struct {
pos int
Expand Down
41 changes: 29 additions & 12 deletions database/pgx/v5/pgx.go
Expand Up @@ -65,19 +65,19 @@ type Postgres struct {
config *Config
}

func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Postgres, error) {
if config == nil {
return nil, ErrNilConfig
}

if err := instance.Ping(); err != nil {
if err := conn.PingContext(ctx); err != nil {
return nil, err
}

if config.DatabaseName == "" {
query := `SELECT CURRENT_DATABASE()`
var databaseName string
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}

Expand All @@ -91,7 +91,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
if config.SchemaName == "" {
query := `SELECT CURRENT_SCHEMA()`
var schemaName string
if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
if err := conn.QueryRowContext(ctx, query).Scan(&schemaName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}

Expand Down Expand Up @@ -119,15 +119,8 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
}
}

conn, err := instance.Conn(context.Background())

if err != nil {
return nil, err
}

px := &Postgres{
conn: conn,
db: instance,
config: config,
}

Expand All @@ -138,6 +131,26 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
return px, nil
}

func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
ctx := context.Background()

if err := instance.Ping(); err != nil {
return nil, err
}

conn, err := instance.Conn(ctx)
if err != nil {
return nil, err
}

px, err := WithConnection(ctx, conn, config)
if err != nil {
return nil, err
}
px.db = instance
return px, nil
}

func (p *Postgres) Open(url string) (database.Driver, error) {
purl, err := nurl.Parse(url)
if err != nil {
Expand Down Expand Up @@ -212,7 +225,11 @@ func (p *Postgres) Open(url string) (database.Driver, error) {

func (p *Postgres) Close() error {
connErr := p.conn.Close()
dbErr := p.db.Close()
var dbErr error
if p.db != nil {
dbErr = p.db.Close()
}

if connErr != nil || dbErr != nil {
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
}
Expand Down
38 changes: 38 additions & 0 deletions database/pgx/v5/pgx_test.go
Expand Up @@ -683,6 +683,44 @@ func TestWithInstance_Concurrent(t *testing.T) {
}
})
}

func TestWithConnection(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}

db, err := sql.Open("pgx", pgConnectionString(ip, port))
if err != nil {
t.Fatal(err)
}
defer func() {
if err := db.Close(); err != nil {
t.Error(err)
}
}()

ctx := context.Background()
conn, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}

p, err := WithConnection(ctx, conn, &Config{})
if err != nil {
t.Fatal(err)
}

defer func() {
if err := p.Close(); err != nil {
t.Error(err)
}
}()
dt.Test(t, p, []byte("SELECT 1"))
})
}

func Test_computeLineFromPos(t *testing.T) {
testcases := []struct {
pos int
Expand Down

0 comments on commit e3a73d1

Please sign in to comment.