Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a new "tls-mode=preferred" DSN parameter #928

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Expand Up @@ -25,6 +25,7 @@ Daniel Montoya <dsmontoyam at gmail.com>
Daniel Nichter <nil at codenode.com>
Daniël van Eeden <git at myname.nl>
Dave Protasowski <dprotaso at gmail.com>
David Weitzman <dweitzman at gmail.com>
DisposaBoy <disposaboy at dby.me>
Egor Smolyakov <egorsmkv at gmail.com>
Evan Shaw <evan at vendhq.com>
Expand Down
11 changes: 11 additions & 0 deletions README.md
Expand Up @@ -335,6 +335,17 @@ Default: false
`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side) or use `preferred` to use TLS only when advertised by the server. This is similar to `skip-verify`, but additionally allows a fallback to a connection which is not encrypted. Neither `skip-verify` nor `preferred` add any reliable security. You can use a custom TLS config after registering it with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig).


##### `tls-mode`

```
Type: string
Valid Values: preferred
Default: <none>
```

Use `tls-mode=preferred` to opt-in to TLS / SSL only with servers that support it. `preferred` does not authenticate the server but allows servers to optionally authenticate clients. The [`tls`](#tls) DSN parameter allows customizing the TLS config.


##### `writeTimeout`

```
Expand Down
20 changes: 20 additions & 0 deletions dsn.go
Expand Up @@ -46,6 +46,7 @@ type Config struct {
pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name
tls *tls.Config // TLS configuration
TLSOptional bool // Allows non-TLS for servers that don't have TLS capability
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
Expand Down Expand Up @@ -287,6 +288,10 @@ func (cfg *Config) FormatDSN() string {
buf.WriteString(url.QueryEscape(cfg.TLSConfig))
}

if cfg.TLSOptional && cfg.TLSConfig != "preferred" {
buf.WriteString("&tls-mode=preferred")
}

if cfg.WriteTimeout > 0 {
if hasParam {
buf.WriteString("&writeTimeout=")
Expand Down Expand Up @@ -550,6 +555,18 @@ func parseDSNParams(cfg *Config, params string) (err error) {
return
}

// TLS enforcement settings
case "tls-mode":
switch value {
case "preferred":
cfg.TLSOptional = true
if cfg.tls == nil {
cfg.tls = &tls.Config{InsecureSkipVerify: true}
}
default:
return errors.New("invalid value / unknown tls-mode: " + value)
}

// TLS-Encryption
case "tls":
boolValue, isBool := readBool(value)
Expand All @@ -563,6 +580,9 @@ func parseDSNParams(cfg *Config, params string) (err error) {
} else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" {
cfg.TLSConfig = vl
cfg.tls = &tls.Config{InsecureSkipVerify: true}
if vl == "preferred" {
cfg.TLSOptional = true
}
} else {
name, err := url.QueryUnescape(value)
if err != nil {
Expand Down
31 changes: 31 additions & 0 deletions dsn_test.go
Expand Up @@ -35,6 +35,12 @@ var testDSNs = []struct {
}, {
"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "true"},
}, {
"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=preferred",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "preferred", TLSOptional: true},
}, {
"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true&tls-mode=preferred",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "true", TLSOptional: true},
}, {
"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "skip-verify"},
Expand Down Expand Up @@ -246,6 +252,31 @@ func TestDSNTLSConfig(t *testing.T) {
if cfg.tls.ServerName != expectedServerName {
t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.tls.ServerName)
}

dsn = "tcp(example.com)/?tls-mode=invalid"
_, err = ParseDSN(dsn)
wantError := "invalid value / unknown tls-mode: invalid"
if err == nil || err.Error() != wantError {
t.Errorf("ParseDSN(%s). Got error: %v. Want error: %v", dsn, err, wantError)
}

dsn = "tcp(example.com)/?tls-mode=preferred"
cfg, err = ParseDSN(dsn)
if err != nil {
t.Error(err.Error())
}
if cfg.tls == nil {
t.Error("cfg.tls should not be nil")
}
if cfg.tls.ServerName != expectedServerName {
t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.tls.ServerName)
}
if !cfg.tls.InsecureSkipVerify {
t.Errorf("cfg.tls.InsecureSkipVerify should be true")
}
if !cfg.TLSOptional {
t.Error("cfg.TLSOptional should be true")
}
}

func TestDSNWithCustomTLSQueryEscape(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion packets.go
Expand Up @@ -194,7 +194,7 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
return nil, "", ErrOldProtocol
}
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
if mc.cfg.TLSConfig == "preferred" {
if mc.cfg.TLSOptional {
mc.cfg.tls = nil
} else {
return nil, "", ErrNoTLS
Expand Down