Skip to content

Commit

Permalink
Fix issue 1567 (#1570)
Browse files Browse the repository at this point in the history
### Description

closes #1567

When TLS is enabled, `mc.netConn` is rewritten after the TLS handshak as
detailed here:


https://github.com/go-sql-driver/mysql/blob/d86c4527bae98ccd4e5060f72887520ce30eda5e/packets.go#L355

Therefore, `mc.netConn` should not be accessed within the watcher
goroutine.
Instead, `mc.rawConn` should be initialized prior to invoking
`mc.startWatcher`, and `mc.rawConn` should be used in lieu of
`mc.netConn`.

### Checklist
- [x] Code compiles correctly
- [x] Created tests which fail without the change (if possible)
- [x] All tests passing
- [x] Extended the README / documentation, if necessary
- [x] Added myself / the copyright holder to the AUTHORS file


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **Refactor**
	- Improved variable naming for better code readability and maintenance.
	- Enhanced network connection handling logic.
- **New Features**
	- Updated TCP connection handling to better support TCP Keepalives.
- **Tests**
- Added a new test to address and verify the fix for a specific issue
related to TLS, connection pooling, and round trip time estimation.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
shogo82148 committed Mar 22, 2024
1 parent d86c452 commit d7ddb8b
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 5 deletions.
6 changes: 3 additions & 3 deletions connection.go
Expand Up @@ -153,11 +153,11 @@ func (mc *mysqlConn) cleanup() {

// Makes cleanup idempotent
close(mc.closech)
nc := mc.netConn
if nc == nil {
conn := mc.rawConn
if conn == nil {
return
}
if err := nc.Close(); err != nil {
if err := conn.Close(); err != nil {
mc.log(err)
}
// This function can be called from multiple goroutines.
Expand Down
2 changes: 1 addition & 1 deletion connector.go
Expand Up @@ -102,10 +102,10 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
nd := net.Dialer{Timeout: mc.cfg.Timeout}
mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)
}

if err != nil {
return nil, err
}
mc.rawConn = mc.netConn

// Enable TCP Keepalives on TCP connections
if tc, ok := mc.netConn.(*net.TCPConn); ok {
Expand Down
33 changes: 33 additions & 0 deletions driver_test.go
Expand Up @@ -20,6 +20,7 @@ import (
"io"
"log"
"math"
mrand "math/rand"
"net"
"net/url"
"os"
Expand Down Expand Up @@ -3577,3 +3578,35 @@ func runCallCommand(dbt *DBTest, query, name string) {
}
}
}

func TestIssue1567(t *testing.T) {
// enable TLS.
runTests(t, dsn+"&tls=skip-verify", func(dbt *DBTest) {
// disable connection pooling.
// data race happens when new connection is created.
dbt.db.SetMaxIdleConns(0)

// estimate round trip time.
start := time.Now()
if err := dbt.db.PingContext(context.Background()); err != nil {
t.Fatal(err)
}
rtt := time.Since(start)
if rtt <= 0 {
// In some environments, rtt may become 0, so set it to at least 1ms.
rtt = time.Millisecond
}

count := 1000
if testing.Short() {
count = 10
}

for i := 0; i < count; i++ {
timeout := time.Duration(mrand.Int63n(int64(rtt)))
ctx, cancel := context.WithTimeout(context.Background(), timeout)
dbt.db.PingContext(ctx)
cancel()
}
})
}
1 change: 0 additions & 1 deletion packets.go
Expand Up @@ -351,7 +351,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
if err := tlsConn.Handshake(); err != nil {
return err
}
mc.rawConn = mc.netConn
mc.netConn = tlsConn
mc.buf.nc = tlsConn
}
Expand Down

0 comments on commit d7ddb8b

Please sign in to comment.