Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keepalives support #999

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 7 additions & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,11 @@ func (cn *conn) writeBuf(b byte) *writeBuf {
// Most users should only use it through database/sql package from the standard
// library.
func Open(dsn string) (_ driver.Conn, err error) {
return DialOpen(defaultDialer{}, dsn)
c, err := NewConnector(dsn)
if err != nil {
return nil, err
}
return c.open(context.Background())
}

// DialOpen opens a new connection to the database using a dialer.
Expand Down Expand Up @@ -1068,6 +1072,8 @@ func isDriverSetting(key string) bool {
return true
case "fallback_application_name":
return true
case "keepalives", "keepalives_interval":
return true
case "connect_timeout":
return true
case "disable_prepared_binary_result":
Expand Down
40 changes: 40 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,46 @@ func TestOpenURL(t *testing.T) {
testURL("postgresql://")
}

func TestOpen(t *testing.T) {
dsn := "keepalives_interval=10"
c, err := Open(dsn)
if err != nil {
t.Fatal(err)
}
defer c.Close()

d := c.(*conn).dialer.(defaultDialer)
want := 10 * time.Second
if want != d.d.KeepAlive {
t.Fatalf("expected: %v, got: %v", want, d.d.KeepAlive)
}
}

func TestSQLOpen(t *testing.T) {
dsn := "keepalives_interval=10"
db, err := sql.Open("postgres", dsn)
if err != nil {
t.Fatal(err)
}
defer db.Close()
if err = db.Ping(); err != nil {
t.Fatal(err)
}

drv := db.Driver()
c, err := drv.Open(dsn)
if err != nil {
t.Fatal(err)
}
defer c.Close()

d := c.(*conn).dialer.(defaultDialer)
want := 10 * time.Second
if want != d.d.KeepAlive {
t.Fatalf("expected: %v, got: %v", want, d.d.KeepAlive)
}
}

const pgpassFile = "/tmp/pqgotest_pgpass"

func TestPgpass(t *testing.T) {
Expand Down
39 changes: 37 additions & 2 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ import (
"database/sql/driver"
"errors"
"fmt"
"net"
"os"
"strconv"
"strings"
"time"
)

// Connector represents a fixed configuration for the pq driver with a given
Expand Down Expand Up @@ -107,9 +110,41 @@ func NewConnector(dsn string) (*Connector, error) {
}

// SSL is not necessary or supported over UNIX domain sockets
if network, _ := network(o); network == "unix" {
ntw, _ := network(o)
if ntw == "unix" {
o["sslmode"] = "disable"
}

return &Connector{opts: o, dialer: defaultDialer{}}, nil
var d net.Dialer
if ntw == "tcp" {
d.KeepAlive, err = keepalive(o)
if err != nil {
return nil, err
}
}

return &Connector{opts: o, dialer: defaultDialer{d}}, nil
}

// keepalive returns the interval between keep-alive probes controlled by keepalives_interval.
// If zero, keep-alive probes are sent with a default value (see net.Dialer).
// If negative, keep-alive probes are disabled.
//
// The keepalives parameter controls whether client-side TCP keepalives are used.
// The default value is 1, meaning on, but you can change this to 0, meaning off, if keepalives are not wanted.
func keepalive(o values) (time.Duration, error) {
v, ok := o["keepalives"]
if ok && v == "0" {
return -1, nil
}

if v, ok = o["keepalives_interval"]; !ok {
return 0, nil
}

keepintvl, err := strconv.ParseInt(v, 10, 0)
if err != nil {
return 0, fmt.Errorf("invalid value for parameter keepalives_interval: %v", err)
}
return time.Duration(keepintvl) * time.Second, nil
}
70 changes: 70 additions & 0 deletions connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ import (
"context"
"database/sql"
"database/sql/driver"
"strconv"
"strings"
"testing"
"time"
)

func TestNewConnector_WorksWithOpenDB(t *testing.T) {
Expand Down Expand Up @@ -65,3 +68,70 @@ func TestNewConnector_Driver(t *testing.T) {
}
txn.Rollback()
}

func TestNewConnectorKeepalive(t *testing.T) {
c, err := NewConnector("keepalives=1 keepalives_interval=10")
if err != nil {
t.Fatal(err)
}
db := sql.OpenDB(c)
defer db.Close()
// database/sql might not call our Open at all unless we do something with
// the connection
txn, err := db.Begin()
if err != nil {
t.Fatal(err)
}
txn.Rollback()

d, _ := c.dialer.(defaultDialer)
want := 10 * time.Second
if want != d.d.KeepAlive {
t.Fatalf("expected: %v, got: %v", want, d.d.KeepAlive)
}
}

func TestKeepalive(t *testing.T) {
var tt = map[string]struct {
input values
want time.Duration
}{
"keepalives on": {values{"keepalives": "1"}, 0},
"keepalives on by default": {nil, 0},
"keepalives off": {values{"keepalives": "0"}, -1},
"keepalives_interval 5 seconds": {values{"keepalives_interval": "5"}, 5 * time.Second},
"keepalives_interval default": {values{"keepalives_interval": "0"}, 0},
"keepalives_interval off": {values{"keepalives_interval": "-1"}, -1 * time.Second},
}

for name, tc := range tt {
t.Run(name, func(t *testing.T) {
got, err := keepalive(tc.input)
if err != nil {
t.Fatal(err)
}
if tc.want != got {
t.Fatalf("expected: %v, got: %v", tc.want, got)
}
})
}
}

func TestKeepaliveError(t *testing.T) {
var tt = map[string]struct {
input values
want error
}{
"keepalives_interval whitespace": {values{"keepalives_interval": " "}, strconv.ErrSyntax},
"keepalives_interval float": {values{"keepalives_interval": "1.1"}, strconv.ErrSyntax},
}

for name, tc := range tt {
t.Run(name, func(t *testing.T) {
_, err := keepalive(tc.input)
if !strings.HasSuffix(err.Error(), tc.want.Error()) {
t.Fatalf("expected: %v, got: %v", tc.want, err)
}
})
}
}
6 changes: 6 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ supported:
* sslmode - Whether or not to use SSL (default is require, this is not
the default for libpq)
* fallback_application_name - An application_name to fall back to if one isn't provided.
* keepalives - Whether or not to use client-side TCP keepalives
(the default value is 1, meaning on, but you can change this to 0, meaning off)
* keepalives_interval - The number of seconds after which a TCP keepalive message
that is not acknowledged by the server should be retransmitted.
If zero or not specified, keep-alive probes are sent with a default value (see net.Dialer).
If negative, keep-alive probes are disabled.
* connect_timeout - Maximum wait for connection, in seconds. Zero or
not specified means wait indefinitely.
* sslcert - Cert file location. The file must contain PEM encoded data.
Expand Down