Skip to content

Commit

Permalink
Add AllowFallbackToPlaintext and TLS to config (#1370)
Browse files Browse the repository at this point in the history
  • Loading branch information
lance6716 committed Nov 28, 2022
1 parent fa1e4ed commit 41dd159
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 63 deletions.
1 change: 1 addition & 0 deletions AUTHORS
Expand Up @@ -61,6 +61,7 @@ Kamil Dziedzic <kamil at klecza.pl>
Kei Kamikawa <x00.x7f.x86 at gmail.com>
Kevin Malachowski <kevin at chowski.com>
Kieron Woodhouse <kieron.woodhouse at infosum.com>
Lance Tian <lance6716 at gmail.com>
Lennart Rudolph <lrudolph at hmc.edu>
Leonardo YongUk Kim <dalinaum at gmail.com>
Linh Tran Tuan <linhduonggnu at gmail.com>
Expand Down
11 changes: 11 additions & 0 deletions README.md
Expand Up @@ -157,6 +157,17 @@ Default: false

`allowCleartextPasswords=true` allows using the [cleartext client side plugin](https://dev.mysql.com/doc/en/cleartext-pluggable-authentication.html) if required by an account, such as one defined with the [PAM authentication plugin](http://dev.mysql.com/doc/en/pam-authentication-plugin.html). Sending passwords in clear text may be a security problem in some configurations. To avoid problems if there is any possibility that the password would be intercepted, clients should connect to MySQL Server using a method that protects the password. Possibilities include [TLS / SSL](#tls), IPsec, or a private network.


##### `allowFallbackToPlaintext`

```
Type: bool
Valid Values: true, false
Default: false
```

`allowFallbackToPlaintext=true` acts like a `--ssl-mode=PREFERRED` MySQL client as described in [Command Options for Connecting to the Server](https://dev.mysql.com/doc/refman/5.7/en/connection-options.html#option_general_ssl-mode)

##### `allowNativePasswords`

```
Expand Down
4 changes: 2 additions & 2 deletions auth.go
Expand Up @@ -275,7 +275,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
}
// unlike caching_sha2_password, sha256_password does not accept
// cleartext password on unix transport.
if mc.cfg.tls != nil {
if mc.cfg.TLS != nil {
// write cleartext auth packet
return append([]byte(mc.cfg.Passwd), 0), nil
}
Expand Down Expand Up @@ -351,7 +351,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
}

case cachingSha2PasswordPerformFullAuthentication:
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
if mc.cfg.TLS != nil || mc.cfg.Net == "unix" {
// write cleartext auth packet
err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0))
if err != nil {
Expand Down
10 changes: 5 additions & 5 deletions auth_test.go
Expand Up @@ -291,7 +291,7 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) {

// Hack to make the caching_sha2_password plugin believe that the connection
// is secure
mc.cfg.tls = &tls.Config{InsecureSkipVerify: true}
mc.cfg.TLS = &tls.Config{InsecureSkipVerify: true}

// check written auth response
authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
Expand Down Expand Up @@ -663,7 +663,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) {

// hack to make the caching_sha2_password plugin believe that the connection
// is secure
mc.cfg.tls = &tls.Config{InsecureSkipVerify: true}
mc.cfg.TLS = &tls.Config{InsecureSkipVerify: true}

authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81,
62, 94, 83, 80, 52, 85}
Expand All @@ -676,7 +676,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) {
}

// unset TLS config to prevent the actual establishment of a TLS wrapper
mc.cfg.tls = nil
mc.cfg.TLS = nil

err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
Expand Down Expand Up @@ -866,7 +866,7 @@ func TestAuthSwitchCachingSHA256PasswordFullSecure(t *testing.T) {

// Hack to make the caching_sha2_password plugin believe that the connection
// is secure
mc.cfg.tls = &tls.Config{InsecureSkipVerify: true}
mc.cfg.TLS = &tls.Config{InsecureSkipVerify: true}

// auth switch request
conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95,
Expand Down Expand Up @@ -1299,7 +1299,7 @@ func TestAuthSwitchSHA256PasswordSecure(t *testing.T) {

// Hack to make the caching_sha2_password plugin believe that the connection
// is secure
mc.cfg.tls = &tls.Config{InsecureSkipVerify: true}
mc.cfg.TLS = &tls.Config{InsecureSkipVerify: true}

// auth switch request
conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97,
Expand Down
72 changes: 45 additions & 27 deletions dsn.go
Expand Up @@ -46,22 +46,23 @@ type Config struct {
ServerPubKey string // Server public key name
pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name
tls *tls.Config // TLS configuration
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout

AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
AllowCleartextPasswords bool // Allows the cleartext client side plugin
AllowNativePasswords bool // Allows the native password authentication method
AllowOldPasswords bool // Allows the old insecure password method
CheckConnLiveness bool // Check connections for liveness before using them
ClientFoundRows bool // Return number of matching rows instead of rows changed
ColumnsWithAlias bool // Prepend table alias to column names
InterpolateParams bool // Interpolate placeholders into query string
MultiStatements bool // Allow multiple statements in one query
ParseTime bool // Parse time values to time.Time
RejectReadOnly bool // Reject read-only connections
AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
AllowCleartextPasswords bool // Allows the cleartext client side plugin
AllowFallbackToPlaintext bool // Allows fallback to unencrypted connection if server does not support TLS
AllowNativePasswords bool // Allows the native password authentication method
AllowOldPasswords bool // Allows the old insecure password method
CheckConnLiveness bool // Check connections for liveness before using them
ClientFoundRows bool // Return number of matching rows instead of rows changed
ColumnsWithAlias bool // Prepend table alias to column names
InterpolateParams bool // Interpolate placeholders into query string
MultiStatements bool // Allow multiple statements in one query
ParseTime bool // Parse time values to time.Time
RejectReadOnly bool // Reject read-only connections
}

// NewConfig creates a new Config and sets default values.
Expand All @@ -77,8 +78,8 @@ func NewConfig() *Config {

func (cfg *Config) Clone() *Config {
cp := *cfg
if cp.tls != nil {
cp.tls = cfg.tls.Clone()
if cp.TLS != nil {
cp.TLS = cfg.TLS.Clone()
}
if len(cp.Params) > 0 {
cp.Params = make(map[string]string, len(cfg.Params))
Expand Down Expand Up @@ -119,24 +120,29 @@ func (cfg *Config) normalize() error {
cfg.Addr = ensureHavePort(cfg.Addr)
}

switch cfg.TLSConfig {
case "false", "":
// don't set anything
case "true":
cfg.tls = &tls.Config{}
case "skip-verify", "preferred":
cfg.tls = &tls.Config{InsecureSkipVerify: true}
default:
cfg.tls = getTLSConfigClone(cfg.TLSConfig)
if cfg.tls == nil {
return errors.New("invalid value / unknown config name: " + cfg.TLSConfig)
if cfg.TLS == nil {
switch cfg.TLSConfig {
case "false", "":
// don't set anything
case "true":
cfg.TLS = &tls.Config{}
case "skip-verify":
cfg.TLS = &tls.Config{InsecureSkipVerify: true}
case "preferred":
cfg.TLS = &tls.Config{InsecureSkipVerify: true}
cfg.AllowFallbackToPlaintext = true
default:
cfg.TLS = getTLSConfigClone(cfg.TLSConfig)
if cfg.TLS == nil {
return errors.New("invalid value / unknown config name: " + cfg.TLSConfig)
}
}
}

if cfg.tls != nil && cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify {
if cfg.TLS != nil && cfg.TLS.ServerName == "" && !cfg.TLS.InsecureSkipVerify {
host, _, err := net.SplitHostPort(cfg.Addr)
if err == nil {
cfg.tls.ServerName = host
cfg.TLS.ServerName = host
}
}

Expand Down Expand Up @@ -204,6 +210,10 @@ func (cfg *Config) FormatDSN() string {
writeDSNParam(&buf, &hasParam, "allowCleartextPasswords", "true")
}

if cfg.AllowFallbackToPlaintext {
writeDSNParam(&buf, &hasParam, "allowFallbackToPlaintext", "true")
}

if !cfg.AllowNativePasswords {
writeDSNParam(&buf, &hasParam, "allowNativePasswords", "false")
}
Expand Down Expand Up @@ -391,6 +401,14 @@ func parseDSNParams(cfg *Config, params string) (err error) {
return errors.New("invalid bool value: " + value)
}

// Allow fallback to unencrypted connection if server does not support TLS
case "allowFallbackToPlaintext":
var isBool bool
cfg.AllowFallbackToPlaintext, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}

// Use native password authentication
case "allowNativePasswords":
var isBool bool
Expand Down
47 changes: 24 additions & 23 deletions dsn_test.go
Expand Up @@ -42,8 +42,8 @@ var testDSNs = []struct {
"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216&tls=false&allowCleartextPasswords=true&parseTime=true&rejectReadOnly=true",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, TLSConfig: "false", AllowCleartextPasswords: true, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, CheckConnLiveness: true, ClientFoundRows: true, MaxAllowedPacket: 16777216, ParseTime: true, RejectReadOnly: true},
}, {
"user:password@/dbname?allowNativePasswords=false&checkConnLiveness=false&maxAllowedPacket=0",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false, CheckConnLiveness: false},
"user:password@/dbname?allowNativePasswords=false&checkConnLiveness=false&maxAllowedPacket=0&allowFallbackToPlaintext=true",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowFallbackToPlaintext: true, AllowNativePasswords: false, CheckConnLiveness: false},
}, {
"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local",
&Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true},
Expand Down Expand Up @@ -82,7 +82,7 @@ func TestDSNParser(t *testing.T) {
}

// pointer not static
cfg.tls = nil
cfg.TLS = nil

if !reflect.DeepEqual(cfg, tst.out) {
t.Errorf("%d. ParseDSN(%q) mismatch:\ngot %+v\nwant %+v", i, tst.in, cfg, tst.out)
Expand All @@ -100,6 +100,7 @@ func TestDSNParserInvalid(t *testing.T) {
"User:pass@tcp(1.2.3.4:3306)", // no trailing slash
"net()/", // unknown default addr
"user:pass@tcp(127.0.0.1:3306)/db/name", // invalid dbname
"user:password@/dbname?allowFallbackToPlaintext=PREFERRED", // wrong bool flag
//"/dbname?arg=/some/unescaped/path",
}

Expand All @@ -118,7 +119,7 @@ func TestDSNReformat(t *testing.T) {
t.Error(err.Error())
continue
}
cfg1.tls = nil // pointer not static
cfg1.TLS = nil // pointer not static
res1 := fmt.Sprintf("%+v", cfg1)

dsn2 := cfg1.FormatDSN()
Expand All @@ -127,7 +128,7 @@ func TestDSNReformat(t *testing.T) {
t.Error(err.Error())
continue
}
cfg2.tls = nil // pointer not static
cfg2.TLS = nil // pointer not static
res2 := fmt.Sprintf("%+v", cfg2)

if res1 != res2 {
Expand Down Expand Up @@ -203,7 +204,7 @@ func TestDSNWithCustomTLS(t *testing.T) {

if err != nil {
t.Error(err.Error())
} else if cfg.tls.ServerName != name {
} else if cfg.TLS.ServerName != name {
t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, tst)
}

Expand All @@ -214,7 +215,7 @@ func TestDSNWithCustomTLS(t *testing.T) {

if err != nil {
t.Error(err.Error())
} else if cfg.tls.ServerName != name {
} else if cfg.TLS.ServerName != name {
t.Errorf("did not get the correct ServerName (%s) parsing DSN (%s).", name, tst)
} else if tlsCfg.ServerName != "" {
t.Errorf("tlsCfg was mutated ServerName (%s) should be empty parsing DSN (%s).", name, tst)
Expand All @@ -229,23 +230,23 @@ func TestDSNTLSConfig(t *testing.T) {
if err != nil {
t.Error(err.Error())
}
if cfg.tls == nil {
if cfg.TLS == nil {
t.Error("cfg.tls should not be nil")
}
if cfg.tls.ServerName != expectedServerName {
t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName)
if cfg.TLS.ServerName != expectedServerName {
t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.TLS.ServerName)
}

dsn = "tcp(example.com)/?tls=true"
cfg, err = ParseDSN(dsn)
if err != nil {
t.Error(err.Error())
}
if cfg.tls == nil {
if cfg.TLS == nil {
t.Error("cfg.tls should not be nil")
}
if cfg.tls.ServerName != expectedServerName {
t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.tls.ServerName)
if cfg.TLS.ServerName != expectedServerName {
t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.TLS.ServerName)
}
}

Expand All @@ -262,7 +263,7 @@ func TestDSNWithCustomTLSQueryEscape(t *testing.T) {

if err != nil {
t.Error(err.Error())
} else if cfg.tls.ServerName != name {
} else if cfg.TLS.ServerName != name {
t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, dsn)
}
}
Expand Down Expand Up @@ -335,12 +336,12 @@ func TestCloneConfig(t *testing.T) {
t.Errorf("Config.Clone did not create a separate config struct")
}

if cfg2.tls.ServerName != expectedServerName {
t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName)
if cfg2.TLS.ServerName != expectedServerName {
t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.TLS.ServerName)
}

cfg2.tls.ServerName = "example2.com"
if cfg.tls.ServerName == cfg2.tls.ServerName {
cfg2.TLS.ServerName = "example2.com"
if cfg.TLS.ServerName == cfg2.TLS.ServerName {
t.Errorf("changed cfg.tls.Server name should not propagate to original Config")
}

Expand Down Expand Up @@ -384,20 +385,20 @@ func TestNormalizeTLSConfig(t *testing.T) {

cfg.normalize()

if cfg.tls == nil {
if cfg.TLS == nil {
if tc.want != nil {
t.Fatal("wanted a tls config but got nil instead")
}
return
}

if cfg.tls.ServerName != tc.want.ServerName {
if cfg.TLS.ServerName != tc.want.ServerName {
t.Errorf("tls.ServerName doesn't match (want: '%s', got: '%s')",
tc.want.ServerName, cfg.tls.ServerName)
tc.want.ServerName, cfg.TLS.ServerName)
}
if cfg.tls.InsecureSkipVerify != tc.want.InsecureSkipVerify {
if cfg.TLS.InsecureSkipVerify != tc.want.InsecureSkipVerify {
t.Errorf("tls.InsecureSkipVerify doesn't match (want: %T, got :%T)",
tc.want.InsecureSkipVerify, cfg.tls.InsecureSkipVerify)
tc.want.InsecureSkipVerify, cfg.TLS.InsecureSkipVerify)
}
})
}
Expand Down
12 changes: 6 additions & 6 deletions packets.go
Expand Up @@ -222,9 +222,9 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
if mc.flags&clientProtocol41 == 0 {
return nil, "", ErrOldProtocol
}
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
if mc.cfg.TLSConfig == "preferred" {
mc.cfg.tls = nil
if mc.flags&clientSSL == 0 && mc.cfg.TLS != nil {
if mc.cfg.AllowFallbackToPlaintext {
mc.cfg.TLS = nil
} else {
return nil, "", ErrNoTLS
}
Expand Down Expand Up @@ -292,7 +292,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
}

// To enable TLS / SSL
if mc.cfg.tls != nil {
if mc.cfg.TLS != nil {
clientFlags |= clientSSL
}

Expand Down Expand Up @@ -356,14 +356,14 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string

// SSL Connection Request Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
if mc.cfg.tls != nil {
if mc.cfg.TLS != nil {
// Send TLS / SSL request packet
if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil {
return err
}

// Switch to TLS
tlsConn := tls.Client(mc.netConn, mc.cfg.tls)
tlsConn := tls.Client(mc.netConn, mc.cfg.TLS)
if err := tlsConn.Handshake(); err != nil {
return err
}
Expand Down

0 comments on commit 41dd159

Please sign in to comment.