diff --git a/AUTHORS b/AUTHORS index 146cdffdd..fb6668346 100644 --- a/AUTHORS +++ b/AUTHORS @@ -88,6 +88,7 @@ Zhenye Xie Barracuda Networks, Inc. Counting Ltd. +GitHub Inc. Google Inc. InfoSum Ltd. Keybase Inc. diff --git a/conncheck.go b/conncheck.go new file mode 100644 index 000000000..fa868e84d --- /dev/null +++ b/conncheck.go @@ -0,0 +1,53 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2019 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build !windows + +package mysql + +import ( + "errors" + "io" + "net" + "syscall" +) + +var errUnexpectedRead = errors.New("unexpected read from socket") + +func connCheck(c net.Conn) error { + var ( + n int + err error + buff [1]byte + ) + + sconn, ok := c.(syscall.Conn) + if !ok { + return nil + } + rc, err := sconn.SyscallConn() + if err != nil { + return err + } + rerr := rc.Read(func(fd uintptr) bool { + n, err = syscall.Read(int(fd), buff[:]) + return true + }) + switch { + case rerr != nil: + return rerr + case n == 0 && err == nil: + return io.EOF + case n > 0: + return errUnexpectedRead + case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK: + return nil + default: + return err + } +} diff --git a/conncheck_test.go b/conncheck_test.go new file mode 100644 index 000000000..b7234b0f5 --- /dev/null +++ b/conncheck_test.go @@ -0,0 +1,38 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build go1.10,!windows + +package mysql + +import ( + "testing" + "time" +) + +func TestStaleConnectionChecks(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("SET @@SESSION.wait_timeout = 2") + + if err := dbt.db.Ping(); err != nil { + dbt.Fatal(err) + } + + // wait for MySQL to close our connection + time.Sleep(3 * time.Second) + + tx, err := dbt.db.Begin() + if err != nil { + dbt.Fatal(err) + } + + if err := tx.Rollback(); err != nil { + dbt.Fatal(err) + } + }) +} diff --git a/conncheck_windows.go b/conncheck_windows.go new file mode 100644 index 000000000..3d9e63f66 --- /dev/null +++ b/conncheck_windows.go @@ -0,0 +1,15 @@ +package mysql + +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2019 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +import "net" + +func connCheck(c net.Conn) error { + return nil +} diff --git a/connection.go b/connection.go index fc4ec7597..265fd4e47 100644 --- a/connection.go +++ b/connection.go @@ -22,6 +22,7 @@ import ( type mysqlConn struct { buf buffer netConn net.Conn + rawConn net.Conn // underlying connection when netConn is TLS connection. affectedRows uint64 insertId uint64 cfg *Config @@ -32,6 +33,7 @@ type mysqlConn struct { status statusFlag sequence uint8 parseTime bool + reset bool // set when the Go SQL package calls ResetSession // for context support (Go 1.8+) watching bool @@ -639,5 +641,6 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error { if mc.closed.IsSet() { return driver.ErrBadConn } + mc.reset = true return nil } diff --git a/packets.go b/packets.go index 5e0853767..70d0d71f5 100644 --- a/packets.go +++ b/packets.go @@ -96,6 +96,25 @@ func (mc *mysqlConn) writePacket(data []byte) error { return ErrPktTooLarge } + // Perform a stale connection check. We only perform this check for + // the first query on a connection that has been checked out of the + // connection pool: a fresh connection from the pool is more likely + // to be stale, and it has not performed any previous writes that + // could cause data corruption, so it's safe to return ErrBadConn + // if the check fails. + if mc.reset { + mc.reset = false + conn := mc.netConn + if mc.rawConn != nil { + conn = mc.rawConn + } + if err := connCheck(conn); err != nil { + errLog.Print("closing bad idle connection: ", err) + mc.Close() + return driver.ErrBadConn + } + } + for { var size int if pktLen >= maxPacketSize { @@ -332,6 +351,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string if err := tlsConn.Handshake(); err != nil { return err } + mc.rawConn = mc.netConn mc.netConn = tlsConn mc.buf.nc = tlsConn }