diff --git a/connection.go b/connection.go index 90aec6439..39f24b15d 100644 --- a/connection.go +++ b/connection.go @@ -182,7 +182,15 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { columnCount, err := stmt.readPrepareResultPacket() if err == nil { if stmt.paramCount > 0 { - if err = mc.readUntilEOF(); err != nil { + // FIXME - seems like a bug in MySQL (or it's intended). + // There's no EOF return after parameters. + // However, this behavior isn't consistent to Maria DB. + if mc.flags&clientDeprecateEOF == 0 { + if err = mc.readUntilEOF(); err != nil { + return nil, err + } + } + if err = mc.readExactPackets(stmt.paramCount); err != nil { return nil, err } } diff --git a/packets.go b/packets.go index 6664e5ae5..b3d3b4850 100644 --- a/packets.go +++ b/packets.go @@ -235,10 +235,15 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro if len(data) > pos { // character set [1 byte] // status flags [2 bytes] + pos += 1 + 2 + // capability flags (upper 2 bytes) [2 bytes] + mc.flags |= clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16 + pos += 2 + // length of auth-plugin-data [1 byte] // reserved (all [00]) [10 bytes] - pos += 1 + 2 + 2 + 1 + 10 + pos += 1 + 10 // second part of the password cipher [mininum 13 bytes], // where len=MAX(13, length of auth-plugin-data - 8) @@ -286,6 +291,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string clientLocalFiles | clientPluginAuth | clientMultiResults | + mc.flags&clientDeprecateEOF | mc.flags&clientLongFlag if mc.cfg.ClientFoundRows { @@ -608,20 +614,21 @@ func readStatus(b []byte) statusFlag { } // Ok Packet -// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet +// https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html func (mc *mysqlConn) handleOkPacket(data []byte) error { - var n, m int - - // 0x00 [1 byte] - + // 0x00 or 0xFE [1 byte] + n := 1 + var l int // Affected rows [Length Coded Binary] - mc.affectedRows, _, n = readLengthEncodedInteger(data[1:]) + mc.affectedRows, _, l = readLengthEncodedInteger(data[n:]) + n += l // Insert id [Length Coded Binary] - mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) + mc.insertId, _, l = readLengthEncodedInteger(data[n:]) + n += l // server_status [2 bytes] - mc.status = readStatus(data[1+n+m : 1+n+m+2]) + mc.status = readStatus(data[n : n+2]) if mc.status&statusMoreResultsExists != 0 { return nil } @@ -631,19 +638,36 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { return nil } +// isEOFPacket will return true if the data is either a EOF-Packet or OK-Packet +// acting as an EOF. +func (mc *mysqlConn) isEOFPacket(data []byte) bool { + // Legacy EOF packet + if data[0] == iEOF && (len(data) == 5 || len(data) == 1) && mc.flags&clientDeprecateEOF == 0 { + return true + } + return data[0] == iEOF && len(data) < 9 && mc.flags&clientDeprecateEOF != 0 +} + // Read Packets as Field Packets until EOF-Packet or an Error appears // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { columns := make([]mysqlField, count) - for i := 0; ; i++ { + // If we set clientDeprecateEOF capability flag, + // the EOF will be no longer sent after all columns. + packets := count + if mc.flags&clientDeprecateEOF == 0 { + // Legacy way, read one more EOF packet. + packets += 1 + } + + for i := 0; i < packets; i++ { data, err := mc.readPacket() if err != nil { return nil, err } - // EOF Packet - if data[0] == iEOF && (len(data) == 5 || len(data) == 1) { + if mc.isEOFPacket(data) { if i == count { return columns, nil } @@ -729,9 +753,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:]) //} } + return columns, nil } -// Read Packets as Field Packets until EOF-Packet or an Error appears +// Read Packets as Field Packets until EOF/OK-Packet or an Error appears // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow func (rows *textRows) readRow(dest []driver.Value) error { mc := rows.mc @@ -746,9 +771,16 @@ func (rows *textRows) readRow(dest []driver.Value) error { } // EOF Packet - if data[0] == iEOF && len(data) == 5 { - // server_status [2 bytes] - rows.mc.status = readStatus(data[3:]) + if mc.isEOFPacket(data) { + if mc.flags&clientDeprecateEOF == 0 { + // server_status [2 bytes] + rows.mc.status = readStatus(data[3:]) + } else { + if err := mc.handleOkPacket(data); err != nil { + rows.mc = nil + return err + } + } rows.rs.done = true if !rows.HasNextResultSet() { rows.mc = nil @@ -808,16 +840,27 @@ func (mc *mysqlConn) readUntilEOF() error { return err } - switch data[0] { - case iERR: + switch { + case data[0] == iERR: return mc.handleErrorPacket(data) - case iEOF: - if len(data) == 5 { + case mc.isEOFPacket(data): + if mc.flags&clientDeprecateEOF == 0 { mc.status = readStatus(data[3:]) + return nil } - return nil + return mc.handleOkPacket(data) + } + } +} + +func (mc *mysqlConn) readExactPackets(num int) error { + for i := 0; i < num; i++ { + _, err := mc.readPacket() + if err != nil { + return err } } + return nil } /****************************************************************************** @@ -1178,15 +1221,22 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // packet indicator [1 byte] if data[0] != iOK { - // EOF Packet - if data[0] == iEOF && len(data) == 5 { - rows.mc.status = readStatus(data[3:]) + if rows.mc.isEOFPacket(data) { + if rows.mc.flags&clientDeprecateEOF == 0 { + rows.mc.status = readStatus(data[3:]) + } else { + if err := rows.mc.handleOkPacket(data); err != nil { + rows.mc = nil + return err + } + } rows.rs.done = true if !rows.HasNextResultSet() { rows.mc = nil } return io.EOF } + mc := rows.mc rows.mc = nil diff --git a/rows.go b/rows.go index 888bdb5f0..1599ee03b 100644 --- a/rows.go +++ b/rows.go @@ -215,7 +215,6 @@ func (rows *textRows) Next(dest []driver.Value) error { if err := mc.error(); err != nil { return err } - // Fetch next row from stream return rows.readRow(dest) } diff --git a/statement.go b/statement.go index 18a3ae498..cc7c93f36 100644 --- a/statement.go +++ b/statement.go @@ -73,10 +73,9 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { if resLen > 0 { // Columns - if err = mc.readUntilEOF(); err != nil { + if err = mc.readExactPackets(resLen); err != nil { return nil, err } - // Rows if err := mc.readUntilEOF(); err != nil { return nil, err