diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c5b2aa31..ea8a972b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -83,7 +83,7 @@ jobs: my-cnf: | innodb_log_file_size=256MB innodb_buffer_pool_size=512MB - max_allowed_packet=16MB + max_allowed_packet=48MB ; TestConcurrent fails if max_connections is too large max_connections=50 local_infile=1 diff --git a/AUTHORS b/AUTHORS index 4021b96c..cbe4316d 100644 --- a/AUTHORS +++ b/AUTHORS @@ -21,6 +21,7 @@ Animesh Ray Arne Hormann Ariel Mashraki Asta Xie +B Lamarche Brian Hendriks Bulat Gaifullin Caine Jette @@ -60,6 +61,7 @@ Jennifer Purevsuren Jerome Meyer Jiajia Zhong Jian Zhen +Joe Mann Joshua Prunier Julien Lefevre Julien Schmidt diff --git a/README.md b/README.md index 6c6abf9c..5267858a 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac * Secure `LOAD DATA LOCAL INFILE` support with file allowlisting and `io.Reader` support * Optional `time.Time` parsing * Optional placeholder interpolation + * Supports zlib compression. ## Requirements @@ -267,6 +268,16 @@ SELECT u.id FROM users as u will return `u.id` instead of just `id` if `columnsWithAlias=true`. +##### `compress` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + +Toggles zlib compression. false by default. + ##### `interpolateParams` ``` @@ -310,6 +321,15 @@ Default: 64*1024*1024 Max packet size allowed in bytes. The default value is 64 MiB and should be adjusted to match the server settings. `maxAllowedPacket=0` can be used to automatically fetch the `max_allowed_packet` variable from server *on every connection*. +##### `minCompressLength` + +``` +Type: decimal number +Default: 50 +``` + +Min packet size in bytes to compress, when compression is enabled (see the `compress` parameter). Packets smaller than this will be sent uncompressed. + ##### `multiStatements` ``` diff --git a/benchmark_test.go b/benchmark_test.go index a4ecc0a6..cb2a2bea 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -46,9 +46,13 @@ func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt { return stmt } -func initDB(b *testing.B, queries ...string) *sql.DB { +func initDB(b *testing.B, useCompression bool, queries ...string) *sql.DB { tb := (*TB)(b) - db := tb.checkDB(sql.Open(driverNameTest, dsn)) + comprStr := "" + if useCompression { + comprStr = "&compress=1" + } + db := tb.checkDB(sql.Open(driverNameTest, dsn+comprStr)) for _, query := range queries { if _, err := db.Exec(query); err != nil { b.Fatalf("error on %q: %v", query, err) @@ -60,10 +64,18 @@ func initDB(b *testing.B, queries ...string) *sql.DB { const concurrencyLevel = 10 func BenchmarkQuery(b *testing.B) { + benchmarkQueryHelper(b, false) +} + +func BenchmarkQueryCompression(b *testing.B) { + benchmarkQueryHelper(b, true) +} + +func benchmarkQueryHelper(b *testing.B, compr bool) { tb := (*TB)(b) b.StopTimer() b.ReportAllocs() - db := initDB(b, + db := initDB(b, compr, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", `INSERT INTO foo VALUES (1, "one")`, @@ -224,6 +236,7 @@ func BenchmarkInterpolation(b *testing.B) { maxWriteSize: maxPacketSize - 1, buf: newBuffer(nil), } + mc.packetReader = &mc.buf args := []driver.Value{ int64(42424242), @@ -269,7 +282,7 @@ func benchmarkQueryContext(b *testing.B, db *sql.DB, p int) { } func BenchmarkQueryContext(b *testing.B) { - db := initDB(b, + db := initDB(b, false, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", `INSERT INTO foo VALUES (1, "one")`, @@ -305,7 +318,7 @@ func benchmarkExecContext(b *testing.B, db *sql.DB, p int) { } func BenchmarkExecContext(b *testing.B) { - db := initDB(b, + db := initDB(b, false, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", `INSERT INTO foo VALUES (1, "one")`, @@ -323,7 +336,7 @@ func BenchmarkExecContext(b *testing.B) { // "size=" means size of each blobs. func BenchmarkQueryRawBytes(b *testing.B) { var sizes []int = []int{100, 1000, 2000, 4000, 8000, 12000, 16000, 32000, 64000, 256000} - db := initDB(b, + db := initDB(b, false, "DROP TABLE IF EXISTS bench_rawbytes", "CREATE TABLE bench_rawbytes (id INT PRIMARY KEY, val LONGBLOB)", ) @@ -376,7 +389,7 @@ func BenchmarkQueryRawBytes(b *testing.B) { // BenchmarkReceiveMassiveRows measures performance of receiving large number of rows. func BenchmarkReceiveMassiveRows(b *testing.B) { // Setup -- prepare 10000 rows. - db := initDB(b, + db := initDB(b, false, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val TEXT)") defer db.Close() diff --git a/compress.go b/compress.go new file mode 100644 index 00000000..54289a61 --- /dev/null +++ b/compress.go @@ -0,0 +1,246 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2024 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "compress/zlib" + "fmt" + "io" + "sync" +) + +var ( + zrPool *sync.Pool // Do not use directly. Use zDecompress() instead. + zwPool *sync.Pool // Do not use directly. Use zCompress() instead. +) + +func init() { + zrPool = &sync.Pool{ + New: func() any { return nil }, + } + zwPool = &sync.Pool{ + New: func() any { + zw, err := zlib.NewWriterLevel(new(bytes.Buffer), 2) + if err != nil { + panic(err) // compress/zlib return non-nil error only if level is invalid + } + return zw + }, + } +} + +func zDecompress(src, dst []byte) (int, error) { + br := bytes.NewReader(src) + var zr io.ReadCloser + var err error + + if a := zrPool.Get(); a == nil { + if zr, err = zlib.NewReader(br); err != nil { + return 0, err + } + } else { + zr = a.(io.ReadCloser) + if zr.(zlib.Resetter).Reset(br, nil); err != nil { + return 0, err + } + } + defer func() { + zr.Close() + zrPool.Put(zr) + }() + + lenRead := 0 + size := len(dst) + + for lenRead < size { + n, err := zr.Read(dst[lenRead:]) + lenRead += n + + if err == io.EOF { + if lenRead < size { + return lenRead, io.ErrUnexpectedEOF + } + } else if err != nil { + return lenRead, err + } + } + return lenRead, nil +} + +func zCompress(src []byte, dst io.Writer) error { + zw := zwPool.Get().(*zlib.Writer) + zw.Reset(dst) + if _, err := zw.Write(src); err != nil { + return err + } + zw.Close() + zwPool.Put(zw) + return nil +} + +type decompressor struct { + mc *mysqlConn + // read buffer (FIFO). + // We can not reuse already-read buffer until dropping Go 1.20 support. + // It is because of database/mysql's weired behavior. + // See https://github.com/go-sql-driver/mysql/issues/1435 + bytesBuf []byte +} + +func newDecompressor(mc *mysqlConn) *decompressor { + return &decompressor{ + mc: mc, + } +} + +func (c *decompressor) readNext(need int) ([]byte, error) { + for len(c.bytesBuf) < need { + if err := c.uncompressPacket(); err != nil { + return nil, err + } + } + + data := c.bytesBuf[:need:need] // prevent caller writes into r.bytesBuf + c.bytesBuf = c.bytesBuf[need:] + return data, nil +} + +func (c *decompressor) uncompressPacket() error { + header, err := c.mc.buf.readNext(7) // size of compressed header + if err != nil { + return err + } + + // compressed header structure + comprLength := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) + uncompressedLength := int(uint32(header[4]) | uint32(header[5])<<8 | uint32(header[6])<<16) + compressionSequence := uint8(header[3]) + if debugTrace { + traceLogger.Printf("uncompress cmplen=%v uncomplen=%v pkt_cmp_seq=%v expected_cmp_seq=%v\n", + comprLength, uncompressedLength, compressionSequence, c.mc.sequence) + } + if compressionSequence != c.mc.sequence { + // return ErrPktSync + // server may return error packet (e.g. 1153 Got a packet bigger than 'max_allowed_packet' bytes) + // before receiving all packets from client. In this case, seqnr is younger than expected. + if debugTrace { + traceLogger.Printf("WARN: unexpected cmpress seq nr: expected %v, got %v", + c.mc.sequence, compressionSequence) + } + // TODO(methane): report error when the packet is not an error packet. + } + c.mc.sequence = compressionSequence + 1 + c.mc.compressSequence = c.mc.sequence + + comprData, err := c.mc.buf.readNext(comprLength) + if err != nil { + return err + } + + // if payload is uncompressed, its length will be specified as zero, and its + // true length is contained in comprLength + if uncompressedLength == 0 { + c.bytesBuf = append(c.bytesBuf, comprData...) + return nil + } + + // use existing capacity in bytesBuf if possible + offset := len(c.bytesBuf) + if cap(c.bytesBuf)-offset < uncompressedLength { + old := c.bytesBuf + c.bytesBuf = make([]byte, offset, offset+uncompressedLength) + copy(c.bytesBuf, old) + } + + lenRead, err := zDecompress(comprData, c.bytesBuf[offset:offset+uncompressedLength]) + if err != nil { + return err + } + if lenRead != uncompressedLength { + return fmt.Errorf("invalid compressed packet: uncompressed length in header is %d, actual %d", + uncompressedLength, lenRead) + } + c.bytesBuf = c.bytesBuf[:offset+uncompressedLength] + return nil +} + +const maxPayloadLen = maxPacketSize - 4 + +// writeCompressed sends one or some packets with compression. +// Use this instead of mc.netConn.Write() when mc.compress is true. +func (mc *mysqlConn) writeCompressed(packets []byte) (int, error) { + totalBytes := len(packets) + dataLen := len(packets) + blankHeader := make([]byte, 7) + var buf bytes.Buffer + + for dataLen > 0 { + payloadLen := dataLen + if payloadLen > maxPayloadLen { + payloadLen = maxPayloadLen + } + payload := packets[:payloadLen] + uncompressedLen := payloadLen + + if _, err := buf.Write(blankHeader); err != nil { + return 0, err + } + + // If payload is less than minCompressLength, don't compress. + if uncompressedLen < minCompressLength { + if _, err := buf.Write(payload); err != nil { + return 0, err + } + uncompressedLen = 0 + } else { + zCompress(payload, &buf) + } + + if err := mc.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil { + return 0, err + } + dataLen -= payloadLen + packets = packets[payloadLen:] + buf.Reset() + } + + return totalBytes, nil +} + +// writeCompressedPacket writes a compressed packet with header. +// data should start with 7 size space for header followed by payload. +func (mc *mysqlConn) writeCompressedPacket(data []byte, uncompressedLen int) error { + comprLength := len(data) - 7 + if debugTrace { + traceLogger.Printf( + "writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v", + comprLength, uncompressedLen, mc.compressSequence) + } + + // compression header + data[0] = byte(0xff & comprLength) + data[1] = byte(0xff & (comprLength >> 8)) + data[2] = byte(0xff & (comprLength >> 16)) + + data[3] = mc.compressSequence + + // this value is never greater than maxPayloadLength + data[4] = byte(0xff & uncompressedLen) + data[5] = byte(0xff & (uncompressedLen >> 8)) + data[6] = byte(0xff & (uncompressedLen >> 16)) + + if _, err := mc.netConn.Write(data); err != nil { + mc.log("writing compressed packet:", err) + return err + } + + mc.compressSequence++ + return nil +} diff --git a/compress_test.go b/compress_test.go new file mode 100644 index 00000000..6d81db33 --- /dev/null +++ b/compress_test.go @@ -0,0 +1,124 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2024 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "crypto/rand" + "fmt" + "io" + "testing" +) + +func makeRandByteSlice(size int) []byte { + randBytes := make([]byte, size) + rand.Read(randBytes) + return randBytes +} + +// compressHelper compresses uncompressedPacket and checks state variables +func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []byte { + conn := new(mockConn) + mc.netConn = conn + + n, err := mc.writeCompressed(uncompressedPacket) + if err != nil { + t.Fatal(err) + } + if n != len(uncompressedPacket) { + t.Fatalf("expected to write %d bytes, wrote %d bytes", len(uncompressedPacket), n) + } + return conn.written +} + +// uncompressHelper uncompresses compressedPacket and checks state variables +func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expSize int) []byte { + // mocking out buf variable + conn := new(mockConn) + conn.data = compressedPacket + mc.buf.nc = conn + cr := newDecompressor(mc) + + uncompressedPacket, err := cr.readNext(expSize) + if err != nil { + if err != io.EOF { + t.Fatalf("non-nil/non-EOF error when reading contents: %s", err.Error()) + } + } + if len(uncompressedPacket) != expSize { + t.Errorf("uncompressed size is unexpected. expected %d but got %d", expSize, len(uncompressedPacket)) + } + return uncompressedPacket +} + +// roundtripHelper compresses then uncompresses uncompressedPacket and checks state variables +func roundtripHelper(t *testing.T, cSend *mysqlConn, cReceive *mysqlConn, uncompressedPacket []byte) []byte { + compressed := compressHelper(t, cSend, uncompressedPacket) + return uncompressHelper(t, cReceive, compressed, len(uncompressedPacket)) +} + +// TestRoundtrip tests two connections, where one is reading and the other is writing +func TestRoundtrip(t *testing.T) { + tests := []struct { + uncompressed []byte + desc string + }{ + {uncompressed: []byte("a"), + desc: "a"}, + {uncompressed: []byte{0}, + desc: "0 byte"}, + {uncompressed: []byte("hello world"), + desc: "hello world"}, + {uncompressed: make([]byte, 100), + desc: "100 bytes"}, + {uncompressed: make([]byte, 32768), + desc: "32768 bytes"}, + {uncompressed: make([]byte, 330000), + desc: "33000 bytes"}, + {uncompressed: make([]byte, 0), + desc: "nothing"}, + {uncompressed: makeRandByteSlice(10), + desc: "10 rand bytes", + }, + {uncompressed: makeRandByteSlice(100), + desc: "100 rand bytes", + }, + {uncompressed: makeRandByteSlice(32768), + desc: "32768 rand bytes", + }, + {uncompressed: bytes.Repeat(makeRandByteSlice(100), 10000), + desc: "100 rand * 10000 repeat bytes", + }, + } + + _, cSend := newRWMockConn(0) + cSend.compress = true + _, cReceive := newRWMockConn(0) + cReceive.compress = true + + for _, test := range tests { + s := fmt.Sprintf("Test roundtrip with %s", test.desc) + cSend.resetSequenceNr() + cReceive.resetSequenceNr() + + uncompressed := roundtripHelper(t, cSend, cReceive, test.uncompressed) + if !bytes.Equal(uncompressed, test.uncompressed) { + t.Fatalf("%s: roundtrip failed", s) + } + + if cSend.sequence != cReceive.sequence { + t.Errorf("inconsistent sequence number: send=%v recv=%v", + cSend.sequence, cReceive.sequence) + } + if cSend.compressSequence != cReceive.compressSequence { + t.Errorf("inconsistent compress sequence number: send=%v recv=%v", + cSend.compressSequence, cReceive.compressSequence) + } + } +} diff --git a/connection.go b/connection.go index 7b8abeb0..758dba0c 100644 --- a/connection.go +++ b/connection.go @@ -26,6 +26,7 @@ type mysqlConn struct { netConn net.Conn rawConn net.Conn // underlying connection when netConn is TLS connection. result mysqlResult // managed by clearResult() and handleOkPacket(). + packetReader packetReader cfg *Config connector *connector maxAllowedPacket int @@ -34,7 +35,9 @@ type mysqlConn struct { flags clientFlag status statusFlag sequence uint8 + compressSequence uint8 parseTime bool + compress bool // for context support (Go 1.8+) watching bool @@ -50,6 +53,26 @@ func (mc *mysqlConn) log(v ...any) { mc.cfg.Logger.Print(v...) } +type packetReader interface { + readNext(need int) ([]byte, error) +} + +func (mc *mysqlConn) resetSequenceNr() { + mc.sequence = 0 + mc.compressSequence = 0 +} + +// syncSequenceNr must be called when finished writing some packet and before start reading. +func (mc *mysqlConn) syncSequenceNr() { + // Syncs compressionSequence to sequence. + // This is not documented but done in `net_flush()` in MySQL and MariaDB. + // https://github.com/mariadb-corporation/mariadb-connector-c/blob/8228164f850b12353da24df1b93a1e53cc5e85e9/libmariadb/ma_net.c#L170-L171 + // https://github.com/mysql/mysql-server/blob/824e2b4064053f7daf17d7f3f84b7a3ed92e5fb4/sql-common/net_serv.cc#L293 + if mc.compress { + mc.sequence = mc.compressSequence + } +} + // Handles parameters set in DSN after the connection is established func (mc *mysqlConn) handleParams() (err error) { var cmdSet strings.Builder @@ -158,7 +181,7 @@ func (mc *mysqlConn) cleanup() { return } if err := conn.Close(); err != nil { - mc.log(err) + mc.log("closing connection:", err) } // This function can be called from multiple goroutines. // So we can not mc.clearResult() here. diff --git a/connection_test.go b/connection_test.go index c59cb617..4a316d46 100644 --- a/connection_test.go +++ b/connection_test.go @@ -25,6 +25,7 @@ func TestInterpolateParams(t *testing.T) { InterpolateParams: true, }, } + mc.packetReader = &mc.buf q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"}) if err != nil { @@ -72,6 +73,7 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { InterpolateParams: true, }, } + mc.packetReader = &mc.buf q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42)}) if err != driver.ErrSkip { @@ -90,6 +92,8 @@ func TestInterpolateParamsPlaceholderInString(t *testing.T) { }, } + mc.packetReader = &mc.buf + q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)}) // When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42` if err != driver.ErrSkip { @@ -159,9 +163,12 @@ func TestCleanCancel(t *testing.T) { func TestPingMarkBadConnection(t *testing.T) { nc := badConnection{err: errors.New("boom")} + + buf := newBuffer(nc) mc := &mysqlConn{ netConn: nc, - buf: newBuffer(nc), + buf: buf, + packetReader: &buf, maxAllowedPacket: defaultMaxAllowedPacket, } @@ -174,9 +181,12 @@ func TestPingMarkBadConnection(t *testing.T) { func TestPingErrInvalidConn(t *testing.T) { nc := badConnection{err: errors.New("failed to write"), n: 10} + + buf := newBuffer(nc) mc := &mysqlConn{ netConn: nc, - buf: newBuffer(nc), + buf: buf, + packetReader: &buf, maxAllowedPacket: defaultMaxAllowedPacket, closech: make(chan struct{}), cfg: NewConfig(), diff --git a/connector.go b/connector.go index b6707759..c5b54524 100644 --- a/connector.go +++ b/connector.go @@ -123,6 +123,8 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { defer mc.finish() mc.buf = newBuffer(mc.netConn) + // packet reader and writer in handshake are never compressed + mc.packetReader = &mc.buf // Set I/O timeouts mc.buf.timeout = mc.cfg.ReadTimeout @@ -165,6 +167,10 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } + if mc.cfg.compress && mc.flags&clientCompress == clientCompress { + mc.compress = true + mc.packetReader = newDecompressor(mc) + } if mc.cfg.MaxAllowedPacket > 0 { mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket } else { diff --git a/const.go b/const.go index 22526e03..58d64c61 100644 --- a/const.go +++ b/const.go @@ -11,11 +11,14 @@ package mysql import "runtime" const ( + debugTrace = false // for debugging wire protocol. + defaultAuthPlugin = "mysql_native_password" defaultMaxAllowedPacket = 64 << 20 // 64 MiB. See https://github.com/go-sql-driver/mysql/issues/1355 minProtocolVersion = 10 maxPacketSize = 1<<24 - 1 timeFormat = "2006-01-02 15:04:05.999999" + minCompressLength = 150 // Connection attributes // See https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html#performance-schema-connection-attributes-available diff --git a/driver_test.go b/driver_test.go index 4fd196d4..e73fca31 100644 --- a/driver_test.go +++ b/driver_test.go @@ -147,7 +147,7 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { db, err := sql.Open(driverNameTest, dsn) if err != nil { - t.Fatalf("error connecting: %s", err.Error()) + t.Fatalf("connecting %q: %s", dsn, err) } defer db.Close() @@ -160,11 +160,19 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation { db2, err = sql.Open(driverNameTest, dsn2) if err != nil { - t.Fatalf("error connecting: %s", err.Error()) + t.Fatalf("connecting %q: %s", dsn2, err) } defer db2.Close() } + dsn3 := dsn + "&compress=true" + var db3 *sql.DB + db3, err = sql.Open(driverNameTest, dsn3) + if err != nil { + t.Fatalf("connecting %q: %s", dsn3, err) + } + defer db3.Close() + for _, test := range tests { test := test t.Run("default", func(t *testing.T) { @@ -179,6 +187,11 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { test(dbt2) }) } + t.Run("compress", func(t *testing.T) { + dbt3 := &DBTest{t, db3} + t.Cleanup(cleanup) + test(dbt3) + }) } } @@ -1265,7 +1278,8 @@ func TestLongData(t *testing.T) { var rows *sql.Rows // Long text data - const nonDataQueryLen = 28 // length query w/o value + // const nonDataQueryLen = 28 // length query w/o value + compress header + const nonDataQueryLen = 100 inS := in[:maxAllowedPacketSize-nonDataQueryLen] dbt.mustExec("INSERT INTO test VALUES('" + inS + "')") rows = dbt.mustQuery("SELECT value FROM test") @@ -3540,6 +3554,10 @@ func TestConnectionAttributes(t *testing.T) { func TestErrorInMultiResult(t *testing.T) { // https://github.com/go-sql-driver/mysql/issues/1361 + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + var db *sql.DB if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { db, err = sql.Open("mysql", dsn) diff --git a/dsn.go b/dsn.go index 65f5a024..d9d9b8e5 100644 --- a/dsn.go +++ b/dsn.go @@ -70,7 +70,10 @@ type Config struct { ParseTime bool // Parse time values to time.Time RejectReadOnly bool // Reject read-only connections - // unexported fields. new options should be come here + // unexported fields. new options should be come here. + // boolean first. alphabetical order. + + compress bool // Enable zlib compression beforeConnect func(context.Context, *Config) error // Invoked before a connection is established pubKey *rsa.PublicKey // Server public key @@ -90,7 +93,6 @@ func NewConfig() *Config { AllowNativePasswords: true, CheckConnLiveness: true, } - return cfg } @@ -122,6 +124,14 @@ func BeforeConnect(fn func(context.Context, *Config) error) Option { } } +// EnableCompress sets the compression mode. +func EnableCompression(yes bool) Option { + return func(cfg *Config) error { + cfg.compress = yes + return nil + } +} + func (cfg *Config) Clone() *Config { cp := *cfg if cp.TLS != nil { @@ -290,6 +300,10 @@ func (cfg *Config) FormatDSN() string { writeDSNParam(&buf, &hasParam, "columnsWithAlias", "true") } + if cfg.compress { + writeDSNParam(&buf, &hasParam, "compress", "true") + } + if cfg.InterpolateParams { writeDSNParam(&buf, &hasParam, "interpolateParams", "true") } @@ -514,7 +528,11 @@ func parseDSNParams(cfg *Config, params string) (err error) { // Compression case "compress": - return errors.New("compression not implemented yet") + var isBool bool + cfg.compress, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } // Enable client side placeholder substitution case "interpolateParams": diff --git a/errors.go b/errors.go index a7ef8890..f6e4ff4c 100644 --- a/errors.go +++ b/errors.go @@ -39,11 +39,24 @@ var ( var defaultLogger = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) +// traceLogger is used for debug trace log. +var traceLogger *log.Logger + +func init() { + if debugTrace { + traceLogger = log.New(os.Stderr, "[mysql.trace]", log.Lmicroseconds|log.Lshortfile) + } +} + // Logger is used to log critical error messages. type Logger interface { Print(v ...any) } +func (mc *mysqlConn) logf(format string, v ...any) { + mc.cfg.Logger.Print(fmt.Sprintf(format, v...)) +} + // NopLogger is a nop implementation of the Logger interface. type NopLogger struct{} diff --git a/infile.go b/infile.go index 0c8af9f1..1f1b8873 100644 --- a/infile.go +++ b/infile.go @@ -171,6 +171,7 @@ func (mc *okHandler) handleInFileRequest(name string) (err error) { if ioErr := mc.conn().writePacket(data[:4]); ioErr != nil { return ioErr } + mc.conn().syncSequenceNr() // read OK packet if err == nil { diff --git a/packets.go b/packets.go index cf3412ff..1bbc0491 100644 --- a/packets.go +++ b/packets.go @@ -28,9 +28,11 @@ import ( // Read packet to buffer 'data' func (mc *mysqlConn) readPacket() ([]byte, error) { var prevData []byte + invalid := false + for { // read packet header - data, err := mc.buf.readNext(4) + data, err := mc.packetReader.readNext(4) if err != nil { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr @@ -42,16 +44,28 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // packet length [24 bit] pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) + seqNr := data[3] - // check packet sync [8 bit] - if data[3] != mc.sequence { - mc.Close() - if data[3] > mc.sequence { - return nil, ErrPktSyncMul + if mc.compress { + // MySQL and MariaDB doesn't check packet nr in compressed packet. + if debugTrace && seqNr != mc.compressSequence { + mc.logf("[debug] mismatched compression sequence nr: expected: %v, got %v", + mc.compressSequence, seqNr) + } + mc.compressSequence = seqNr + 1 + } else { + // check packet sync [8 bit] + if seqNr != mc.sequence { + mc.logf("[warn] unexpected seq nr: expected %v, got %v", mc.sequence, seqNr) + // For large packets, we stop reading as soon as sync error. + if len(prevData) > 0 { + return nil, ErrPktSyncMul + } + // TODO(methane): report error when the packet is not an error packet. + invalid = true } - return nil, ErrPktSync + mc.sequence++ } - mc.sequence++ // packets with length 0 terminate a previous packet which is a // multiple of (2^24)-1 bytes long @@ -62,12 +76,11 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { mc.Close() return nil, ErrInvalidConn } - return prevData, nil } // read packet body [pktLen bytes] - data, err = mc.buf.readNext(pktLen) + data, err = mc.packetReader.readNext(pktLen) if err != nil { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr @@ -81,6 +94,10 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { if pktLen < maxPacketSize { // zero allocations for non-split packets if prevData == nil { + if invalid && data[0] != iERR { + // return sync error only for regular packet. + return nil, ErrPktSync + } return data, nil } @@ -115,13 +132,24 @@ func (mc *mysqlConn) writePacket(data []byte) error { data[3] = mc.sequence // Write packet + if debugTrace { + traceLogger.Printf("writePacket: size=%v seq=%v", size, mc.sequence) + } if mc.writeTimeout > 0 { if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { return err } } - n, err := mc.netConn.Write(data[:4+size]) + var ( + n int + err error + ) + if mc.compress { + n, err = mc.writeCompressed(data[:4+size]) + } else { + n, err = mc.netConn.Write(data[:4+size]) + } if err == nil && n == 4+size { mc.sequence++ if size != maxPacketSize { @@ -203,6 +231,7 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro return nil, "", ErrNoTLS } } + pos += 2 if len(data) > pos { @@ -265,12 +294,13 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string if mc.cfg.ClientFoundRows { clientFlags |= clientFoundRows } - + if mc.cfg.compress && mc.flags&clientCompress == clientCompress { + clientFlags |= clientCompress + } // To enable TLS / SSL if mc.cfg.TLS != nil { clientFlags |= clientSSL } - if mc.cfg.MultiStatements { clientFlags |= clientMultiStatements } @@ -407,14 +437,11 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { func (mc *mysqlConn) writeCommandPacket(command byte) error { // Reset Packet Sequence - mc.sequence = 0 + mc.resetSequenceNr() - data, err := mc.buf.takeSmallBuffer(4 + 1) - if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite - } + // We do not use mc.buf because this function is used by mc.Close() + // and mc.Close() could be used when some error happend during read. + data := make([]byte, 5) // Add command byte data[4] = command @@ -425,7 +452,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { // Reset Packet Sequence - mc.sequence = 0 + mc.resetSequenceNr() pktLen := 1 + len(arg) data, err := mc.buf.takeBuffer(pktLen + 4) @@ -442,12 +469,14 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { copy(data[5:], arg) // Send CMD packet - return mc.writePacket(data) + err = mc.writePacket(data) + mc.syncSequenceNr() + return err } func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { // Reset Packet Sequence - mc.sequence = 0 + mc.resetSequenceNr() data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) if err != nil { @@ -931,7 +960,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { pktLen = dataOffset + argLen } - stmt.mc.sequence = 0 + stmt.mc.resetSequenceNr() // Add command byte [1 byte] data[4] = comStmtSendLongData @@ -952,11 +981,10 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { continue } return err - } // Reset Packet Sequence - stmt.mc.sequence = 0 + stmt.mc.resetSequenceNr() return nil } @@ -981,7 +1009,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } // Reset packet-sequence - mc.sequence = 0 + mc.resetSequenceNr() var data []byte var err error @@ -1202,7 +1230,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { data = data[:pos] } - return mc.writePacket(data) + err = mc.writePacket(data) + mc.syncSequenceNr() + return err } // For each remaining resultset in the stream, discards its rows and updates diff --git a/packets_test.go b/packets_test.go index fa4683ea..1d4de7af 100644 --- a/packets_test.go +++ b/packets_test.go @@ -97,8 +97,10 @@ var _ net.Conn = new(mockConn) func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { conn := new(mockConn) connector := newConnector(NewConfig()) + buf := newBuffer(conn) mc := &mysqlConn{ - buf: newBuffer(conn), + buf: buf, + packetReader: &buf, cfg: connector.cfg, connector: connector, netConn: conn, @@ -114,6 +116,7 @@ func TestReadPacketSingleByte(t *testing.T) { mc := &mysqlConn{ buf: newBuffer(conn), } + mc.packetReader = &mc.buf conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} conn.maxReads = 1 @@ -143,7 +146,7 @@ func TestReadPacketWrongSequenceID(t *testing.T) { { ClientSequenceID: 0, ServerSequenceID: 0x42, - ExpectedErr: ErrPktSyncMul, + ExpectedErr: ErrPktSync, }, } { conn, mc := newRWMockConn(testCase.ClientSequenceID) @@ -166,6 +169,7 @@ func TestReadPacketSplit(t *testing.T) { mc := &mysqlConn{ buf: newBuffer(conn), } + mc.packetReader = &mc.buf data := make([]byte, maxPacketSize*2+4*3) const pkt2ofs = maxPacketSize + 4 @@ -273,6 +277,7 @@ func TestReadPacketFail(t *testing.T) { closech: make(chan struct{}), cfg: NewConfig(), } + mc.packetReader = &mc.buf // illegal empty (stand-alone) packet conn.data = []byte{0x00, 0x00, 0x00, 0x00} @@ -318,6 +323,7 @@ func TestRegression801(t *testing.T) { sequence: 42, closech: make(chan struct{}), } + mc.packetReader = &mc.buf conn.data = []byte{72, 0, 0, 42, 10, 53, 46, 53, 46, 56, 0, 165, 0, 0, 0, 60, 70, 63, 58, 68, 104, 34, 97, 0, 223, 247, 33, 2, 0, 15, 128, 21, 0,