Skip to content

Commit

Permalink
Add keepalives support
Browse files Browse the repository at this point in the history
  • Loading branch information
marselester committed Sep 30, 2020
1 parent 083382b commit 9b159b5
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 2 deletions.
2 changes: 2 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,8 @@ func isDriverSetting(key string) bool {
return true
case "fallback_application_name":
return true
case "keepalives", "keepalives_interval":
return true
case "connect_timeout":
return true
case "disable_prepared_binary_result":
Expand Down
39 changes: 37 additions & 2 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ import (
"database/sql/driver"
"errors"
"fmt"
"net"
"os"
"strconv"
"strings"
"time"
)

// Connector represents a fixed configuration for the pq driver with a given
Expand Down Expand Up @@ -107,9 +110,41 @@ func NewConnector(dsn string) (*Connector, error) {
}

// SSL is not necessary or supported over UNIX domain sockets
if network, _ := network(o); network == "unix" {
ntw, _ := network(o)
if ntw == "unix" {
o["sslmode"] = "disable"
}

return &Connector{opts: o, dialer: defaultDialer{}}, nil
var d net.Dialer
if ntw == "tcp" {
d.KeepAlive, err = keepalive(o)
if err != nil {
return nil, err
}
}

return &Connector{opts: o, dialer: defaultDialer{d}}, nil
}

// keepalive returns the interval between keep-alive probes controlled by keepalives_interval.
// If zero, keep-alive probes are sent with a default value (see net.Dialer).
// If negative, keep-alive probes are disabled.
//
// The keepalives parameter controls whether client-side TCP keepalives are used.
// The default value is 1, meaning on, but you can change this to 0, meaning off, if keepalives are not wanted.
func keepalive(o values) (time.Duration, error) {
v, ok := o["keepalives"]
if v == "0" {
return -1, nil
}

if v, ok = o["keepalives_interval"]; !ok {
return 0, nil
}

keepintvl, err := strconv.ParseInt(v, 10, 0)
if err != nil {
return 0, fmt.Errorf("invalid value for parameter keepalives_interval: %w", err)
}
return time.Duration(keepintvl) * time.Second, nil
}
70 changes: 70 additions & 0 deletions connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"strconv"
"testing"
"time"
)

func TestNewConnector_WorksWithOpenDB(t *testing.T) {
Expand Down Expand Up @@ -65,3 +68,70 @@ func TestNewConnector_Driver(t *testing.T) {
}
txn.Rollback()
}

func TestNewConnectorKeepalive(t *testing.T) {
c, err := NewConnector("keepalives=1 keepalives_interval=10")
if err != nil {
t.Fatal(err)
}
db := sql.OpenDB(c)
defer db.Close()
// database/sql might not call our Open at all unless we do something with
// the connection
txn, err := db.Begin()
if err != nil {
t.Fatal(err)
}
txn.Rollback()

d, _ := c.dialer.(defaultDialer)
want := 10 * time.Second
if want != d.d.KeepAlive {
t.Fatalf("expected: %v, got: %v", want, d.d.KeepAlive)
}
}

func TestKeepalive(t *testing.T) {
var tt = map[string]struct {
input values
want time.Duration
}{
"keepalives on": {values{"keepalives": "1"}, 0},
"keepalives on by default": {nil, 0},
"keepalives off": {values{"keepalives": "0"}, -1},
"keepalives_interval 5 seconds": {values{"keepalives_interval": "5"}, 5 * time.Second},
"keepalives_interval default": {values{"keepalives_interval": "0"}, 0},
"keepalives_interval off": {values{"keepalives_interval": "-1"}, -1 * time.Second},
}

for name, tc := range tt {
t.Run(name, func(t *testing.T) {
got, err := keepalive(tc.input)
if err != nil {
t.Fatal(err)
}
if tc.want != got {
t.Fatalf("expected: %v, got: %v", tc.want, got)
}
})
}
}

func TestKeepaliveError(t *testing.T) {
var tt = map[string]struct {
input values
want error
}{
"keepalives_interval whitespace": {values{"keepalives_interval": " "}, strconv.ErrSyntax},
"keepalives_interval float": {values{"keepalives_interval": "1.1"}, strconv.ErrSyntax},
}

for name, tc := range tt {
t.Run(name, func(t *testing.T) {
_, err := keepalive(tc.input)
if !errors.Is(err, tc.want) {
t.Fatalf("expected: %v, got: %v", tc.want, err)
}
})
}
}
6 changes: 6 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ supported:
* sslmode - Whether or not to use SSL (default is require, this is not
the default for libpq)
* fallback_application_name - An application_name to fall back to if one isn't provided.
* keepalives - Whether or not to use client-side TCP keepalives
(the default value is 1, meaning on, but you can change this to 0, meaning off)
* keepalives_interval - The number of seconds after which a TCP keepalive message
that is not acknowledged by the server should be retransmitted.
If zero or not specified, keep-alive probes are sent with a default value (see net.Dialer).
If negative, keep-alive probes are disabled.
* connect_timeout - Maximum wait for connection, in seconds. Zero or
not specified means wait indefinitely.
* sslcert - Cert file location. The file must contain PEM encoded data.
Expand Down

0 comments on commit 9b159b5

Please sign in to comment.