Skip to content

Commit

Permalink
Merge pull request #35 from upfluence/am/postgres-update
Browse files Browse the repository at this point in the history
backend/postgres: Various improvments
  • Loading branch information
AlexisMontagne committed Oct 20, 2023
2 parents 9770b35 + 6029c76 commit f059df7
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Expand Up @@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
go: [ '1.18.x', '1.17.x', '1.16.x' ]
go: [ '1.21.x', '1.20.x', '1.19.x' ]
services:
postgres:
image: postgres
Expand Down
42 changes: 32 additions & 10 deletions backend/postgres/config.go
Expand Up @@ -37,9 +37,12 @@ type Config struct {
Password string

SSLMode SSLMode
SSLSNI bool

CACertFile string
CACert *x509.Certificate
// Deprecated: Prefer passing a slice of CACerts
CACert *x509.Certificate
CACerts []*x509.Certificate

Role DBRole

Expand Down Expand Up @@ -78,29 +81,41 @@ func (c *Config) userInfo() *url.Userinfo {
return url.UserPassword(c.User, c.Password)
}

func writeCert(cert *x509.Certificate) (string, error) {
func writeBundle(certs []*x509.Certificate) (string, error) {
f, err := ioutil.TempFile("", "")

if err != nil {
return "", err
}

if err := pem.Encode(
f,
&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw},
); err != nil {
f.Close()
return "", err
for _, cert := range certs {
if err := pem.Encode(
f,
&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw},
); err != nil {
f.Close()
return "", err
}
}

return f.Name(), f.Close()
}

func (c *Config) caCerts() []*x509.Certificate {
var certs = c.CACerts

if c.CACert != nil {
certs = append(certs, c.CACert)
}

return certs
}

func (c *Config) sslValues() (url.Values, error) {
mode := Disable

if c.CACertFile == "" && c.CACert != nil {
c.certOnce.Do(func() { c.CACertFile, c.certErr = writeCert(c.CACert) })
c.certOnce.Do(func() { c.CACertFile, c.certErr = writeBundle(c.caCerts()) })

if c.certErr != nil {
return nil, c.certErr
Expand All @@ -113,7 +128,14 @@ func (c *Config) sslValues() (url.Values, error) {
mode = VerifyCA
}

vs := url.Values{"sslmode": {string(mode)}}
vs := url.Values{
"sslmode": {string(mode)},
"sslsni": {"0"},
}

if c.SSLSNI {
vs["sslsni"][0] = "1"
}

if c.CACertFile != "" {
vs["sslrootcert"] = []string{c.CACertFile}
Expand Down
10 changes: 5 additions & 5 deletions backend/postgres/config_test.go
Expand Up @@ -8,16 +8,16 @@ import (

func TestDSN(t *testing.T) {
for _, tt := range []struct {
c Config
c *Config
dsn string
}{
{
c: Config{DBName: "foobar", SSLMode: VerifyFull},
dsn: "postgres://localhost:5432/foobar?sslmode=verify-full",
c: &Config{DBName: "foobar", SSLMode: VerifyFull, SSLSNI: true},
dsn: "postgres://localhost:5432/foobar?sslmode=verify-full&sslsni=1",
},
{
c: Config{DBName: "foobar", CACertFile: "foobar"},
dsn: "postgres://localhost:5432/foobar?sslmode=verify-ca&sslrootcert=foobar",
c: &Config{DBName: "foobar", CACertFile: "foobar"},
dsn: "postgres://localhost:5432/foobar?sslmode=verify-ca&sslrootcert=foobar&sslsni=0",
},
} {
dsn, err := tt.c.DSN()
Expand Down
27 changes: 23 additions & 4 deletions backend/postgres/db.go
Expand Up @@ -96,6 +96,7 @@ func (q *queryer) Exec(ctx context.Context, stmt string, vs ...interface{}) (sql
}

res, err := q.q.Exec(ctx, stmt, vs...)

return res, wrapErr(err)
}

Expand Down Expand Up @@ -125,6 +126,20 @@ const (
rollbackClass = pq.ErrorClass("40")
)

type queryCanceledError struct {
cause *pq.Error
}

func (qce *queryCanceledError) Error() string {
return qce.cause.Error()
}

func (qce *queryCanceledError) Is(target error) bool {
return target == context.Canceled
}

func (qce *queryCanceledError) Unwrap() error { return qce.cause }

func wrapErr(err error) error {
if err == nil {
return err
Expand All @@ -136,6 +151,10 @@ func wrapErr(err error) error {
return err
}

if pqErr.Code == "57014" {
return &queryCanceledError{cause: pqErr}
}

switch pqErr.Code.Class() {
case constraintClass:
return wrapConstraintErr(pqErr)
Expand All @@ -149,7 +168,7 @@ func wrapErr(err error) error {
func wrapRollbackError(pqErr *pq.Error) error {
var err = sql.RollbackError{Cause: pqErr}

if pqErr.Code == pq.ErrorCode("40001") {
if pqErr.Code == "40001" {
err.Type = sql.SerializationFailure
}

Expand All @@ -160,11 +179,11 @@ func wrapConstraintErr(pqErr *pq.Error) error {
var err = sql.ConstraintError{Cause: pqErr, Constraint: pqErr.Column}

switch pqErr.Code {
case pq.ErrorCode("23503"):
case "23503":
err.Type = sql.ForeignKey
case pq.ErrorCode("23502"):
case "23502":
err.Type = sql.NotNull
case pq.ErrorCode("23505"):
case "23505":
if strings.HasSuffix(pqErr.Constraint, "_pkey") {
err.Type = sql.PrimaryKey
} else {
Expand Down
6 changes: 3 additions & 3 deletions go.mod
@@ -1,17 +1,17 @@
module github.com/upfluence/sql

go 1.18
go 1.21

require (
github.com/lib/pq v1.4.0
github.com/mattn/go-sqlite3 v1.13.1-0.20200416054559-98a44bcf5949
github.com/stretchr/testify v1.6.1
github.com/stretchr/testify v1.8.4
github.com/upfluence/errors v0.2.2
github.com/upfluence/log v0.0.4
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
6 changes: 4 additions & 2 deletions go.sum
Expand Up @@ -123,8 +123,9 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc=
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0=
Expand Down Expand Up @@ -196,5 +197,6 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20191120175047-4206685974f2/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c h1:grhR+C34yXImVGp7EzNk+DTIk+323eIUWOmEevy6bDo=
gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
36 changes: 32 additions & 4 deletions integration/error_test.go
Expand Up @@ -3,15 +3,46 @@ package integration
import (
"context"
"testing"
"time"

"github.com/lib/pq"
"github.com/stretchr/testify/assert"

"github.com/upfluence/sql"
"github.com/upfluence/sql/sqltest"
"github.com/upfluence/sql/x/migration"
)

func TestCanceling(t *testing.T) {
sqltest.NewTestCase().Run(t, func(t *testing.T, db sql.DB) {
ctx, done := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer done()

var stmt string

switch d := db.Driver(); d {
case "postgres":
stmt = "pg_sleep(1)"
default:
t.Skipf("driver not handled: %q", d)
}

err := db.QueryRow(ctx, "SELECT "+stmt).Scan()

assert.ErrorIs(t, err, context.Canceled)
})
}

func TestCanceled(t *testing.T) {
sqltest.NewTestCase().Run(t, func(t *testing.T, db sql.DB) {
ctx, done := context.WithCancel(context.Background())
done()

err := db.QueryRow(ctx, "SELECT 1").Scan()

assert.ErrorIs(t, err, context.Canceled)
})
}

func TestConstraintPrimaryKeyError(t *testing.T) {
sqltest.NewTestCase(
sqltest.WithMigratorFunc(func(db sql.DB) migration.Migrator {
Expand All @@ -35,9 +66,6 @@ func TestConstraintPrimaryKeyError(t *testing.T) {

assert.True(t, ok)

if pqerr, ok := cerr.Cause.(*pq.Error); ok {
t.Logf("%+v", pqerr.Constraint)
}
assert.Equal(t, sql.PrimaryKey, cerr.Type)
assert.Equal(
t,
Expand Down

0 comments on commit f059df7

Please sign in to comment.