diff --git a/conn.go b/conn.go index 7d83f672..e70b386f 100644 --- a/conn.go +++ b/conn.go @@ -1127,7 +1127,7 @@ func isDriverSetting(key string) bool { return true case "password": return true - case "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline": + case "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline", "sslsni": return true case "fallback_application_name": return true @@ -2020,6 +2020,8 @@ func parseEnviron(env []string) (out map[string]string) { accrue("sslkey") case "PGSSLROOTCERT": accrue("sslrootcert") + case "PGSSLSNI": + accrue("sslsni") case "PGREQUIRESSL", "PGSSLCRL": unsupported() case "PGREQUIREPEER": diff --git a/ssl.go b/ssl.go index e5eb9289..36b61ba4 100644 --- a/ssl.go +++ b/ssl.go @@ -8,6 +8,7 @@ import ( "os" "os/user" "path/filepath" + "strings" ) // ssl generates a function to upgrade a net.Conn based on the "sslmode" and @@ -50,6 +51,16 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) { return nil, fmterrorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode) } + // Set Server Name Indication (SNI), if enabled by connection parameters. + // By default SNI is on, any value which is not starting with "1" disables + // SNI -- that is the same check vanilla libpq uses. + if sslsni := o["sslsni"]; sslsni == "" || strings.HasPrefix(sslsni, "1") { + // RFC 6066 asks to not set SNI if the host is a literal IP address (IPv4 + // or IPv6). This check is coded already crypto.tls.hostnameInSNI, so + // just always set ServerName here and let crypto/tls do the filtering. + tlsConf.ServerName = o["host"] + } + err := sslClientCertificates(&tlsConf, o) if err != nil { return nil, err diff --git a/ssl_test.go b/ssl_test.go index e00522e7..64d68cf4 100644 --- a/ssl_test.go +++ b/ssl_test.go @@ -3,12 +3,19 @@ package pq // This file contains SSL tests import ( + "bytes" _ "crypto/sha256" + "crypto/tls" "crypto/x509" "database/sql" + "fmt" + "io" + "net" "os" "path/filepath" + "strings" "testing" + "time" ) func maybeSkipSSLTests(t *testing.T) { @@ -280,3 +287,135 @@ func TestSSLClientCertificates(t *testing.T) { } } } + +// Check that clint sends SNI data when `sslsni` is not disabled +func TestSNISupport(t *testing.T) { + t.Parallel() + tests := []struct { + name string + conn_param string + hostname string + expected_sni string + }{ + { + name: "SNI is set by default", + conn_param: "", + hostname: "localhost", + expected_sni: "localhost", + }, + { + name: "SNI is passed when asked for", + conn_param: "sslsni=1", + hostname: "localhost", + expected_sni: "localhost", + }, + { + name: "SNI is not passed when disabled", + conn_param: "sslsni=0", + hostname: "localhost", + expected_sni: "", + }, + { + name: "SNI is not set for IPv4", + conn_param: "", + hostname: "127.0.0.1", + expected_sni: "", + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Start mock postgres server on OS-provided port + listener, err := net.Listen("tcp", "127.0.0.1:") + if err != nil { + t.Fatal(err) + } + serverErrChan := make(chan error, 1) + serverSNINameChan := make(chan string, 1) + go mockPostgresSSL(listener, serverErrChan, serverSNINameChan) + + defer listener.Close() + defer close(serverErrChan) + defer close(serverSNINameChan) + + // Try to establish a connection with the mock server. Connection will error out after TLS + // clientHello, but it is enough to catch SNI data on the server side + port := strings.Split(listener.Addr().String(), ":")[1] + connStr := fmt.Sprintf("sslmode=require host=%s port=%s %s", tt.hostname, port, tt.conn_param) + + // We are okay to skip this error as we are polling serverErrChan and we'll get an error + // or timeout from the server side in case of problems here. + db, _ := sql.Open("postgres", connStr) + _, _ = db.Exec("SELECT 1") + + // Check SNI data + select { + case sniHost := <-serverSNINameChan: + if sniHost != tt.expected_sni { + t.Fatalf("Expected SNI to be 'localhost', got '%+v' instead", sniHost) + } + case err = <-serverErrChan: + t.Fatalf("mock server failed with error: %+v", err) + case <-time.After(time.Second): + t.Fatal("exceeded connection timeout without erroring out") + } + }) + } +} + +// Make a postgres mock server to test TLS SNI +// +// Accepts postgres StartupMessage and handles TLS clientHello, then closes a connection. +// While reading clientHello catch passed SNI data and report it to nameChan. +func mockPostgresSSL(listener net.Listener, errChan chan error, nameChan chan string) { + var sniHost string + + conn, err := listener.Accept() + if err != nil { + errChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(time.Second)) + if err != nil { + errChan <- err + return + } + + // Receive StartupMessage with SSL Request + startupMessage := make([]byte, 8) + if _, err := io.ReadFull(conn, startupMessage); err != nil { + errChan <- err + return + } + // StartupMessage: first four bytes -- total len = 8, last four bytes SslRequestNumber + if !bytes.Equal(startupMessage, []byte{0, 0, 0, 0x8, 0x4, 0xd2, 0x16, 0x2f}) { + errChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage) + return + } + + // Respond with SSLOk + _, err = conn.Write([]byte("S")) + if err != nil { + errChan <- err + return + } + + // Set up TLS context to catch clientHello. It will always error out during handshake + // as no certificate is set. + srv := tls.Server(conn, &tls.Config{ + GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) { + sniHost = argHello.ServerName + return nil, nil + }, + }) + defer srv.Close() + + // Do the TLS handshake ignoring errors + _ = srv.Handshake() + + nameChan <- sniHost +}