Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set SNI for TSL connections #1088

Merged
merged 1 commit into from Sep 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion conn.go
Expand Up @@ -1127,7 +1127,7 @@ func isDriverSetting(key string) bool {
return true
case "password":
return true
case "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline":
case "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline", "sslsni":
return true
case "fallback_application_name":
return true
Expand Down Expand Up @@ -2020,6 +2020,8 @@ func parseEnviron(env []string) (out map[string]string) {
accrue("sslkey")
case "PGSSLROOTCERT":
accrue("sslrootcert")
case "PGSSLSNI":
accrue("sslsni")
case "PGREQUIRESSL", "PGSSLCRL":
unsupported()
case "PGREQUIREPEER":
Expand Down
11 changes: 11 additions & 0 deletions ssl.go
Expand Up @@ -8,6 +8,7 @@ import (
"os"
"os/user"
"path/filepath"
"strings"
)

// ssl generates a function to upgrade a net.Conn based on the "sslmode" and
Expand Down Expand Up @@ -50,6 +51,16 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
return nil, fmterrorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode)
}

// Set Server Name Indication (SNI), if enabled by connection parameters.
// By default SNI is on, any value which is not starting with "1" disables
// SNI -- that is the same check vanilla libpq uses.
if sslsni := o["sslsni"]; sslsni == "" || strings.HasPrefix(sslsni, "1") {
// RFC 6066 asks to not set SNI if the host is a literal IP address (IPv4
// or IPv6). This check is coded already crypto.tls.hostnameInSNI, so
// just always set ServerName here and let crypto/tls do the filtering.
tlsConf.ServerName = o["host"]
}

err := sslClientCertificates(&tlsConf, o)
if err != nil {
return nil, err
Expand Down
139 changes: 139 additions & 0 deletions ssl_test.go
Expand Up @@ -3,12 +3,19 @@ package pq
// This file contains SSL tests

import (
"bytes"
_ "crypto/sha256"
"crypto/tls"
"crypto/x509"
"database/sql"
"fmt"
"io"
"net"
"os"
"path/filepath"
"strings"
"testing"
"time"
)

func maybeSkipSSLTests(t *testing.T) {
Expand Down Expand Up @@ -280,3 +287,135 @@ func TestSSLClientCertificates(t *testing.T) {
}
}
}

// Check that clint sends SNI data when `sslsni` is not disabled
func TestSNISupport(t *testing.T) {
t.Parallel()
tests := []struct {
name string
conn_param string
hostname string
expected_sni string
}{
{
name: "SNI is set by default",
conn_param: "",
hostname: "localhost",
expected_sni: "localhost",
},
{
name: "SNI is passed when asked for",
conn_param: "sslsni=1",
hostname: "localhost",
expected_sni: "localhost",
},
{
name: "SNI is not passed when disabled",
conn_param: "sslsni=0",
hostname: "localhost",
expected_sni: "",
},
{
name: "SNI is not set for IPv4",
conn_param: "",
hostname: "127.0.0.1",
expected_sni: "",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

// Start mock postgres server on OS-provided port
listener, err := net.Listen("tcp", "127.0.0.1:")
if err != nil {
t.Fatal(err)
}
serverErrChan := make(chan error, 1)
serverSNINameChan := make(chan string, 1)
go mockPostgresSSL(listener, serverErrChan, serverSNINameChan)

defer listener.Close()
defer close(serverErrChan)
defer close(serverSNINameChan)

// Try to establish a connection with the mock server. Connection will error out after TLS
// clientHello, but it is enough to catch SNI data on the server side
port := strings.Split(listener.Addr().String(), ":")[1]
connStr := fmt.Sprintf("sslmode=require host=%s port=%s %s", tt.hostname, port, tt.conn_param)

// We are okay to skip this error as we are polling serverErrChan and we'll get an error
// or timeout from the server side in case of problems here.
db, _ := sql.Open("postgres", connStr)
_, _ = db.Exec("SELECT 1")

// Check SNI data
select {
case sniHost := <-serverSNINameChan:
if sniHost != tt.expected_sni {
t.Fatalf("Expected SNI to be 'localhost', got '%+v' instead", sniHost)
}
case err = <-serverErrChan:
t.Fatalf("mock server failed with error: %+v", err)
case <-time.After(time.Second):
t.Fatal("exceeded connection timeout without erroring out")
}
})
}
}

// Make a postgres mock server to test TLS SNI
//
// Accepts postgres StartupMessage and handles TLS clientHello, then closes a connection.
// While reading clientHello catch passed SNI data and report it to nameChan.
func mockPostgresSSL(listener net.Listener, errChan chan error, nameChan chan string) {
var sniHost string

conn, err := listener.Accept()
if err != nil {
errChan <- err
return
}
defer conn.Close()

err = conn.SetDeadline(time.Now().Add(time.Second))
if err != nil {
errChan <- err
return
}

// Receive StartupMessage with SSL Request
startupMessage := make([]byte, 8)
if _, err := io.ReadFull(conn, startupMessage); err != nil {
errChan <- err
return
}
// StartupMessage: first four bytes -- total len = 8, last four bytes SslRequestNumber
if !bytes.Equal(startupMessage, []byte{0, 0, 0, 0x8, 0x4, 0xd2, 0x16, 0x2f}) {
errChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage)
return
}

// Respond with SSLOk
_, err = conn.Write([]byte("S"))
if err != nil {
errChan <- err
return
}

// Set up TLS context to catch clientHello. It will always error out during handshake
// as no certificate is set.
srv := tls.Server(conn, &tls.Config{
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
sniHost = argHello.ServerName
return nil, nil
},
})
defer srv.Close()

// Do the TLS handshake ignoring errors
_ = srv.Handshake()

nameChan <- sniHost
}