Skip to content

Commit

Permalink
in function MySQLDriver.Open: replace errBadConnNoWrite with driver.E…
Browse files Browse the repository at this point in the history
…rrBadConn for resend while 'bad connection' happenning
  • Loading branch information
安佳玮 committed Oct 19, 2018
1 parent 361f66e commit 59a6cd6
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 2 deletions.
2 changes: 1 addition & 1 deletion driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
}
if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil {
mc.cleanup()
return nil, err
return nil, mc.markBadConn(err)
}

// Handle response to auth packet, switch methods if possible
Expand Down
16 changes: 16 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2860,3 +2860,19 @@ func TestValuerWithValueReceiverGivenNilValue(t *testing.T) {
// This test will panic on the INSERT if ConvertValue() does not check for typed nil before calling Value()
})
}

func TestWriteHandshakeResponseErr(t *testing.T) {
oldWritter := connWritter
defer func() {
connWritter = oldWritter
}()
connWritter = func(conn net.Conn, data []byte) (int, error) {
return 0, fmt.Errorf("network error")
}

md := MySQLDriver{}
_, err := md.Open(dsn)
if err != driver.ErrBadConn {
t.Fatalf("error is not driver.ErrBadConn: %v", err)
}
}
8 changes: 7 additions & 1 deletion packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@ import (
"fmt"
"io"
"math"
"net"
"time"
)

// connWritter write data with net.Conn, for test mocking
var connWritter = func(conn net.Conn, data []byte) (int, error) {
return conn.Write(data)
}

// Packets documentation:
// http://dev.mysql.com/doc/internals/en/client-server-protocol.html

Expand Down Expand Up @@ -118,7 +124,7 @@ func (mc *mysqlConn) writePacket(data []byte) error {
}
}

n, err := mc.netConn.Write(data[:4+size])
n, err := connWritter(mc.netConn, data[:4+size])
if err == nil && n == 4+size {
mc.sequence++
if size != maxPacketSize {
Expand Down

0 comments on commit 59a6cd6

Please sign in to comment.