Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue 1567 (#1570) #1571

Merged
merged 1 commit into from Mar 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions connection.go
Expand Up @@ -152,11 +152,11 @@ func (mc *mysqlConn) cleanup() {

// Makes cleanup idempotent
close(mc.closech)
nc := mc.netConn
if nc == nil {
conn := mc.rawConn
if conn == nil {
return
}
if err := nc.Close(); err != nil {
if err := conn.Close(); err != nil {
mc.log(err)
}
// This function can be called from multiple goroutines.
Expand Down
2 changes: 1 addition & 1 deletion connector.go
Expand Up @@ -102,10 +102,10 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
nd := net.Dialer{Timeout: mc.cfg.Timeout}
mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)
}

if err != nil {
return nil, err
}
mc.rawConn = mc.netConn

// Enable TCP Keepalives on TCP connections
if tc, ok := mc.netConn.(*net.TCPConn); ok {
Expand Down
33 changes: 33 additions & 0 deletions driver_test.go
Expand Up @@ -20,6 +20,7 @@ import (
"io"
"log"
"math"
mrand "math/rand"
"net"
"net/url"
"os"
Expand Down Expand Up @@ -3577,3 +3578,35 @@ func runCallCommand(dbt *DBTest, query, name string) {
}
}
}

func TestIssue1567(t *testing.T) {
// enable TLS.
runTests(t, dsn+"&tls=skip-verify", func(dbt *DBTest) {
// disable connection pooling.
// data race happens when new connection is created.
dbt.db.SetMaxIdleConns(0)

// estimate round trip time.
start := time.Now()
if err := dbt.db.PingContext(context.Background()); err != nil {
t.Fatal(err)
}
rtt := time.Since(start)
if rtt <= 0 {
// In some environments, rtt may become 0, so set it to at least 1ms.
rtt = time.Millisecond
}

count := 1000
if testing.Short() {
count = 10
}

for i := 0; i < count; i++ {
timeout := time.Duration(mrand.Int63n(int64(rtt)))
ctx, cancel := context.WithTimeout(context.Background(), timeout)
dbt.db.PingContext(ctx)
cancel()
}
})
}
1 change: 0 additions & 1 deletion packets.go
Expand Up @@ -351,7 +351,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
if err := tlsConn.Handshake(); err != nil {
return err
}
mc.rawConn = mc.netConn
mc.netConn = tlsConn
mc.buf.nc = tlsConn
}
Expand Down