From 6d36ff035963d77f1e3e30252c290cb5da563e56 Mon Sep 17 00:00:00 2001 From: David Weitzman Date: Thu, 7 Mar 2019 03:13:24 -0800 Subject: [PATCH] Add a new "tls-mode=preferred" DSN parameter Separating "preferred" into its own parameter instead of making it a special value in the "tls=" parameter makes it possible to use custom TLS config with this mode. This is useful when clients don't need to authenticate servers using TLS but a server may or may not need to authenticate the client using TLS. --- AUTHORS | 1 + README.md | 11 +++++++++++ dsn.go | 20 ++++++++++++++++++++ dsn_test.go | 31 +++++++++++++++++++++++++++++++ packets.go | 2 +- 5 files changed, 64 insertions(+), 1 deletion(-) diff --git a/AUTHORS b/AUTHORS index 5482a8536..5cdc7391b 100644 --- a/AUTHORS +++ b/AUTHORS @@ -25,6 +25,7 @@ Daniel Montoya Daniel Nichter Daniƫl van Eeden Dave Protasowski +David Weitzman DisposaBoy Egor Smolyakov Evan Shaw diff --git a/README.md b/README.md index 341d9194c..7d6cf627c 100644 --- a/README.md +++ b/README.md @@ -335,6 +335,17 @@ Default: false `tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side) or use `preferred` to use TLS only when advertised by the server. This is similar to `skip-verify`, but additionally allows a fallback to a connection which is not encrypted. Neither `skip-verify` nor `preferred` add any reliable security. You can use a custom TLS config after registering it with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig). +##### `tls-mode` + +``` +Type: string +Valid Values: preferred +Default: +``` + +Use `tls-mode=preferred` to opt-in to TLS / SSL only with servers that support it. `preferred` does not authenticate the server but allows servers to optionally authenticate clients. The [`tls`](#tls) DSN parameter allows customizing the TLS config. + + ##### `writeTimeout` ``` diff --git a/dsn.go b/dsn.go index b9134722e..db487c872 100644 --- a/dsn.go +++ b/dsn.go @@ -46,6 +46,7 @@ type Config struct { pubKey *rsa.PublicKey // Server public key TLSConfig string // TLS configuration name tls *tls.Config // TLS configuration + TLSOptional bool // Allows non-TLS for servers that don't have TLS capability Timeout time.Duration // Dial timeout ReadTimeout time.Duration // I/O read timeout WriteTimeout time.Duration // I/O write timeout @@ -287,6 +288,10 @@ func (cfg *Config) FormatDSN() string { buf.WriteString(url.QueryEscape(cfg.TLSConfig)) } + if cfg.TLSOptional && cfg.TLSConfig != "preferred" { + buf.WriteString("&tls-mode=preferred") + } + if cfg.WriteTimeout > 0 { if hasParam { buf.WriteString("&writeTimeout=") @@ -550,6 +555,18 @@ func parseDSNParams(cfg *Config, params string) (err error) { return } + // TLS enforcement settings + case "tls-mode": + switch value { + case "preferred": + cfg.TLSOptional = true + if cfg.tls == nil { + cfg.tls = &tls.Config{InsecureSkipVerify: true} + } + default: + return errors.New("invalid value / unknown tls-mode: " + value) + } + // TLS-Encryption case "tls": boolValue, isBool := readBool(value) @@ -563,6 +580,9 @@ func parseDSNParams(cfg *Config, params string) (err error) { } else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" { cfg.TLSConfig = vl cfg.tls = &tls.Config{InsecureSkipVerify: true} + if vl == "preferred" { + cfg.TLSOptional = true + } } else { name, err := url.QueryUnescape(value) if err != nil { diff --git a/dsn_test.go b/dsn_test.go index 1cd095496..0bc94aafd 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -35,6 +35,12 @@ var testDSNs = []struct { }, { "user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "true"}, +}, { + "user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=preferred", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "preferred", TLSOptional: true}, +}, { + "user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true&tls-mode=preferred", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "true", TLSOptional: true}, }, { "user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "skip-verify"}, @@ -246,6 +252,31 @@ func TestDSNTLSConfig(t *testing.T) { if cfg.tls.ServerName != expectedServerName { t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.tls.ServerName) } + + dsn = "tcp(example.com)/?tls-mode=invalid" + _, err = ParseDSN(dsn) + wantError := "invalid value / unknown tls-mode: invalid" + if err == nil || err.Error() != wantError { + t.Errorf("ParseDSN(%s). Got error: %v. Want error: %v", dsn, err, wantError) + } + + dsn = "tcp(example.com)/?tls-mode=preferred" + cfg, err = ParseDSN(dsn) + if err != nil { + t.Error(err.Error()) + } + if cfg.tls == nil { + t.Error("cfg.tls should not be nil") + } + if cfg.tls.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.tls.ServerName) + } + if !cfg.tls.InsecureSkipVerify { + t.Errorf("cfg.tls.InsecureSkipVerify should be true") + } + if !cfg.TLSOptional { + t.Error("cfg.TLSOptional should be true") + } } func TestDSNWithCustomTLSQueryEscape(t *testing.T) { diff --git a/packets.go b/packets.go index 5e0853767..6396e7d2b 100644 --- a/packets.go +++ b/packets.go @@ -194,7 +194,7 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro return nil, "", ErrOldProtocol } if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { - if mc.cfg.TLSConfig == "preferred" { + if mc.cfg.TLSOptional { mc.cfg.tls = nil } else { return nil, "", ErrNoTLS