diff --git a/CONTRIBUTING.md b/.github/CONTRIBUTING.md similarity index 100% rename from CONTRIBUTING.md rename to .github/CONTRIBUTING.md diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 000000000..fce4cf670 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,102 @@ +name: test +on: + pull_request: + push: + workflow_dispatch: + +env: + MYSQL_TEST_USER: gotest + MYSQL_TEST_PASS: secret + MYSQL_TEST_ADDR: 127.0.0.1:3306 + MYSQL_TEST_CONCURRENT: 1 + +jobs: + list: + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - name: list + id: set-matrix + run: | + import json + go = [ + # Keep the most recent production release at the top + '1.16', + # Older production releases + '1.15', + '1.14', + '1.13', + ] + mysql = [ + '8.0', + '5.7', + '5.6', + 'mariadb-10.5', + 'mariadb-10.4', + 'mariadb-10.3', + ] + + includes = [] + # Go versions compatibility check + for v in go[1:]: + includes.append({'os': 'ubuntu-latest', 'go': v, 'mysql': mysql[0]}) + + matrix = { + # OS vs MySQL versions + 'os': [ 'ubuntu-latest', 'macos-latest', 'windows-latest' ], + 'go': [ go[0] ], + 'mysql': mysql, + + 'include': includes + } + output = json.dumps(matrix, separators=(',', ':')) + print('::set-output name=matrix::{0}'.format(output)) + shell: python + test: + needs: list + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: ${{ fromJSON(needs.list.outputs.matrix) }} + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go }} + - uses: shogo82148/actions-setup-mysql@v1 + with: + mysql-version: ${{ matrix.mysql }} + user: ${{ env.MYSQL_TEST_USER }} + password: ${{ env.MYSQL_TEST_PASS }} + my-cnf: | + innodb_log_file_size=256MB + innodb_buffer_pool_size=512MB + max_allowed_packet=16MB + ; TestConcurrent fails if max_connections is too large + max_connections=50 + local_infile=1 + - name: setup database + run: | + mysql --user 'root' --host '127.0.0.1' -e 'create database gotest;' + + - name: test + run: | + go test -v '-covermode=count' '-coverprofile=coverage.out' + + - name: Send coverage + uses: shogo82148/actions-goveralls@v1 + with: + path-to-profile: coverage.out + flag-name: ${{ runner.os }}-Go-${{ matrix.go }}-DB-${{ matrix.mysql }} + parallel: true + + # notifies that all test jobs are finished. + finish: + needs: test + if: always() + runs-on: ubuntu-latest + steps: + - uses: shogo82148/actions-goveralls@v1 + with: + parallel-finished: true diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 7a4113dbc..000000000 --- a/.travis.yml +++ /dev/null @@ -1,107 +0,0 @@ -sudo: false -language: go -go: - - 1.8.x - - 1.9.x - - 1.10.x - - master - -before_install: - - go get golang.org/x/tools/cmd/cover - - go get github.com/mattn/goveralls - -before_script: - - echo -e "[server]\ninnodb_log_file_size=256MB\ninnodb_buffer_pool_size=512MB\nmax_allowed_packet=16MB" | sudo tee -a /etc/mysql/my.cnf - - sudo service mysql restart - - .travis/wait_mysql.sh - - mysql -e 'create database gotest;' - - mysql -e 'show databases;' - -matrix: - include: - - env: DB=MYSQL8 - sudo: required - dist: trusty - go: 1.10.x - services: - - docker - before_install: - - go get golang.org/x/tools/cmd/cover - - go get github.com/mattn/goveralls - - docker pull mysql:8.0 - - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret - mysql:8.0 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1 - - cp .travis/docker.cnf ~/.my.cnf - - .travis/wait_mysql.sh - before_script: - - export MYSQL_TEST_USER=gotest - - export MYSQL_TEST_PASS=secret - - export MYSQL_TEST_ADDR=127.0.0.1:3307 - - export MYSQL_TEST_CONCURRENT=1 - - - env: DB=MYSQL57 - sudo: required - dist: trusty - go: 1.10.x - services: - - docker - before_install: - - go get golang.org/x/tools/cmd/cover - - go get github.com/mattn/goveralls - - docker pull mysql:5.7 - - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret - mysql:5.7 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1 - - cp .travis/docker.cnf ~/.my.cnf - - .travis/wait_mysql.sh - before_script: - - export MYSQL_TEST_USER=gotest - - export MYSQL_TEST_PASS=secret - - export MYSQL_TEST_ADDR=127.0.0.1:3307 - - export MYSQL_TEST_CONCURRENT=1 - - - env: DB=MARIA55 - sudo: required - dist: trusty - go: 1.10.x - services: - - docker - before_install: - - go get golang.org/x/tools/cmd/cover - - go get github.com/mattn/goveralls - - docker pull mariadb:5.5 - - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret - mariadb:5.5 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1 - - cp .travis/docker.cnf ~/.my.cnf - - .travis/wait_mysql.sh - before_script: - - export MYSQL_TEST_USER=gotest - - export MYSQL_TEST_PASS=secret - - export MYSQL_TEST_ADDR=127.0.0.1:3307 - - export MYSQL_TEST_CONCURRENT=1 - - - env: DB=MARIA10_1 - sudo: required - dist: trusty - go: 1.10.x - services: - - docker - before_install: - - go get golang.org/x/tools/cmd/cover - - go get github.com/mattn/goveralls - - docker pull mariadb:10.1 - - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret - mariadb:10.1 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1 - - cp .travis/docker.cnf ~/.my.cnf - - .travis/wait_mysql.sh - before_script: - - export MYSQL_TEST_USER=gotest - - export MYSQL_TEST_PASS=secret - - export MYSQL_TEST_ADDR=127.0.0.1:3307 - - export MYSQL_TEST_CONCURRENT=1 - -script: - - go test -v -covermode=count -coverprofile=coverage.out - - go vet ./... - - .travis/gofmt.sh -after_script: - - $HOME/gopath/bin/goveralls -coverprofile=coverage.out -service=travis-ci diff --git a/.travis/docker.cnf b/.travis/docker.cnf deleted file mode 100644 index e57754e5a..000000000 --- a/.travis/docker.cnf +++ /dev/null @@ -1,5 +0,0 @@ -[client] -user = gotest -password = secret -host = 127.0.0.1 -port = 3307 diff --git a/.travis/gofmt.sh b/.travis/gofmt.sh deleted file mode 100755 index 9bf0d1684..000000000 --- a/.travis/gofmt.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash -set -ev - -# Only check for go1.10+ since the gofmt style changed -if [[ $(go version) =~ go1\.([0-9]+) ]] && ((${BASH_REMATCH[1]} >= 10)); then - test -z "$(gofmt -d -s . | tee /dev/stderr)" -fi diff --git a/.travis/wait_mysql.sh b/.travis/wait_mysql.sh deleted file mode 100755 index e87993e57..000000000 --- a/.travis/wait_mysql.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/sh -while : -do - if mysql -e 'select version()' 2>&1 | grep 'version()\|ERROR 2059 (HY000):'; then - break - fi - sleep 3 -done diff --git a/AUTHORS b/AUTHORS index 7124f8bd4..df847251d 100644 --- a/AUTHORS +++ b/AUTHORS @@ -13,12 +13,17 @@ Aaron Hopkins Achille Roussel +Alex Snast Alexey Palazhchenko Andrew Reid +Animesh Ray Arne Hormann +Ariel Mashraki Asta Xie Bulat Gaifullin +Caine Jette Carlos Nieto +Chris Kirkland Chris Moos Craig Wilson Daniel Montoya @@ -27,6 +32,7 @@ Daniël van Eeden Dave Protasowski DisposaBoy Egor Smolyakov +Erwan Martin Evan Shaw Frederick Mayle Gustavo Kristic @@ -34,12 +40,17 @@ Hajime Nakagami Hanno Braun Henri Yandell Hirotaka Yamamoto +Huyiguang ICHINOSE Shogo +Ilia Cimpoes INADA Naoki Jacek Szwec James Harr +Janek Vedock Jeff Hodges Jeffrey Charles +Jerome Meyer +Jiajia Zhong Jian Zhen Joshua Prunier Julien Lefevre @@ -47,6 +58,7 @@ Julien Schmidt Justin Li Justin Nuß Kamil Dziedzic +Kei Kamikawa Kevin Malachowski Kieron Woodhouse Lennart Rudolph @@ -55,9 +67,11 @@ Linh Tran Tuan Lion Yang Luca Looz Lucas Liu +Lunny Xiao Luke Scott Maciej Zimnoch Michael Woolnough +Nathanial Murphy Nicola Peduzzi Olivier Mengué oscarzhao @@ -68,23 +82,41 @@ Reed Allman Richard Wilkes Robert Russell Runrioter Wung +Sho Iizuka +Sho Ikeda Shuode Li +Simon J Mudd Soroush Pour Stan Putrya Stanley Gunawan +Steven Hartland +Tan Jinhua <312841925 at qq.com> +Thomas Wodarek +Tim Ruffles +Tom Jenkinson Vasily Fedoseyev +Vladimir Kovpak +Vladyslav Zhelezniak Xiangyu Hu Xiaobing Jiang Xiuming Chen +Xuehong Chan Zhenye Xie +Zhixin Wen +Ziheng Lyu # Organizations Barracuda Networks, Inc. Counting Ltd. +DigitalOcean Inc. +Facebook Inc. +GitHub Inc. Google Inc. InfoSum Ltd. Keybase Inc. +Multiplay Ltd. Percona LLC Pivotal Inc. Stripe Inc. +Zendesk Inc. diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d87d74c9..72a738ed5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,68 @@ +## Version 1.6 (2021-04-01) + +Changes: + + - Migrate the CI service from travis-ci to GitHub Actions (#1176, #1183, #1190) + - `NullTime` is deprecated (#960, #1144) + - Reduce allocations when building SET command (#1111) + - Performance improvement for time formatting (#1118) + - Performance improvement for time parsing (#1098, #1113) + +New Features: + + - Implement `driver.Validator` interface (#1106, #1174) + - Support returning `uint64` from `Valuer` in `ConvertValue` (#1143) + - Add `json.RawMessage` for converter and prepared statement (#1059) + - Interpolate `json.RawMessage` as `string` (#1058) + - Implements `CheckNamedValue` (#1090) + +Bugfixes: + + - Stop rounding times (#1121, #1172) + - Put zero filler into the SSL handshake packet (#1066) + - Fix checking cancelled connections back into the connection pool (#1095) + - Fix remove last 0 byte for mysql_old_password when password is empty (#1133) + + +## Version 1.5 (2020-01-07) + +Changes: + + - Dropped support Go 1.9 and lower (#823, #829, #886, #1016, #1017) + - Improve buffer handling (#890) + - Document potentially insecure TLS configs (#901) + - Use a double-buffering scheme to prevent data races (#943) + - Pass uint64 values without converting them to string (#838, #955) + - Update collations and make utf8mb4 default (#877, #1054) + - Make NullTime compatible with sql.NullTime in Go 1.13+ (#995) + - Removed CloudSQL support (#993, #1007) + - Add Go Module support (#1003) + +New Features: + + - Implement support of optional TLS (#900) + - Check connection liveness (#934, #964, #997, #1048, #1051, #1052) + - Implement Connector Interface (#941, #958, #1020, #1035) + +Bugfixes: + + - Mark connections as bad on error during ping (#875) + - Mark connections as bad on error during dial (#867) + - Fix connection leak caused by rapid context cancellation (#1024) + - Mark connections as bad on error during Conn.Prepare (#1030) + + +## Version 1.4.1 (2018-11-14) + +Bugfixes: + + - Fix TIME format for binary columns (#818) + - Fix handling of empty auth plugin names (#835) + - Fix caching_sha2_password with empty password (#826) + - Fix canceled context broke mysqlConn (#862) + - Fix OldAuthSwitchRequest support (#870) + - Fix Auth Response packet for cleartext password (#887) + ## Version 1.4 (2018-06-03) Changes: diff --git a/README.md b/README.md index 9c712f4fe..c32f88d99 100644 --- a/README.md +++ b/README.md @@ -35,12 +35,12 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac * Supports queries larger than 16MB * Full [`sql.RawBytes`](https://golang.org/pkg/database/sql/#RawBytes) support. * Intelligent `LONG DATA` handling in prepared statements - * Secure `LOAD DATA LOCAL INFILE` support with file Whitelisting and `io.Reader` support + * Secure `LOAD DATA LOCAL INFILE` support with file allowlisting and `io.Reader` support * Optional `time.Time` parsing * Optional placeholder interpolation ## Requirements - * Go 1.7 or higher. We aim to support the 3 latest versions of Go. + * Go 1.13 or higher. We aim to support the 3 latest versions of Go. * MySQL (4.1+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+) --------------------------------------- @@ -56,15 +56,37 @@ Make sure [Git is installed](https://git-scm.com/downloads) on your machine and _Go MySQL Driver_ is an implementation of Go's `database/sql/driver` interface. You only need to import the driver and can use the full [`database/sql`](https://golang.org/pkg/database/sql/) API then. Use `mysql` as `driverName` and a valid [DSN](#dsn-data-source-name) as `dataSourceName`: + ```go -import "database/sql" -import _ "github.com/go-sql-driver/mysql" +import ( + "database/sql" + "time" + + _ "github.com/go-sql-driver/mysql" +) + +// ... db, err := sql.Open("mysql", "user:password@/dbname") +if err != nil { + panic(err) +} +// See "Important settings" section. +db.SetConnMaxLifetime(time.Minute * 3) +db.SetMaxOpenConns(10) +db.SetMaxIdleConns(10) ``` [Examples are available in our Wiki](https://github.com/go-sql-driver/mysql/wiki/Examples "Go-MySQL-Driver Examples"). +### Important settings + +`db.SetConnMaxLifetime()` is required to ensure connections are closed by the driver safely before connection is closed by MySQL server, OS, or other middlewares. Since some middlewares close idle connections by 5 minutes, we recommend timeout shorter than 5 minutes. This setting helps load balancing and changing system variables too. + +`db.SetMaxOpenConns()` is highly recommended to limit the number of connection used by the application. There is no recommended limit number because it depends on application and MySQL server. + +`db.SetMaxIdleConns()` is recommended to be set same to `db.SetMaxOpenConns()`. When it is smaller than `SetMaxOpenConns()`, connections can be opened and closed much more frequently than you expect. Idle connections can be closed by the `db.SetConnMaxLifetime()`. If you want to close idle connections more rapidly, you can use `db.SetConnMaxIdleTime()` since Go 1.15. + ### DSN (Data Source Name) @@ -122,7 +144,7 @@ Valid Values: true, false Default: false ``` -`allowAllFiles=true` disables the file Whitelist for `LOAD DATA LOCAL INFILE` and allows *all* files. +`allowAllFiles=true` disables the file allowlist for `LOAD DATA LOCAL INFILE` and allows *all* files. [*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html) ##### `allowCleartextPasswords` @@ -133,7 +155,7 @@ Valid Values: true, false Default: false ``` -`allowCleartextPasswords=true` allows using the [cleartext client side plugin](http://dev.mysql.com/doc/en/cleartext-authentication-plugin.html) if required by an account, such as one defined with the [PAM authentication plugin](http://dev.mysql.com/doc/en/pam-authentication-plugin.html). Sending passwords in clear text may be a security problem in some configurations. To avoid problems if there is any possibility that the password would be intercepted, clients should connect to MySQL Server using a method that protects the password. Possibilities include [TLS / SSL](#tls), IPsec, or a private network. +`allowCleartextPasswords=true` allows using the [cleartext client side plugin](https://dev.mysql.com/doc/en/cleartext-pluggable-authentication.html) if required by an account, such as one defined with the [PAM authentication plugin](http://dev.mysql.com/doc/en/pam-authentication-plugin.html). Sending passwords in clear text may be a security problem in some configurations. To avoid problems if there is any possibility that the password would be intercepted, clients should connect to MySQL Server using a method that protects the password. Possibilities include [TLS / SSL](#tls), IPsec, or a private network. ##### `allowNativePasswords` @@ -166,18 +188,34 @@ Sets the charset used for client-server interaction (`"SET NAMES "`). If Usage of the `charset` parameter is discouraged because it issues additional queries to the server. Unless you need the fallback behavior, please use `collation` instead. +##### `checkConnLiveness` + +``` +Type: bool +Valid Values: true, false +Default: true +``` + +On supported platforms connections retrieved from the connection pool are checked for liveness before using them. If the check fails, the respective connection is marked as bad and the query retried with another connection. +`checkConnLiveness=false` disables this liveness check of connections. + ##### `collation` ``` Type: string Valid Values: -Default: utf8_general_ci +Default: utf8mb4_general_ci ``` Sets the collation used for client-server interaction on connection. In contrast to `charset`, `collation` does not issue additional queries. If the specified collation is unavailable on the target server, the connection will fail. A list of valid charsets for a server is retrievable with `SHOW COLLATION`. +The default collation (`utf8mb4_general_ci`) is supported from MySQL 5.5. You should use an older collation (e.g. `utf8_general_ci`) for older MySQL. + +Collations for charset "ucs2", "utf16", "utf16le", and "utf32" can not be used ([ref](https://dev.mysql.com/doc/refman/5.7/en/charset-connection.html#charset-connection-impermissible-client-charset)). + + ##### `clientFoundRows` ``` @@ -224,7 +262,7 @@ Default: false If `interpolateParams` is true, placeholders (`?`) in calls to `db.Query()` and `db.Exec()` are interpolated into a single query string with given parameters. This reduces the number of roundtrips, since the driver has to prepare a statement, execute it with given parameters and close the statement again with `interpolateParams=false`. -*This can not be used together with the multibyte encodings BIG5, CP932, GB2312, GBK or SJIS. These are blacklisted as they may [introduce a SQL injection vulnerability](http://stackoverflow.com/a/12118602/3430118)!* +*This can not be used together with the multibyte encodings BIG5, CP932, GB2312, GBK or SJIS. These are rejected as they may [introduce a SQL injection vulnerability](http://stackoverflow.com/a/12118602/3430118)!* ##### `loc` @@ -338,11 +376,11 @@ Timeout for establishing connections, aka dial timeout. The value must be a deci ``` Type: bool / string -Valid Values: true, false, skip-verify, +Valid Values: true, false, skip-verify, preferred, Default: false ``` -`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side). Use a custom value registered with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig). +`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side) or use `preferred` to use TLS only when advertised by the server. This is similar to `skip-verify`, but additionally allows a fallback to a connection which is not encrypted. Neither `skip-verify` nor `preferred` add any reliable security. You can use a custom TLS config after registering it with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig). ##### `writeTimeout` @@ -370,7 +408,7 @@ Rules: Examples: * `autocommit=1`: `SET autocommit=1` * [`time_zone=%27Europe%2FParis%27`](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html): `SET time_zone='Europe/Paris'` - * [`tx_isolation=%27REPEATABLE-READ%27`](https://dev.mysql.com/doc/refman/5.5/en/server-system-variables.html#sysvar_tx_isolation): `SET tx_isolation='REPEATABLE-READ'` + * [`transaction_isolation=%27REPEATABLE-READ%27`](https://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_transaction_isolation): `SET transaction_isolation='REPEATABLE-READ'` #### Examples @@ -401,14 +439,9 @@ TCP on a remote host, e.g. Amazon RDS: id:password@tcp(your-amazonaws-uri.com:3306)/dbname ``` -Google Cloud SQL on App Engine (First Generation MySQL Server): +Google Cloud SQL on App Engine: ``` -user@cloudsql(project-id:instance-name)/dbname -``` - -Google Cloud SQL on App Engine (Second Generation MySQL Server): -``` -user@cloudsql(project-id:regionname:instance-name)/dbname +user:password@unix(/cloudsql/project-id:region-name:instance-name)/dbname ``` TCP using default port (3306) on localhost: @@ -431,7 +464,7 @@ user:password@/ The connection pool is managed by Go's database/sql package. For details on how to configure the size of the pool and how long connections stay in the pool see `*DB.SetMaxOpenConns`, `*DB.SetMaxIdleConns`, and `*DB.SetConnMaxLifetime` in the [database/sql documentation](https://golang.org/pkg/database/sql/). The read, write, and dial timeouts for each individual connection are configured with the DSN parameters [`readTimeout`](#readtimeout), [`writeTimeout`](#writetimeout), and [`timeout`](#timeout), respectively. ## `ColumnType` Support -This driver supports the [`ColumnType` interface](https://golang.org/pkg/database/sql/#ColumnType) introduced in Go 1.8, with the exception of [`ColumnType.Length()`](https://golang.org/pkg/database/sql/#ColumnType.Length), which is currently not supported. +This driver supports the [`ColumnType` interface](https://golang.org/pkg/database/sql/#ColumnType) introduced in Go 1.8, with the exception of [`ColumnType.Length()`](https://golang.org/pkg/database/sql/#ColumnType.Length), which is currently not supported. All Unsigned database type names will be returned `UNSIGNED ` with `INT`, `TINYINT`, `SMALLINT`, `BIGINT`. ## `context.Context` Support Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts. @@ -444,7 +477,7 @@ For this feature you need direct access to the package. Therefore you must chang import "github.com/go-sql-driver/mysql" ``` -Files must be whitelisted by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the Whitelist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)). +Files must be explicitly allowed by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the allowlist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)). To use a `io.Reader` a handler function must be registered with `mysql.RegisterReaderHandler(name, handler)` which returns a `io.Reader` or `io.ReadCloser`. The Reader is available with the filepath `Reader::` then. Choose different names for different handlers and `DeregisterReaderHandler` when you don't need it anymore. @@ -454,21 +487,19 @@ See the [godoc of Go-MySQL-Driver](https://godoc.org/github.com/go-sql-driver/my ### `time.Time` support The default internal output type of MySQL `DATE` and `DATETIME` values is `[]byte` which allows you to scan the value into a `[]byte`, `string` or `sql.RawBytes` variable in your program. -However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical opposite in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](https://golang.org/pkg/time/#Location) with the `loc` DSN parameter. +However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical equivalent in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](https://golang.org/pkg/time/#Location) with the `loc` DSN parameter. **Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes). -Alternatively you can use the [`NullTime`](https://godoc.org/github.com/go-sql-driver/mysql#NullTime) type as the scan destination, which works with both `time.Time` and `string` / `[]byte`. - ### Unicode support -Since version 1.1 Go-MySQL-Driver automatically uses the collation `utf8_general_ci` by default. +Since version 1.5 Go-MySQL-Driver automatically uses the collation ` utf8mb4_general_ci` by default. Other collations / charsets can be set using the [`collation`](#collation) DSN parameter. Version 1.0 of the driver recommended adding `&charset=utf8` (alias for `SET NAMES utf8`) to the DSN to enable proper UTF-8 support. This is not necessary anymore. The [`collation`](#collation) parameter should be preferred to set another collation / charset than the default. -See http://dev.mysql.com/doc/refman/5.7/en/charset-unicode.html for more details on MySQL's Unicode support. +See http://dev.mysql.com/doc/refman/8.0/en/charset-unicode.html for more details on MySQL's Unicode support. ## Testing / Development To run the driver tests you may need to adjust the configuration. See the [Testing Wiki-Page](https://github.com/go-sql-driver/mysql/wiki/Testing "Testing") for details. @@ -476,7 +507,7 @@ To run the driver tests you may need to adjust the configuration. See the [Testi Go-MySQL-Driver is not feature-complete yet. Your help is very appreciated. If you want to contribute, you can work on an [open issue](https://github.com/go-sql-driver/mysql/issues?state=open) or review a [pull request](https://github.com/go-sql-driver/mysql/pulls). -See the [Contribution Guidelines](https://github.com/go-sql-driver/mysql/blob/master/CONTRIBUTING.md) for details. +See the [Contribution Guidelines](https://github.com/go-sql-driver/mysql/blob/master/.github/CONTRIBUTING.md) for details. --------------------------------------- diff --git a/auth.go b/auth.go index 0b59f52ee..b2f19e8f0 100644 --- a/auth.go +++ b/auth.go @@ -15,6 +15,7 @@ import ( "crypto/sha256" "crypto/x509" "encoding/pem" + "fmt" "sync" ) @@ -136,10 +137,6 @@ func pwHash(password []byte) (result [2]uint32) { // Hash password using insecure pre 4.1 method func scrambleOldPassword(scramble []byte, password string) []byte { - if len(password) == 0 { - return nil - } - scramble = scramble[:8] hashPw := pwHash([]byte(password)) @@ -234,64 +231,67 @@ func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) erro if err != nil { return err } - return mc.writeAuthSwitchPacket(enc, false) + return mc.writeAuthSwitchPacket(enc) } -func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, bool, error) { +func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { switch plugin { case "caching_sha2_password": authResp := scrambleSHA256Password(authData, mc.cfg.Passwd) - return authResp, (authResp == nil), nil + return authResp, nil case "mysql_old_password": if !mc.cfg.AllowOldPasswords { - return nil, false, ErrOldPassword + return nil, ErrOldPassword + } + if len(mc.cfg.Passwd) == 0 { + return nil, nil } // Note: there are edge cases where this should work but doesn't; // this is currently "wontfix": // https://github.com/go-sql-driver/mysql/issues/184 - authResp := scrambleOldPassword(authData[:8], mc.cfg.Passwd) - return authResp, true, nil + authResp := append(scrambleOldPassword(authData[:8], mc.cfg.Passwd), 0) + return authResp, nil case "mysql_clear_password": if !mc.cfg.AllowCleartextPasswords { - return nil, false, ErrCleartextPassword + return nil, ErrCleartextPassword } // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html - return []byte(mc.cfg.Passwd), true, nil + return append([]byte(mc.cfg.Passwd), 0), nil case "mysql_native_password": if !mc.cfg.AllowNativePasswords { - return nil, false, ErrNativePassword + return nil, ErrNativePassword } // https://dev.mysql.com/doc/internals/en/secure-password-authentication.html // Native password authentication only need and will need 20-byte challenge. authResp := scramblePassword(authData[:20], mc.cfg.Passwd) - return authResp, false, nil + return authResp, nil case "sha256_password": if len(mc.cfg.Passwd) == 0 { - return nil, true, nil + return []byte{0}, nil } if mc.cfg.tls != nil || mc.cfg.Net == "unix" { // write cleartext auth packet - return []byte(mc.cfg.Passwd), true, nil + return append([]byte(mc.cfg.Passwd), 0), nil } pubKey := mc.cfg.pubKey if pubKey == nil { // request public key from server - return []byte{1}, false, nil + return []byte{1}, nil } // encrypted password enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey) - return enc, false, err + return enc, err default: errLog.Print("unknown auth plugin:", plugin) - return nil, false, ErrUnknownPlugin + return nil, ErrUnknownPlugin } } @@ -315,11 +315,11 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { plugin = newPlugin - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { return err } - if err = mc.writeAuthSwitchPacket(authResp, addNUL); err != nil { + if err = mc.writeAuthSwitchPacket(authResp); err != nil { return err } @@ -352,7 +352,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { case cachingSha2PasswordPerformFullAuthentication: if mc.cfg.tls != nil || mc.cfg.Net == "unix" { // write cleartext auth packet - err = mc.writeAuthSwitchPacket([]byte(mc.cfg.Passwd), true) + err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0)) if err != nil { return err } @@ -360,17 +360,22 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { pubKey := mc.cfg.pubKey if pubKey == nil { // request public key from server - data := mc.buf.takeSmallBuffer(4 + 1) + data, err := mc.buf.takeSmallBuffer(4 + 1) + if err != nil { + return err + } data[4] = cachingSha2PasswordRequestPublicKey mc.writePacket(data) // parse public key - data, err := mc.readPacket() - if err != nil { + if data, err = mc.readPacket(); err != nil { return err } - block, _ := pem.Decode(data[1:]) + block, rest := pem.Decode(data[1:]) + if block == nil { + return fmt.Errorf("No Pem data found, data: %s", rest) + } pkix, err := x509.ParsePKIXPublicKey(block.Bytes) if err != nil { return err diff --git a/auth_test.go b/auth_test.go index 040a016f7..3bce7fe22 100644 --- a/auth_test.go +++ b/auth_test.go @@ -13,7 +13,6 @@ import ( "crypto/rsa" "crypto/tls" "crypto/x509" - "encoding/hex" "encoding/pem" "fmt" "testing" @@ -86,11 +85,11 @@ func TestAuthFastCachingSHA256PasswordCached(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -131,11 +130,11 @@ func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -173,11 +172,11 @@ func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -229,11 +228,11 @@ func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -281,11 +280,11 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -337,7 +336,7 @@ func TestAuthFastCleartextPasswordNotAllowed(t *testing.T) { plugin := "mysql_clear_password" // Send Client Authentication Packet - _, _, err := mc.auth(authData, plugin) + _, err := mc.auth(authData, plugin) if err != ErrCleartextPassword { t.Errorf("expected ErrCleartextPassword, got %v", err) } @@ -354,26 +353,23 @@ func TestAuthFastCleartextPassword(t *testing.T) { plugin := "mysql_clear_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } // check written auth response - authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] - expectedAuthResp := []byte("secret") - if !bytes.Equal(writtenAuthResp, expectedAuthResp) { - t.Fatalf("unexpected written auth response:\n%s\nExpected:\n%s\n", - hex.Dump(writtenAuthResp), hex.Dump(expectedAuthResp)) - } - if conn.written[authRespEnd] != 0 { - t.Fatalf("Expected null-terminated") + expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0} + if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) } conn.written = nil @@ -400,11 +396,11 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) { plugin := "mysql_clear_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -414,9 +410,9 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) { authRespEnd := authRespStart + 1 + len(authResp) writtenAuthRespLen := conn.written[authRespStart] writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] - if writtenAuthRespLen != 0 { - t.Fatalf("unexpected written auth response (%d bytes): %v", - writtenAuthRespLen, writtenAuthResp) + expectedAuthResp := []byte{0} + if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) } conn.written = nil @@ -443,7 +439,7 @@ func TestAuthFastNativePasswordNotAllowed(t *testing.T) { plugin := "mysql_native_password" // Send Client Authentication Packet - _, _, err := mc.auth(authData, plugin) + _, err := mc.auth(authData, plugin) if err != ErrNativePassword { t.Errorf("expected ErrNativePassword, got %v", err) } @@ -459,11 +455,11 @@ func TestAuthFastNativePassword(t *testing.T) { plugin := "mysql_native_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -502,11 +498,11 @@ func TestAuthFastNativePasswordEmpty(t *testing.T) { plugin := "mysql_native_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -544,11 +540,11 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) { plugin := "sha256_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -558,7 +554,8 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) { authRespEnd := authRespStart + 1 + len(authResp) writtenAuthRespLen := conn.written[authRespStart] writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] - if writtenAuthRespLen != 0 { + expectedAuthResp := []byte{0} + if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) } conn.written = nil @@ -591,11 +588,11 @@ func TestAuthFastSHA256PasswordRSA(t *testing.T) { plugin := "sha256_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -640,11 +637,11 @@ func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) { plugin := "sha256_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -673,7 +670,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { plugin := "sha256_password" // send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } @@ -681,24 +678,20 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { // unset TLS config to prevent the actual establishment of a TLS wrapper mc.cfg.tls = nil - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } // check written auth response - authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] - expectedAuthResp := []byte("secret") - if !bytes.Equal(writtenAuthResp, expectedAuthResp) { - t.Fatalf("unexpected written auth response:\n%s\nExpected:\n%s\n", - hex.Dump(writtenAuthResp), hex.Dump(expectedAuthResp)) - } - if conn.written[authRespEnd] != 0 { - t.Fatalf("Expected null-terminated") + expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0} + if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) } - conn.written = nil // auth response (OK) @@ -772,7 +765,7 @@ func TestAuthSwitchCachingSHA256PasswordEmpty(t *testing.T) { t.Errorf("got error: %v", err) } - expectedReply := []byte{1, 0, 0, 3, 0} + expectedReply := []byte{0, 0, 0, 3} if !bytes.Equal(conn.written, expectedReply) { t.Errorf("got unexpected data: %v", conn.written) } @@ -1072,6 +1065,22 @@ func TestAuthSwitchOldPasswordNotAllowed(t *testing.T) { } } +// Same to TestAuthSwitchOldPasswordNotAllowed, but use OldAuthSwitch request. +func TestOldAuthSwitchNotAllowed(t *testing.T) { + conn, mc := newRWMockConn(2) + + // OldAuthSwitch request + conn.data = []byte{1, 0, 0, 2, 0xfe} + conn.maxReads = 1 + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + err := mc.handleAuthResult(authData, plugin) + if err != ErrOldPassword { + t.Errorf("expected ErrOldPassword, got %v", err) + } +} + func TestAuthSwitchOldPassword(t *testing.T) { conn, mc := newRWMockConn(2) mc.cfg.AllowOldPasswords = true @@ -1100,6 +1109,32 @@ func TestAuthSwitchOldPassword(t *testing.T) { } } +// Same to TestAuthSwitchOldPassword, but use OldAuthSwitch request. +func TestOldAuthSwitch(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowOldPasswords = true + mc.cfg.Passwd = "secret" + + // OldAuthSwitch request + conn.data = []byte{1, 0, 0, 2, 0xfe} + + // auth response + conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{9, 0, 0, 3, 86, 83, 83, 79, 74, 78, 65, 66, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} func TestAuthSwitchOldPasswordEmpty(t *testing.T) { conn, mc := newRWMockConn(2) mc.cfg.AllowOldPasswords = true @@ -1122,7 +1157,34 @@ func TestAuthSwitchOldPasswordEmpty(t *testing.T) { t.Errorf("got error: %v", err) } - expectedReply := []byte{1, 0, 0, 3, 0} + expectedReply := []byte{0, 0, 0, 3} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +// Same to TestAuthSwitchOldPasswordEmpty, but use OldAuthSwitch request. +func TestOldAuthSwitchPasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowOldPasswords = true + mc.cfg.Passwd = "" + + // OldAuthSwitch request. + conn.data = []byte{1, 0, 0, 2, 0xfe} + + // auth response + conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{0, 0, 0, 3} if !bytes.Equal(conn.written, expectedReply) { t.Errorf("got unexpected data: %v", conn.written) } diff --git a/benchmark_test.go b/benchmark_test.go index 5828d40f9..1030ddc52 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -127,7 +127,8 @@ func BenchmarkExec(b *testing.B) { } if _, err := stmt.Exec(); err != nil { - b.Fatal(err.Error()) + b.Logf("stmt.Exec failed: %v", err) + b.Fail() } } }() @@ -317,3 +318,57 @@ func BenchmarkExecContext(b *testing.B) { }) } } + +// BenchmarkQueryRawBytes benchmarks fetching 100 blobs using sql.RawBytes. +// "size=" means size of each blobs. +func BenchmarkQueryRawBytes(b *testing.B) { + var sizes []int = []int{100, 1000, 2000, 4000, 8000, 12000, 16000, 32000, 64000, 256000} + db := initDB(b, + "DROP TABLE IF EXISTS bench_rawbytes", + "CREATE TABLE bench_rawbytes (id INT PRIMARY KEY, val LONGBLOB)", + ) + defer db.Close() + + blob := make([]byte, sizes[len(sizes)-1]) + for i := range blob { + blob[i] = 42 + } + for i := 0; i < 100; i++ { + _, err := db.Exec("INSERT INTO bench_rawbytes VALUES (?, ?)", i, blob) + if err != nil { + b.Fatal(err) + } + } + + for _, s := range sizes { + b.Run(fmt.Sprintf("size=%v", s), func(b *testing.B) { + db.SetMaxIdleConns(0) + db.SetMaxIdleConns(1) + b.ReportAllocs() + b.ResetTimer() + + for j := 0; j < b.N; j++ { + rows, err := db.Query("SELECT LEFT(val, ?) as v FROM bench_rawbytes", s) + if err != nil { + b.Fatal(err) + } + nrows := 0 + for rows.Next() { + var buf sql.RawBytes + err := rows.Scan(&buf) + if err != nil { + b.Fatal(err) + } + if len(buf) != s { + b.Fatalf("size mismatch: expected %v, got %v", s, len(buf)) + } + nrows++ + } + rows.Close() + if nrows != 100 { + b.Fatalf("numbers of rows mismatch: expected %v, got %v", 100, nrows) + } + } + }) + } +} diff --git a/buffer.go b/buffer.go index eb4748bf4..0774c5c8c 100644 --- a/buffer.go +++ b/buffer.go @@ -15,47 +15,69 @@ import ( ) const defaultBufSize = 4096 +const maxCachedBufSize = 256 * 1024 // A buffer which is used for both reading and writing. // This is possible since communication on each connection is synchronous. // In other words, we can't write and read simultaneously on the same connection. // The buffer is similar to bufio.Reader / Writer but zero-copy-ish // Also highly optimized for this particular use case. +// This buffer is backed by two byte slices in a double-buffering scheme type buffer struct { - buf []byte + buf []byte // buf is a byte buffer who's length and capacity are equal. nc net.Conn idx int length int timeout time.Duration + dbuf [2][]byte // dbuf is an array with the two byte slices that back this buffer + flipcnt uint // flipccnt is the current buffer counter for double-buffering } +// newBuffer allocates and returns a new buffer. func newBuffer(nc net.Conn) buffer { - var b [defaultBufSize]byte + fg := make([]byte, defaultBufSize) return buffer{ - buf: b[:], - nc: nc, + buf: fg, + nc: nc, + dbuf: [2][]byte{fg, nil}, } } +// flip replaces the active buffer with the background buffer +// this is a delayed flip that simply increases the buffer counter; +// the actual flip will be performed the next time we call `buffer.fill` +func (b *buffer) flip() { + b.flipcnt += 1 +} + // fill reads into the buffer until at least _need_ bytes are in it func (b *buffer) fill(need int) error { n := b.length + // fill data into its double-buffering target: if we've called + // flip on this buffer, we'll be copying to the background buffer, + // and then filling it with network data; otherwise we'll just move + // the contents of the current buffer to the front before filling it + dest := b.dbuf[b.flipcnt&1] + + // grow buffer if necessary to fit the whole packet. + if need > len(dest) { + // Round up to the next multiple of the default size + dest = make([]byte, ((need/defaultBufSize)+1)*defaultBufSize) - // move existing data to the beginning - if n > 0 && b.idx > 0 { - copy(b.buf[0:n], b.buf[b.idx:]) + // if the allocated buffer is not too large, move it to backing storage + // to prevent extra allocations on applications that perform large reads + if len(dest) <= maxCachedBufSize { + b.dbuf[b.flipcnt&1] = dest + } } - // grow buffer if necessary - // TODO: let the buffer shrink again at some point - // Maybe keep the org buf slice and swap back? - if need > len(b.buf) { - // Round up to the next multiple of the default size - newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize) - copy(newBuf, b.buf) - b.buf = newBuf + // if we're filling the fg buffer, move the existing data to the start of it. + // if we're filling the bg buffer, copy over the data + if n > 0 { + copy(dest[:n], b.buf[b.idx:]) } + b.buf = dest b.idx = 0 for { @@ -105,43 +127,56 @@ func (b *buffer) readNext(need int) ([]byte, error) { return b.buf[offset:b.idx], nil } -// returns a buffer with the requested size. +// takeBuffer returns a buffer with the requested size. // If possible, a slice from the existing buffer is returned. // Otherwise a bigger buffer is made. // Only one buffer (total) can be used at a time. -func (b *buffer) takeBuffer(length int) []byte { +func (b *buffer) takeBuffer(length int) ([]byte, error) { if b.length > 0 { - return nil + return nil, ErrBusyBuffer } // test (cheap) general case first - if length <= defaultBufSize || length <= cap(b.buf) { - return b.buf[:length] + if length <= cap(b.buf) { + return b.buf[:length], nil } if length < maxPacketSize { b.buf = make([]byte, length) - return b.buf + return b.buf, nil } - return make([]byte, length) + + // buffer is larger than we want to store. + return make([]byte, length), nil } -// shortcut which can be used if the requested buffer is guaranteed to be -// smaller than defaultBufSize +// takeSmallBuffer is shortcut which can be used if length is +// known to be smaller than defaultBufSize. // Only one buffer (total) can be used at a time. -func (b *buffer) takeSmallBuffer(length int) []byte { +func (b *buffer) takeSmallBuffer(length int) ([]byte, error) { if b.length > 0 { - return nil + return nil, ErrBusyBuffer } - return b.buf[:length] + return b.buf[:length], nil } // takeCompleteBuffer returns the complete existing buffer. // This can be used if the necessary buffer size is unknown. +// cap and len of the returned buffer will be equal. // Only one buffer (total) can be used at a time. -func (b *buffer) takeCompleteBuffer() []byte { +func (b *buffer) takeCompleteBuffer() ([]byte, error) { + if b.length > 0 { + return nil, ErrBusyBuffer + } + return b.buf, nil +} + +// store stores buf, an updated buffer, if its suitable to do so. +func (b *buffer) store(buf []byte) error { if b.length > 0 { - return nil + return ErrBusyBuffer + } else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) { + b.buf = buf[:cap(buf)] } - return b.buf + return nil } diff --git a/collations.go b/collations.go index 136c9e4d1..326a9f7fa 100644 --- a/collations.go +++ b/collations.go @@ -8,183 +8,190 @@ package mysql -const defaultCollation = "utf8_general_ci" +const defaultCollation = "utf8mb4_general_ci" const binaryCollation = "binary" // A list of available collations mapped to the internal ID. // To update this map use the following MySQL query: -// SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS +// SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS WHERE ID<256 ORDER BY ID +// +// Handshake packet have only 1 byte for collation_id. So we can't use collations with ID > 255. +// +// ucs2, utf16, and utf32 can't be used for connection charset. +// https://dev.mysql.com/doc/refman/5.7/en/charset-connection.html#charset-connection-impermissible-client-charset +// They are commented out to reduce this map. var collations = map[string]byte{ - "big5_chinese_ci": 1, - "latin2_czech_cs": 2, - "dec8_swedish_ci": 3, - "cp850_general_ci": 4, - "latin1_german1_ci": 5, - "hp8_english_ci": 6, - "koi8r_general_ci": 7, - "latin1_swedish_ci": 8, - "latin2_general_ci": 9, - "swe7_swedish_ci": 10, - "ascii_general_ci": 11, - "ujis_japanese_ci": 12, - "sjis_japanese_ci": 13, - "cp1251_bulgarian_ci": 14, - "latin1_danish_ci": 15, - "hebrew_general_ci": 16, - "tis620_thai_ci": 18, - "euckr_korean_ci": 19, - "latin7_estonian_cs": 20, - "latin2_hungarian_ci": 21, - "koi8u_general_ci": 22, - "cp1251_ukrainian_ci": 23, - "gb2312_chinese_ci": 24, - "greek_general_ci": 25, - "cp1250_general_ci": 26, - "latin2_croatian_ci": 27, - "gbk_chinese_ci": 28, - "cp1257_lithuanian_ci": 29, - "latin5_turkish_ci": 30, - "latin1_german2_ci": 31, - "armscii8_general_ci": 32, - "utf8_general_ci": 33, - "cp1250_czech_cs": 34, - "ucs2_general_ci": 35, - "cp866_general_ci": 36, - "keybcs2_general_ci": 37, - "macce_general_ci": 38, - "macroman_general_ci": 39, - "cp852_general_ci": 40, - "latin7_general_ci": 41, - "latin7_general_cs": 42, - "macce_bin": 43, - "cp1250_croatian_ci": 44, - "utf8mb4_general_ci": 45, - "utf8mb4_bin": 46, - "latin1_bin": 47, - "latin1_general_ci": 48, - "latin1_general_cs": 49, - "cp1251_bin": 50, - "cp1251_general_ci": 51, - "cp1251_general_cs": 52, - "macroman_bin": 53, - "utf16_general_ci": 54, - "utf16_bin": 55, - "utf16le_general_ci": 56, - "cp1256_general_ci": 57, - "cp1257_bin": 58, - "cp1257_general_ci": 59, - "utf32_general_ci": 60, - "utf32_bin": 61, - "utf16le_bin": 62, - "binary": 63, - "armscii8_bin": 64, - "ascii_bin": 65, - "cp1250_bin": 66, - "cp1256_bin": 67, - "cp866_bin": 68, - "dec8_bin": 69, - "greek_bin": 70, - "hebrew_bin": 71, - "hp8_bin": 72, - "keybcs2_bin": 73, - "koi8r_bin": 74, - "koi8u_bin": 75, - "latin2_bin": 77, - "latin5_bin": 78, - "latin7_bin": 79, - "cp850_bin": 80, - "cp852_bin": 81, - "swe7_bin": 82, - "utf8_bin": 83, - "big5_bin": 84, - "euckr_bin": 85, - "gb2312_bin": 86, - "gbk_bin": 87, - "sjis_bin": 88, - "tis620_bin": 89, - "ucs2_bin": 90, - "ujis_bin": 91, - "geostd8_general_ci": 92, - "geostd8_bin": 93, - "latin1_spanish_ci": 94, - "cp932_japanese_ci": 95, - "cp932_bin": 96, - "eucjpms_japanese_ci": 97, - "eucjpms_bin": 98, - "cp1250_polish_ci": 99, - "utf16_unicode_ci": 101, - "utf16_icelandic_ci": 102, - "utf16_latvian_ci": 103, - "utf16_romanian_ci": 104, - "utf16_slovenian_ci": 105, - "utf16_polish_ci": 106, - "utf16_estonian_ci": 107, - "utf16_spanish_ci": 108, - "utf16_swedish_ci": 109, - "utf16_turkish_ci": 110, - "utf16_czech_ci": 111, - "utf16_danish_ci": 112, - "utf16_lithuanian_ci": 113, - "utf16_slovak_ci": 114, - "utf16_spanish2_ci": 115, - "utf16_roman_ci": 116, - "utf16_persian_ci": 117, - "utf16_esperanto_ci": 118, - "utf16_hungarian_ci": 119, - "utf16_sinhala_ci": 120, - "utf16_german2_ci": 121, - "utf16_croatian_ci": 122, - "utf16_unicode_520_ci": 123, - "utf16_vietnamese_ci": 124, - "ucs2_unicode_ci": 128, - "ucs2_icelandic_ci": 129, - "ucs2_latvian_ci": 130, - "ucs2_romanian_ci": 131, - "ucs2_slovenian_ci": 132, - "ucs2_polish_ci": 133, - "ucs2_estonian_ci": 134, - "ucs2_spanish_ci": 135, - "ucs2_swedish_ci": 136, - "ucs2_turkish_ci": 137, - "ucs2_czech_ci": 138, - "ucs2_danish_ci": 139, - "ucs2_lithuanian_ci": 140, - "ucs2_slovak_ci": 141, - "ucs2_spanish2_ci": 142, - "ucs2_roman_ci": 143, - "ucs2_persian_ci": 144, - "ucs2_esperanto_ci": 145, - "ucs2_hungarian_ci": 146, - "ucs2_sinhala_ci": 147, - "ucs2_german2_ci": 148, - "ucs2_croatian_ci": 149, - "ucs2_unicode_520_ci": 150, - "ucs2_vietnamese_ci": 151, - "ucs2_general_mysql500_ci": 159, - "utf32_unicode_ci": 160, - "utf32_icelandic_ci": 161, - "utf32_latvian_ci": 162, - "utf32_romanian_ci": 163, - "utf32_slovenian_ci": 164, - "utf32_polish_ci": 165, - "utf32_estonian_ci": 166, - "utf32_spanish_ci": 167, - "utf32_swedish_ci": 168, - "utf32_turkish_ci": 169, - "utf32_czech_ci": 170, - "utf32_danish_ci": 171, - "utf32_lithuanian_ci": 172, - "utf32_slovak_ci": 173, - "utf32_spanish2_ci": 174, - "utf32_roman_ci": 175, - "utf32_persian_ci": 176, - "utf32_esperanto_ci": 177, - "utf32_hungarian_ci": 178, - "utf32_sinhala_ci": 179, - "utf32_german2_ci": 180, - "utf32_croatian_ci": 181, - "utf32_unicode_520_ci": 182, - "utf32_vietnamese_ci": 183, + "big5_chinese_ci": 1, + "latin2_czech_cs": 2, + "dec8_swedish_ci": 3, + "cp850_general_ci": 4, + "latin1_german1_ci": 5, + "hp8_english_ci": 6, + "koi8r_general_ci": 7, + "latin1_swedish_ci": 8, + "latin2_general_ci": 9, + "swe7_swedish_ci": 10, + "ascii_general_ci": 11, + "ujis_japanese_ci": 12, + "sjis_japanese_ci": 13, + "cp1251_bulgarian_ci": 14, + "latin1_danish_ci": 15, + "hebrew_general_ci": 16, + "tis620_thai_ci": 18, + "euckr_korean_ci": 19, + "latin7_estonian_cs": 20, + "latin2_hungarian_ci": 21, + "koi8u_general_ci": 22, + "cp1251_ukrainian_ci": 23, + "gb2312_chinese_ci": 24, + "greek_general_ci": 25, + "cp1250_general_ci": 26, + "latin2_croatian_ci": 27, + "gbk_chinese_ci": 28, + "cp1257_lithuanian_ci": 29, + "latin5_turkish_ci": 30, + "latin1_german2_ci": 31, + "armscii8_general_ci": 32, + "utf8_general_ci": 33, + "cp1250_czech_cs": 34, + //"ucs2_general_ci": 35, + "cp866_general_ci": 36, + "keybcs2_general_ci": 37, + "macce_general_ci": 38, + "macroman_general_ci": 39, + "cp852_general_ci": 40, + "latin7_general_ci": 41, + "latin7_general_cs": 42, + "macce_bin": 43, + "cp1250_croatian_ci": 44, + "utf8mb4_general_ci": 45, + "utf8mb4_bin": 46, + "latin1_bin": 47, + "latin1_general_ci": 48, + "latin1_general_cs": 49, + "cp1251_bin": 50, + "cp1251_general_ci": 51, + "cp1251_general_cs": 52, + "macroman_bin": 53, + //"utf16_general_ci": 54, + //"utf16_bin": 55, + //"utf16le_general_ci": 56, + "cp1256_general_ci": 57, + "cp1257_bin": 58, + "cp1257_general_ci": 59, + //"utf32_general_ci": 60, + //"utf32_bin": 61, + //"utf16le_bin": 62, + "binary": 63, + "armscii8_bin": 64, + "ascii_bin": 65, + "cp1250_bin": 66, + "cp1256_bin": 67, + "cp866_bin": 68, + "dec8_bin": 69, + "greek_bin": 70, + "hebrew_bin": 71, + "hp8_bin": 72, + "keybcs2_bin": 73, + "koi8r_bin": 74, + "koi8u_bin": 75, + "utf8_tolower_ci": 76, + "latin2_bin": 77, + "latin5_bin": 78, + "latin7_bin": 79, + "cp850_bin": 80, + "cp852_bin": 81, + "swe7_bin": 82, + "utf8_bin": 83, + "big5_bin": 84, + "euckr_bin": 85, + "gb2312_bin": 86, + "gbk_bin": 87, + "sjis_bin": 88, + "tis620_bin": 89, + //"ucs2_bin": 90, + "ujis_bin": 91, + "geostd8_general_ci": 92, + "geostd8_bin": 93, + "latin1_spanish_ci": 94, + "cp932_japanese_ci": 95, + "cp932_bin": 96, + "eucjpms_japanese_ci": 97, + "eucjpms_bin": 98, + "cp1250_polish_ci": 99, + //"utf16_unicode_ci": 101, + //"utf16_icelandic_ci": 102, + //"utf16_latvian_ci": 103, + //"utf16_romanian_ci": 104, + //"utf16_slovenian_ci": 105, + //"utf16_polish_ci": 106, + //"utf16_estonian_ci": 107, + //"utf16_spanish_ci": 108, + //"utf16_swedish_ci": 109, + //"utf16_turkish_ci": 110, + //"utf16_czech_ci": 111, + //"utf16_danish_ci": 112, + //"utf16_lithuanian_ci": 113, + //"utf16_slovak_ci": 114, + //"utf16_spanish2_ci": 115, + //"utf16_roman_ci": 116, + //"utf16_persian_ci": 117, + //"utf16_esperanto_ci": 118, + //"utf16_hungarian_ci": 119, + //"utf16_sinhala_ci": 120, + //"utf16_german2_ci": 121, + //"utf16_croatian_ci": 122, + //"utf16_unicode_520_ci": 123, + //"utf16_vietnamese_ci": 124, + //"ucs2_unicode_ci": 128, + //"ucs2_icelandic_ci": 129, + //"ucs2_latvian_ci": 130, + //"ucs2_romanian_ci": 131, + //"ucs2_slovenian_ci": 132, + //"ucs2_polish_ci": 133, + //"ucs2_estonian_ci": 134, + //"ucs2_spanish_ci": 135, + //"ucs2_swedish_ci": 136, + //"ucs2_turkish_ci": 137, + //"ucs2_czech_ci": 138, + //"ucs2_danish_ci": 139, + //"ucs2_lithuanian_ci": 140, + //"ucs2_slovak_ci": 141, + //"ucs2_spanish2_ci": 142, + //"ucs2_roman_ci": 143, + //"ucs2_persian_ci": 144, + //"ucs2_esperanto_ci": 145, + //"ucs2_hungarian_ci": 146, + //"ucs2_sinhala_ci": 147, + //"ucs2_german2_ci": 148, + //"ucs2_croatian_ci": 149, + //"ucs2_unicode_520_ci": 150, + //"ucs2_vietnamese_ci": 151, + //"ucs2_general_mysql500_ci": 159, + //"utf32_unicode_ci": 160, + //"utf32_icelandic_ci": 161, + //"utf32_latvian_ci": 162, + //"utf32_romanian_ci": 163, + //"utf32_slovenian_ci": 164, + //"utf32_polish_ci": 165, + //"utf32_estonian_ci": 166, + //"utf32_spanish_ci": 167, + //"utf32_swedish_ci": 168, + //"utf32_turkish_ci": 169, + //"utf32_czech_ci": 170, + //"utf32_danish_ci": 171, + //"utf32_lithuanian_ci": 172, + //"utf32_slovak_ci": 173, + //"utf32_spanish2_ci": 174, + //"utf32_roman_ci": 175, + //"utf32_persian_ci": 176, + //"utf32_esperanto_ci": 177, + //"utf32_hungarian_ci": 178, + //"utf32_sinhala_ci": 179, + //"utf32_german2_ci": 180, + //"utf32_croatian_ci": 181, + //"utf32_unicode_520_ci": 182, + //"utf32_vietnamese_ci": 183, "utf8_unicode_ci": 192, "utf8_icelandic_ci": 193, "utf8_latvian_ci": 194, @@ -234,18 +241,25 @@ var collations = map[string]byte{ "utf8mb4_croatian_ci": 245, "utf8mb4_unicode_520_ci": 246, "utf8mb4_vietnamese_ci": 247, + "gb18030_chinese_ci": 248, + "gb18030_bin": 249, + "gb18030_unicode_520_ci": 250, + "utf8mb4_0900_ai_ci": 255, } -// A blacklist of collations which is unsafe to interpolate parameters. +// A denylist of collations which is unsafe to interpolate parameters. // These multibyte encodings may contains 0x5c (`\`) in their trailing bytes. var unsafeCollations = map[string]bool{ - "big5_chinese_ci": true, - "sjis_japanese_ci": true, - "gbk_chinese_ci": true, - "big5_bin": true, - "gb2312_bin": true, - "gbk_bin": true, - "sjis_bin": true, - "cp932_japanese_ci": true, - "cp932_bin": true, + "big5_chinese_ci": true, + "sjis_japanese_ci": true, + "gbk_chinese_ci": true, + "big5_bin": true, + "gb2312_bin": true, + "gbk_bin": true, + "sjis_bin": true, + "cp932_japanese_ci": true, + "cp932_bin": true, + "gb18030_chinese_ci": true, + "gb18030_bin": true, + "gb18030_unicode_520_ci": true, } diff --git a/conncheck.go b/conncheck.go new file mode 100644 index 000000000..024eb2858 --- /dev/null +++ b/conncheck.go @@ -0,0 +1,54 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2019 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos + +package mysql + +import ( + "errors" + "io" + "net" + "syscall" +) + +var errUnexpectedRead = errors.New("unexpected read from socket") + +func connCheck(conn net.Conn) error { + var sysErr error + + sysConn, ok := conn.(syscall.Conn) + if !ok { + return nil + } + rawConn, err := sysConn.SyscallConn() + if err != nil { + return err + } + + err = rawConn.Read(func(fd uintptr) bool { + var buf [1]byte + n, err := syscall.Read(int(fd), buf[:]) + switch { + case n == 0 && err == nil: + sysErr = io.EOF + case n > 0: + sysErr = errUnexpectedRead + case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK: + sysErr = nil + default: + sysErr = err + } + return true + }) + if err != nil { + return err + } + + return sysErr +} diff --git a/appengine.go b/conncheck_dummy.go similarity index 59% rename from appengine.go rename to conncheck_dummy.go index be41f2ee6..ea7fb607a 100644 --- a/appengine.go +++ b/conncheck_dummy.go @@ -1,19 +1,17 @@ // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // -// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// Copyright 2019 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. -// +build appengine +// +build !linux,!darwin,!dragonfly,!freebsd,!netbsd,!openbsd,!solaris,!illumos package mysql -import ( - "google.golang.org/appengine/cloudsql" -) +import "net" -func init() { - RegisterDial("cloudsql", cloudsql.Dial) +func connCheck(conn net.Conn) error { + return nil } diff --git a/conncheck_test.go b/conncheck_test.go new file mode 100644 index 000000000..53995517b --- /dev/null +++ b/conncheck_test.go @@ -0,0 +1,38 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos + +package mysql + +import ( + "testing" + "time" +) + +func TestStaleConnectionChecks(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("SET @@SESSION.wait_timeout = 2") + + if err := dbt.db.Ping(); err != nil { + dbt.Fatal(err) + } + + // wait for MySQL to close our connection + time.Sleep(3 * time.Second) + + tx, err := dbt.db.Begin() + if err != nil { + dbt.Fatal(err) + } + + if err := tx.Rollback(); err != nil { + dbt.Fatal(err) + } + }) +} diff --git a/connection.go b/connection.go index 911be2060..835f89729 100644 --- a/connection.go +++ b/connection.go @@ -12,6 +12,7 @@ import ( "context" "database/sql" "database/sql/driver" + "encoding/json" "io" "net" "strconv" @@ -19,19 +20,10 @@ import ( "time" ) -// a copy of context.Context for Go 1.7 and earlier -type mysqlContext interface { - Done() <-chan struct{} - Err() error - - // defined in context.Context, but not used in this driver: - // Deadline() (deadline time.Time, ok bool) - // Value(key interface{}) interface{} -} - type mysqlConn struct { buf buffer netConn net.Conn + rawConn net.Conn // underlying connection when netConn is TLS connection. affectedRows uint64 insertId uint64 cfg *Config @@ -42,10 +34,11 @@ type mysqlConn struct { status statusFlag sequence uint8 parseTime bool + reset bool // set when the Go SQL package calls ResetSession // for context support (Go 1.8+) watching bool - watcher chan<- mysqlContext + watcher chan<- context.Context closech chan struct{} finished chan<- struct{} canceled atomicError // set non-nil if conn is canceled @@ -54,9 +47,10 @@ type mysqlConn struct { // Handles parameters set in DSN after the connection is established func (mc *mysqlConn) handleParams() (err error) { + var cmdSet strings.Builder for param, val := range mc.cfg.Params { switch param { - // Charset + // Charset: character_set_connection, character_set_client, character_set_results case "charset": charsets := strings.Split(val, ",") for i := range charsets { @@ -70,12 +64,25 @@ func (mc *mysqlConn) handleParams() (err error) { return } - // System Vars + // Other system vars accumulated in a single SET command default: - err = mc.exec("SET " + param + "=" + val + "") - if err != nil { - return + if cmdSet.Len() == 0 { + // Heuristic: 29 chars for each other key=value to reduce reallocations + cmdSet.Grow(4 + len(param) + 1 + len(val) + 30*(len(mc.cfg.Params)-1)) + cmdSet.WriteString("SET ") + } else { + cmdSet.WriteByte(',') } + cmdSet.WriteString(param) + cmdSet.WriteByte('=') + cmdSet.WriteString(val) + } + } + + if cmdSet.Len() > 0 { + err = mc.exec(cmdSet.String()) + if err != nil { + return } } @@ -162,7 +169,9 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { // Send command err := mc.writeCommandPacketStr(comStmtPrepare, query) if err != nil { - return nil, mc.markBadConn(err) + // STMT_PREPARE is safe to retry. So we can return ErrBadConn here. + errLog.Print(err) + return nil, driver.ErrBadConn } stmt := &mysqlStmt{ @@ -192,10 +201,10 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin return "", driver.ErrSkip } - buf := mc.buf.takeCompleteBuffer() - if buf == nil { + buf, err := mc.buf.takeCompleteBuffer() + if err != nil { // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return "", ErrInvalidConn } buf = buf[:0] @@ -221,6 +230,9 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin switch v := arg.(type) { case int64: buf = strconv.AppendInt(buf, v, 10) + case uint64: + // Handle uint64 explicitly because our custom ConvertValue emits unsigned values + buf = strconv.AppendUint(buf, v, 10) case float64: buf = strconv.AppendFloat(buf, v, 'g', -1, 64) case bool: @@ -233,47 +245,21 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin if v.IsZero() { buf = append(buf, "'0000-00-00'"...) } else { - v := v.In(mc.cfg.Loc) - v = v.Add(time.Nanosecond * 500) // To round under microsecond - year := v.Year() - year100 := year / 100 - year1 := year % 100 - month := v.Month() - day := v.Day() - hour := v.Hour() - minute := v.Minute() - second := v.Second() - micro := v.Nanosecond() / 1000 - - buf = append(buf, []byte{ - '\'', - digits10[year100], digits01[year100], - digits10[year1], digits01[year1], - '-', - digits10[month], digits01[month], - '-', - digits10[day], digits01[day], - ' ', - digits10[hour], digits01[hour], - ':', - digits10[minute], digits01[minute], - ':', - digits10[second], digits01[second], - }...) - - if micro != 0 { - micro10000 := micro / 10000 - micro100 := micro / 100 % 100 - micro1 := micro % 100 - buf = append(buf, []byte{ - '.', - digits10[micro10000], digits01[micro10000], - digits10[micro100], digits01[micro100], - digits10[micro1], digits01[micro1], - }...) + buf = append(buf, '\'') + buf, err = appendDateTime(buf, v.In(mc.cfg.Loc)) + if err != nil { + return "", err } buf = append(buf, '\'') } + case json.RawMessage: + buf = append(buf, '\'') + if mc.status&statusNoBackslashEscapes == 0 { + buf = escapeBytesBackslash(buf, v) + } else { + buf = escapeBytesQuotes(buf, v) + } + buf = append(buf, '\'') case []byte: if v == nil { buf = append(buf, "NULL"...) @@ -475,7 +461,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) { defer mc.finish() if err = mc.writeCommandPacket(comPing); err != nil { - return + return mc.markBadConn(err) } return mc.readResultOK() @@ -483,6 +469,10 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) { // BeginTx implements driver.ConnBeginTx interface func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if mc.closed.IsSet() { + return nil, driver.ErrBadConn + } + if err := mc.watchCancel(ctx); err != nil { return nil, err } @@ -595,33 +585,32 @@ func (mc *mysqlConn) watchCancel(ctx context.Context) error { mc.cleanup() return nil } + // When ctx is already cancelled, don't watch it. + if err := ctx.Err(); err != nil { + return err + } + // When ctx is not cancellable, don't watch it. if ctx.Done() == nil { return nil } - - mc.watching = true - select { - default: - case <-ctx.Done(): - return ctx.Err() - } + // When watcher is not alive, can't watch it. if mc.watcher == nil { return nil } + mc.watching = true mc.watcher <- ctx - return nil } func (mc *mysqlConn) startWatcher() { - watcher := make(chan mysqlContext, 1) + watcher := make(chan context.Context, 1) mc.watcher = watcher finished := make(chan struct{}) mc.finished = finished go func() { for { - var ctx mysqlContext + var ctx context.Context select { case ctx = <-watcher: case <-mc.closech: @@ -650,5 +639,12 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error { if mc.closed.IsSet() { return driver.ErrBadConn } + mc.reset = true return nil } + +// IsValid implements driver.Validator interface +// (From Go 1.15) +func (mc *mysqlConn) IsValid() bool { + return !mc.closed.IsSet() +} diff --git a/connection_test.go b/connection_test.go index dec376117..a6d677308 100644 --- a/connection_test.go +++ b/connection_test.go @@ -9,7 +9,11 @@ package mysql import ( + "context" "database/sql/driver" + "encoding/json" + "errors" + "net" "testing" ) @@ -33,6 +37,33 @@ func TestInterpolateParams(t *testing.T) { } } +func TestInterpolateParamsJSONRawMessage(t *testing.T) { + mc := &mysqlConn{ + buf: newBuffer(nil), + maxAllowedPacket: maxPacketSize, + cfg: &Config{ + InterpolateParams: true, + }, + } + + buf, err := json.Marshal(struct { + Value int `json:"value"` + }{Value: 42}) + if err != nil { + t.Errorf("Expected err=nil, got %#v", err) + return + } + q, err := mc.interpolateParams("SELECT ?", []driver.Value{json.RawMessage(buf)}) + if err != nil { + t.Errorf("Expected err=nil, got %#v", err) + return + } + expected := `SELECT '{\"value\":42}'` + if q != expected { + t.Errorf("Expected: %q\nGot: %q", expected, q) + } +} + func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { mc := &mysqlConn{ buf: newBuffer(nil), @@ -66,6 +97,24 @@ func TestInterpolateParamsPlaceholderInString(t *testing.T) { } } +func TestInterpolateParamsUint64(t *testing.T) { + mc := &mysqlConn{ + buf: newBuffer(nil), + maxAllowedPacket: maxPacketSize, + cfg: &Config{ + InterpolateParams: true, + }, + } + + q, err := mc.interpolateParams("SELECT ?", []driver.Value{uint64(42)}) + if err != nil { + t.Errorf("Expected err=nil, got err=%#v, q=%#v", err, q) + } + if q != "SELECT 42" { + t.Errorf("Expected uint64 interpolation to work, got q=%#v", q) + } +} + func TestCheckNamedValue(t *testing.T) { value := driver.NamedValue{Value: ^uint64(0)} x := &mysqlConn{} @@ -75,7 +124,80 @@ func TestCheckNamedValue(t *testing.T) { t.Fatal("uint64 high-bit not convertible", err) } - if value.Value != "18446744073709551615" { - t.Fatalf("uint64 high-bit not converted, got %#v %T", value.Value, value.Value) + if value.Value != ^uint64(0) { + t.Fatalf("uint64 high-bit converted, got %#v %T", value.Value, value.Value) + } +} + +// TestCleanCancel tests passed context is cancelled at start. +// No packet should be sent. Connection should keep current status. +func TestCleanCancel(t *testing.T) { + mc := &mysqlConn{ + closech: make(chan struct{}), + } + mc.startWatcher() + defer mc.cleanup() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + for i := 0; i < 3; i++ { // Repeat same behavior + err := mc.Ping(ctx) + if err != context.Canceled { + t.Errorf("expected context.Canceled, got %#v", err) + } + + if mc.closed.IsSet() { + t.Error("expected mc is not closed, closed actually") + } + + if mc.watching { + t.Error("expected watching is false, but true") + } + } +} + +func TestPingMarkBadConnection(t *testing.T) { + nc := badConnection{err: errors.New("boom")} + ms := &mysqlConn{ + netConn: nc, + buf: newBuffer(nc), + maxAllowedPacket: defaultMaxAllowedPacket, + } + + err := ms.Ping(context.Background()) + + if err != driver.ErrBadConn { + t.Errorf("expected driver.ErrBadConn, got %#v", err) + } +} + +func TestPingErrInvalidConn(t *testing.T) { + nc := badConnection{err: errors.New("failed to write"), n: 10} + ms := &mysqlConn{ + netConn: nc, + buf: newBuffer(nc), + maxAllowedPacket: defaultMaxAllowedPacket, + closech: make(chan struct{}), } + + err := ms.Ping(context.Background()) + + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %#v", err) + } +} + +type badConnection struct { + n int + err error + net.Conn +} + +func (bc badConnection) Write(b []byte) (n int, err error) { + return bc.n, bc.err +} + +func (bc badConnection) Close() error { + return nil } diff --git a/connector.go b/connector.go new file mode 100644 index 000000000..d567b4e4f --- /dev/null +++ b/connector.go @@ -0,0 +1,146 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "context" + "database/sql/driver" + "net" +) + +type connector struct { + cfg *Config // immutable private copy. +} + +// Connect implements driver.Connector interface. +// Connect returns a connection to the database. +func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { + var err error + + // New mysqlConn + mc := &mysqlConn{ + maxAllowedPacket: maxPacketSize, + maxWriteSize: maxPacketSize - 1, + closech: make(chan struct{}), + cfg: c.cfg, + } + mc.parseTime = mc.cfg.ParseTime + + // Connect to Server + dialsLock.RLock() + dial, ok := dials[mc.cfg.Net] + dialsLock.RUnlock() + if ok { + dctx := ctx + if mc.cfg.Timeout > 0 { + var cancel context.CancelFunc + dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout) + defer cancel() + } + mc.netConn, err = dial(dctx, mc.cfg.Addr) + } else { + nd := net.Dialer{Timeout: mc.cfg.Timeout} + mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr) + } + + if err != nil { + return nil, err + } + + // Enable TCP Keepalives on TCP connections + if tc, ok := mc.netConn.(*net.TCPConn); ok { + if err := tc.SetKeepAlive(true); err != nil { + // Don't send COM_QUIT before handshake. + mc.netConn.Close() + mc.netConn = nil + return nil, err + } + } + + // Call startWatcher for context support (From Go 1.8) + mc.startWatcher() + if err := mc.watchCancel(ctx); err != nil { + mc.cleanup() + return nil, err + } + defer mc.finish() + + mc.buf = newBuffer(mc.netConn) + + // Set I/O timeouts + mc.buf.timeout = mc.cfg.ReadTimeout + mc.writeTimeout = mc.cfg.WriteTimeout + + // Reading Handshake Initialization Packet + authData, plugin, err := mc.readHandshakePacket() + if err != nil { + mc.cleanup() + return nil, err + } + + if plugin == "" { + plugin = defaultAuthPlugin + } + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + // try the default auth plugin, if using the requested plugin failed + errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) + plugin = defaultAuthPlugin + authResp, err = mc.auth(authData, plugin) + if err != nil { + mc.cleanup() + return nil, err + } + } + if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil { + mc.cleanup() + return nil, err + } + + // Handle response to auth packet, switch methods if possible + if err = mc.handleAuthResult(authData, plugin); err != nil { + // Authentication failed and MySQL has already closed the connection + // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). + // Do not send COM_QUIT, just cleanup and return the error. + mc.cleanup() + return nil, err + } + + if mc.cfg.MaxAllowedPacket > 0 { + mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket + } else { + // Get max allowed packet size + maxap, err := mc.getSystemVar("max_allowed_packet") + if err != nil { + mc.Close() + return nil, err + } + mc.maxAllowedPacket = stringToInt(maxap) - 1 + } + if mc.maxAllowedPacket < maxPacketSize { + mc.maxWriteSize = mc.maxAllowedPacket + } + + // Handle DSN Params + err = mc.handleParams() + if err != nil { + mc.Close() + return nil, err + } + + return mc, nil +} + +// Driver implements driver.Connector interface. +// Driver returns &MySQLDriver{}. +func (c *connector) Driver() driver.Driver { + return &MySQLDriver{} +} diff --git a/connector_test.go b/connector_test.go new file mode 100644 index 000000000..976903c5b --- /dev/null +++ b/connector_test.go @@ -0,0 +1,30 @@ +package mysql + +import ( + "context" + "net" + "testing" + "time" +) + +func TestConnectorReturnsTimeout(t *testing.T) { + connector := &connector{&Config{ + Net: "tcp", + Addr: "1.1.1.1:1234", + Timeout: 10 * time.Millisecond, + }} + + _, err := connector.Connect(context.Background()) + if err == nil { + t.Fatal("error expected") + } + + if nerr, ok := err.(*net.OpError); ok { + expected := "dial tcp 1.1.1.1:1234: i/o timeout" + if nerr.Error() != expected { + t.Fatalf("expected %q, got %q", expected, nerr.Error()) + } + } else { + t.Fatalf("expected %T, got %T", nerr, err) + } +} diff --git a/driver.go b/driver.go index 8c35de73c..c1bdf1199 100644 --- a/driver.go +++ b/driver.go @@ -17,6 +17,7 @@ package mysql import ( + "context" "database/sql" "database/sql/driver" "net" @@ -29,134 +30,78 @@ type MySQLDriver struct{} // DialFunc is a function which can be used to establish the network connection. // Custom dial functions must be registered with RegisterDial +// +// Deprecated: users should register a DialContextFunc instead type DialFunc func(addr string) (net.Conn, error) +// DialContextFunc is a function which can be used to establish the network connection. +// Custom dial functions must be registered with RegisterDialContext +type DialContextFunc func(ctx context.Context, addr string) (net.Conn, error) + var ( dialsLock sync.RWMutex - dials map[string]DialFunc + dials map[string]DialContextFunc ) -// RegisterDial registers a custom dial function. It can then be used by the +// RegisterDialContext registers a custom dial function. It can then be used by the // network address mynet(addr), where mynet is the registered new network. -// addr is passed as a parameter to the dial function. -func RegisterDial(net string, dial DialFunc) { +// The current context for the connection and its address is passed to the dial function. +func RegisterDialContext(net string, dial DialContextFunc) { dialsLock.Lock() defer dialsLock.Unlock() if dials == nil { - dials = make(map[string]DialFunc) + dials = make(map[string]DialContextFunc) } dials[net] = dial } +// RegisterDial registers a custom dial function. It can then be used by the +// network address mynet(addr), where mynet is the registered new network. +// addr is passed as a parameter to the dial function. +// +// Deprecated: users should call RegisterDialContext instead +func RegisterDial(network string, dial DialFunc) { + RegisterDialContext(network, func(_ context.Context, addr string) (net.Conn, error) { + return dial(addr) + }) +} + // Open new Connection. // See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how -// the DSN string is formated +// the DSN string is formatted func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { - var err error - - // New mysqlConn - mc := &mysqlConn{ - maxAllowedPacket: maxPacketSize, - maxWriteSize: maxPacketSize - 1, - closech: make(chan struct{}), - } - mc.cfg, err = ParseDSN(dsn) - if err != nil { - return nil, err - } - mc.parseTime = mc.cfg.ParseTime - - // Connect to Server - dialsLock.RLock() - dial, ok := dials[mc.cfg.Net] - dialsLock.RUnlock() - if ok { - mc.netConn, err = dial(mc.cfg.Addr) - } else { - nd := net.Dialer{Timeout: mc.cfg.Timeout} - mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) - } + cfg, err := ParseDSN(dsn) if err != nil { return nil, err } - - // Enable TCP Keepalives on TCP connections - if tc, ok := mc.netConn.(*net.TCPConn); ok { - if err := tc.SetKeepAlive(true); err != nil { - // Don't send COM_QUIT before handshake. - mc.netConn.Close() - mc.netConn = nil - return nil, err - } - } - - // Call startWatcher for context support (From Go 1.8) - mc.startWatcher() - - mc.buf = newBuffer(mc.netConn) - - // Set I/O timeouts - mc.buf.timeout = mc.cfg.ReadTimeout - mc.writeTimeout = mc.cfg.WriteTimeout - - // Reading Handshake Initialization Packet - authData, plugin, err := mc.readHandshakePacket() - if err != nil { - mc.cleanup() - return nil, err + c := &connector{ + cfg: cfg, } + return c.Connect(context.Background()) +} - // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) - if err != nil { - // try the default auth plugin, if using the requested plugin failed - errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) - plugin = defaultAuthPlugin - authResp, addNUL, err = mc.auth(authData, plugin) - if err != nil { - mc.cleanup() - return nil, err - } - } - if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil { - mc.cleanup() - return nil, err - } +func init() { + sql.Register("mysql", &MySQLDriver{}) +} - // Handle response to auth packet, switch methods if possible - if err = mc.handleAuthResult(authData, plugin); err != nil { - // Authentication failed and MySQL has already closed the connection - // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). - // Do not send COM_QUIT, just cleanup and return the error. - mc.cleanup() +// NewConnector returns new driver.Connector. +func NewConnector(cfg *Config) (driver.Connector, error) { + cfg = cfg.Clone() + // normalize the contents of cfg so calls to NewConnector have the same + // behavior as MySQLDriver.OpenConnector + if err := cfg.normalize(); err != nil { return nil, err } + return &connector{cfg: cfg}, nil +} - if mc.cfg.MaxAllowedPacket > 0 { - mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket - } else { - // Get max allowed packet size - maxap, err := mc.getSystemVar("max_allowed_packet") - if err != nil { - mc.Close() - return nil, err - } - mc.maxAllowedPacket = stringToInt(maxap) - 1 - } - if mc.maxAllowedPacket < maxPacketSize { - mc.maxWriteSize = mc.maxAllowedPacket - } - - // Handle DSN Params - err = mc.handleParams() +// OpenConnector implements driver.DriverContext. +func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) { + cfg, err := ParseDSN(dsn) if err != nil { - mc.Close() return nil, err } - - return mc, nil -} - -func init() { - sql.Register("mysql", &MySQLDriver{}) + return &connector{ + cfg: cfg, + }, nil } diff --git a/driver_test.go b/driver_test.go index e45441efa..949d7f332 100644 --- a/driver_test.go +++ b/driver_test.go @@ -14,6 +14,7 @@ import ( "crypto/tls" "database/sql" "database/sql/driver" + "encoding/json" "fmt" "io" "io/ioutil" @@ -23,6 +24,7 @@ import ( "net/url" "os" "reflect" + "runtime" "strings" "sync" "sync/atomic" @@ -85,6 +87,23 @@ type DBTest struct { db *sql.DB } +type netErrorMock struct { + temporary bool + timeout bool +} + +func (e netErrorMock) Temporary() bool { + return e.temporary +} + +func (e netErrorMock) Timeout() bool { + return e.timeout +} + +func (e netErrorMock) Error() string { + return fmt.Sprintf("mock net error. Temporary: %v, Timeout %v", e.temporary, e.timeout) +} + func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { if !available { t.Skipf("MySQL server not running on %s", netAddr) @@ -195,6 +214,7 @@ func TestEmptyQuery(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { // just a comment, no query rows := dbt.mustQuery("--") + defer rows.Close() // will hang before #255 if rows.Next() { dbt.Errorf("next on rows must be false") @@ -213,6 +233,7 @@ func TestCRUD(t *testing.T) { if rows.Next() { dbt.Error("unexpected data in empty table") } + rows.Close() // Create Data res := dbt.mustExec("INSERT INTO test VALUES (1)") @@ -246,6 +267,7 @@ func TestCRUD(t *testing.T) { } else { dbt.Error("no data") } + rows.Close() // Update res = dbt.mustExec("UPDATE test SET value = ? WHERE value = ?", false, true) @@ -271,6 +293,7 @@ func TestCRUD(t *testing.T) { } else { dbt.Error("no data") } + rows.Close() // Delete res = dbt.mustExec("DELETE FROM test WHERE value = ?", false) @@ -334,6 +357,7 @@ func TestMultiQuery(t *testing.T) { } else { dbt.Error("no data") } + rows.Close() }) } @@ -360,6 +384,7 @@ func TestInt(t *testing.T) { } else { dbt.Errorf("%s: no data", v) } + rows.Close() dbt.mustExec("DROP TABLE IF EXISTS test") } @@ -379,6 +404,7 @@ func TestInt(t *testing.T) { } else { dbt.Errorf("%s ZEROFILL: no data", v) } + rows.Close() dbt.mustExec("DROP TABLE IF EXISTS test") } @@ -403,6 +429,7 @@ func TestFloat32(t *testing.T) { } else { dbt.Errorf("%s: no data", v) } + rows.Close() dbt.mustExec("DROP TABLE IF EXISTS test") } }) @@ -426,6 +453,7 @@ func TestFloat64(t *testing.T) { } else { dbt.Errorf("%s: no data", v) } + rows.Close() dbt.mustExec("DROP TABLE IF EXISTS test") } }) @@ -449,6 +477,7 @@ func TestFloat64Placeholder(t *testing.T) { } else { dbt.Errorf("%s: no data", v) } + rows.Close() dbt.mustExec("DROP TABLE IF EXISTS test") } }) @@ -475,6 +504,7 @@ func TestString(t *testing.T) { } else { dbt.Errorf("%s: no data", v) } + rows.Close() dbt.mustExec("DROP TABLE IF EXISTS test") } @@ -507,6 +537,7 @@ func TestRawBytes(t *testing.T) { v1 := []byte("aaa") v2 := []byte("bbb") rows := dbt.mustQuery("SELECT ?, ?", v1, v2) + defer rows.Close() if rows.Next() { var o1, o2 sql.RawBytes if err := rows.Scan(&o1, &o2); err != nil { @@ -530,6 +561,29 @@ func TestRawBytes(t *testing.T) { }) } +func TestRawMessage(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + v1 := json.RawMessage("{}") + v2 := json.RawMessage("[]") + rows := dbt.mustQuery("SELECT ?, ?", v1, v2) + defer rows.Close() + if rows.Next() { + var o1, o2 json.RawMessage + if err := rows.Scan(&o1, &o2); err != nil { + dbt.Errorf("Got error: %v", err) + } + if !bytes.Equal(v1, o1) { + dbt.Errorf("expected %v, got %v", v1, o1) + } + if !bytes.Equal(v2, o2) { + dbt.Errorf("expected %v, got %v", v2, o2) + } + } else { + dbt.Errorf("no data") + } + }) +} + type testValuer struct { value string } @@ -555,6 +609,7 @@ func TestValuer(t *testing.T) { } else { dbt.Errorf("Valuer: no data") } + rows.Close() dbt.mustExec("DROP TABLE IF EXISTS test") }) @@ -867,6 +922,7 @@ func TestTimestampMicros(t *testing.T) { dbt.mustExec("INSERT INTO test SET value0=?, value1=?, value6=?", f0, f1, f6) var res0, res1, res6 string rows := dbt.mustQuery("SELECT * FROM test") + defer rows.Close() if !rows.Next() { dbt.Errorf("test contained no selectable values") } @@ -1025,6 +1081,7 @@ func TestNULL(t *testing.T) { var out interface{} rows := dbt.mustQuery("SELECT * FROM test") + defer rows.Close() if rows.Next() { rows.Scan(&out) if out != nil { @@ -1104,6 +1161,7 @@ func TestLongData(t *testing.T) { inS := in[:maxAllowedPacketSize-nonDataQueryLen] dbt.mustExec("INSERT INTO test VALUES('" + inS + "')") rows = dbt.mustQuery("SELECT value FROM test") + defer rows.Close() if rows.Next() { rows.Scan(&out) if inS != out { @@ -1122,6 +1180,7 @@ func TestLongData(t *testing.T) { // Long binary data dbt.mustExec("INSERT INTO test VALUES(?)", in) rows = dbt.mustQuery("SELECT value FROM test WHERE 1=?", 1) + defer rows.Close() if rows.Next() { rows.Scan(&out) if in != out { @@ -1287,7 +1346,7 @@ func TestFoundRows(t *testing.T) { } func TestTLS(t *testing.T) { - tlsTest := func(dbt *DBTest) { + tlsTestReq := func(dbt *DBTest) { if err := dbt.db.Ping(); err != nil { if err == ErrNoTLS { dbt.Skip("server does not support TLS") @@ -1297,6 +1356,7 @@ func TestTLS(t *testing.T) { } rows := dbt.mustQuery("SHOW STATUS LIKE 'Ssl_cipher'") + defer rows.Close() var variable, value *sql.RawBytes for rows.Next() { @@ -1304,19 +1364,27 @@ func TestTLS(t *testing.T) { dbt.Fatal(err.Error()) } - if value == nil { - dbt.Fatal("no Cipher") + if (*value == nil) || (len(*value) == 0) { + dbt.Fatalf("no Cipher") + } else { + dbt.Logf("Cipher: %s", *value) } } } + tlsTestOpt := func(dbt *DBTest) { + if err := dbt.db.Ping(); err != nil { + dbt.Fatalf("error on Ping: %s", err.Error()) + } + } - runTests(t, dsn+"&tls=skip-verify", tlsTest) + runTests(t, dsn+"&tls=preferred", tlsTestOpt) + runTests(t, dsn+"&tls=skip-verify", tlsTestReq) // Verify that registering / using a custom cfg works RegisterTLSConfig("custom-skip-verify", &tls.Config{ InsecureSkipVerify: true, }) - runTests(t, dsn+"&tls=custom-skip-verify", tlsTest) + runTests(t, dsn+"&tls=custom-skip-verify", tlsTestReq) } func TestReuseClosedConnection(t *testing.T) { @@ -1382,11 +1450,11 @@ func TestCharset(t *testing.T) { mustSetCharset("charset=ascii", "ascii") // when the first charset is invalid, use the second - mustSetCharset("charset=none,utf8", "utf8") + mustSetCharset("charset=none,utf8mb4", "utf8mb4") // when the first charset is valid, use it - mustSetCharset("charset=ascii,utf8", "ascii") - mustSetCharset("charset=utf8,ascii", "utf8") + mustSetCharset("charset=ascii,utf8mb4", "ascii") + mustSetCharset("charset=utf8mb4,ascii", "utf8mb4") } func TestFailingCharset(t *testing.T) { @@ -1405,7 +1473,7 @@ func TestCollation(t *testing.T) { t.Skipf("MySQL server not running on %s", netAddr) } - defaultCollation := "utf8_general_ci" + defaultCollation := "utf8mb4_general_ci" testCollations := []string{ "", // do not set defaultCollation, // driver default @@ -1449,9 +1517,9 @@ func TestColumnsWithAlias(t *testing.T) { if cols[0] != "A" { t.Fatalf("expected column name \"A\", got \"%s\"", cols[0]) } - rows.Close() rows = dbt.mustQuery("SELECT * FROM (SELECT 1 AS one) AS A") + defer rows.Close() cols, _ = rows.Columns() if len(cols) != 1 { t.Fatalf("expected 1 column, got %d", len(cols)) @@ -1495,6 +1563,7 @@ func TestTimezoneConversion(t *testing.T) { // Retrieve time from DB rows := dbt.mustQuery("SELECT ts FROM test") + defer rows.Close() if !rows.Next() { dbt.Fatal("did not get any rows out") } @@ -1738,6 +1807,14 @@ func TestConcurrent(t *testing.T) { } runTests(t, dsn, func(dbt *DBTest) { + var version string + if err := dbt.db.QueryRow("SELECT @@version").Scan(&version); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if strings.Contains(strings.ToLower(version), "mariadb") { + t.Skip(`TODO: "fix commands out of sync. Did you run multiple statements at once?" on MariaDB`) + } + var max int err := dbt.db.QueryRow("SELECT @@max_connections").Scan(&max) if err != nil { @@ -1801,6 +1878,38 @@ func TestConcurrent(t *testing.T) { }) } +func testDialError(t *testing.T, dialErr error, expectErr error) { + RegisterDialContext("mydial", func(ctx context.Context, addr string) (net.Conn, error) { + return nil, dialErr + }) + + db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + _, err = db.Exec("DO 1") + if err != expectErr { + t.Fatalf("was expecting %s. Got: %s", dialErr, err) + } +} + +func TestDialUnknownError(t *testing.T) { + testErr := fmt.Errorf("test") + testDialError(t, testErr, testErr) +} + +func TestDialNonRetryableNetErr(t *testing.T) { + testErr := netErrorMock{} + testDialError(t, testErr, testErr) +} + +func TestDialTemporaryNetErr(t *testing.T) { + testErr := netErrorMock{temporary: true} + testDialError(t, testErr, testErr) +} + // Tests custom dial functions func TestCustomDial(t *testing.T) { if !available { @@ -1808,8 +1917,9 @@ func TestCustomDial(t *testing.T) { } // our custom dial function which justs wraps net.Dial here - RegisterDial("mydial", func(addr string) (net.Conn, error) { - return net.Dial(prot, addr) + RegisterDialContext("mydial", func(ctx context.Context, addr string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, prot, addr) }) db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) @@ -1960,6 +2070,7 @@ func TestInterruptBySignal(t *testing.T) { dbt.Errorf("expected val to be 42") } } + rows.Close() // binary protocol rows, err = dbt.db.Query("CALL test_signal(?)", 42) @@ -1973,6 +2084,7 @@ func TestInterruptBySignal(t *testing.T) { dbt.Errorf("expected val to be 42") } } + rows.Close() }) } @@ -2552,10 +2664,19 @@ func TestContextCancelStmtQuery(t *testing.T) { } func TestContextCancelBegin(t *testing.T) { + if runtime.GOOS == "windows" || runtime.GOOS == "darwin" { + t.Skip(`FIXME: it sometime fails with "expected driver.ErrBadConn, got sql: connection is already closed" on windows and macOS`) + } + runTests(t, dsn, func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) - tx, err := dbt.db.BeginTx(ctx, nil) + conn, err := dbt.db.Conn(ctx) + if err != nil { + dbt.Fatal(err) + } + defer conn.Close() + tx, err := conn.BeginTx(ctx, nil) if err != nil { dbt.Fatal(err) } @@ -2585,7 +2706,17 @@ func TestContextCancelBegin(t *testing.T) { dbt.Errorf("expected sql.ErrTxDone or context.Canceled, got %v", err) } - // Context is canceled, so cannot begin a transaction. + // The connection is now in an inoperable state - so performing other + // operations should fail with ErrBadConn + // Important to exercise isolation level too - it runs SET TRANSACTION ISOLATION + // LEVEL XXX first, which needs to return ErrBadConn if the connection's context + // is cancelled + _, err = conn.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelReadCommitted}) + if err != driver.ErrBadConn { + dbt.Errorf("expected driver.ErrBadConn, got %v", err) + } + + // cannot begin a transaction (on a different conn) with a canceled context if _, err := dbt.db.BeginTx(ctx, nil); err != context.Canceled { dbt.Errorf("expected context.Canceled, got %v", err) } @@ -2690,13 +2821,13 @@ func TestRowsColumnTypes(t *testing.T) { nfNULL := sql.NullFloat64{Float64: 0.0, Valid: false} nf0 := sql.NullFloat64{Float64: 0.0, Valid: true} nf1337 := sql.NullFloat64{Float64: 13.37, Valid: true} - nt0 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 0, time.UTC), Valid: true} - nt1 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 100000000, time.UTC), Valid: true} - nt2 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 110000000, time.UTC), Valid: true} - nt6 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 111111000, time.UTC), Valid: true} - nd1 := NullTime{Time: time.Date(2006, 01, 02, 0, 0, 0, 0, time.UTC), Valid: true} - nd2 := NullTime{Time: time.Date(2006, 03, 04, 0, 0, 0, 0, time.UTC), Valid: true} - ndNULL := NullTime{Time: time.Time{}, Valid: false} + nt0 := sql.NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 0, time.UTC), Valid: true} + nt1 := sql.NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 100000000, time.UTC), Valid: true} + nt2 := sql.NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 110000000, time.UTC), Valid: true} + nt6 := sql.NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 111111000, time.UTC), Valid: true} + nd1 := sql.NullTime{Time: time.Date(2006, 01, 02, 0, 0, 0, 0, time.UTC), Valid: true} + nd2 := sql.NullTime{Time: time.Date(2006, 03, 04, 0, 0, 0, 0, time.UTC), Valid: true} + ndNULL := sql.NullTime{Time: time.Time{}, Valid: false} rbNULL := sql.RawBytes(nil) rb0 := sql.RawBytes("0") rb42 := sql.RawBytes("42") @@ -2727,10 +2858,10 @@ func TestRowsColumnTypes(t *testing.T) { {"mediumintnull", "MEDIUMINT", "MEDIUMINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "42", "NULL"}, [3]interface{}{ni0, ni42, niNULL}}, {"bigint", "BIGINT NOT NULL", "BIGINT", scanTypeInt64, false, 0, 0, [3]string{"0", "65535", "-42"}, [3]interface{}{int64(0), int64(65535), int64(-42)}}, {"bigintnull", "BIGINT", "BIGINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "1", "42"}, [3]interface{}{niNULL, ni1, ni42}}, - {"tinyuint", "TINYINT UNSIGNED NOT NULL", "TINYINT", scanTypeUint8, false, 0, 0, [3]string{"0", "255", "42"}, [3]interface{}{uint8(0), uint8(255), uint8(42)}}, - {"smalluint", "SMALLINT UNSIGNED NOT NULL", "SMALLINT", scanTypeUint16, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint16(0), uint16(65535), uint16(42)}}, - {"biguint", "BIGINT UNSIGNED NOT NULL", "BIGINT", scanTypeUint64, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint64(0), uint64(65535), uint64(42)}}, - {"uint13", "INT(13) UNSIGNED NOT NULL", "INT", scanTypeUint32, false, 0, 0, [3]string{"0", "1337", "42"}, [3]interface{}{uint32(0), uint32(1337), uint32(42)}}, + {"tinyuint", "TINYINT UNSIGNED NOT NULL", "UNSIGNED TINYINT", scanTypeUint8, false, 0, 0, [3]string{"0", "255", "42"}, [3]interface{}{uint8(0), uint8(255), uint8(42)}}, + {"smalluint", "SMALLINT UNSIGNED NOT NULL", "UNSIGNED SMALLINT", scanTypeUint16, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint16(0), uint16(65535), uint16(42)}}, + {"biguint", "BIGINT UNSIGNED NOT NULL", "UNSIGNED BIGINT", scanTypeUint64, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint64(0), uint64(65535), uint64(42)}}, + {"uint13", "INT(13) UNSIGNED NOT NULL", "UNSIGNED INT", scanTypeUint32, false, 0, 0, [3]string{"0", "1337", "42"}, [3]interface{}{uint32(0), uint32(1337), uint32(42)}}, {"float", "FLOAT NOT NULL", "FLOAT", scanTypeFloat32, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float32(0), float32(42), float32(13.37)}}, {"floatnull", "FLOAT", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, {"float74null", "FLOAT(7,4)", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, 4, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, @@ -2776,131 +2907,125 @@ func TestRowsColumnTypes(t *testing.T) { values2 = values2[:len(values2)-2] values3 = values3[:len(values3)-2] - dsns := []string{ - dsn + "&parseTime=true", - dsn + "&parseTime=false", - } - for _, testdsn := range dsns { - runTests(t, testdsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (" + schema + ")") - dbt.mustExec("INSERT INTO test VALUES (" + values1 + "), (" + values2 + "), (" + values3 + ")") + runTests(t, dsn+"&parseTime=true", func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (" + schema + ")") + dbt.mustExec("INSERT INTO test VALUES (" + values1 + "), (" + values2 + "), (" + values3 + ")") - rows, err := dbt.db.Query("SELECT * FROM test") - if err != nil { - t.Fatalf("Query: %v", err) - } + rows, err := dbt.db.Query("SELECT * FROM test") + if err != nil { + t.Fatalf("Query: %v", err) + } - tt, err := rows.ColumnTypes() - if err != nil { - t.Fatalf("ColumnTypes: %v", err) - } + tt, err := rows.ColumnTypes() + if err != nil { + t.Fatalf("ColumnTypes: %v", err) + } - if len(tt) != len(columns) { - t.Fatalf("unexpected number of columns: expected %d, got %d", len(columns), len(tt)) - } + if len(tt) != len(columns) { + t.Fatalf("unexpected number of columns: expected %d, got %d", len(columns), len(tt)) + } - types := make([]reflect.Type, len(tt)) - for i, tp := range tt { - column := columns[i] + types := make([]reflect.Type, len(tt)) + for i, tp := range tt { + column := columns[i] - // Name - name := tp.Name() - if name != column.name { - t.Errorf("column name mismatch %s != %s", name, column.name) - continue - } + // Name + name := tp.Name() + if name != column.name { + t.Errorf("column name mismatch %s != %s", name, column.name) + continue + } - // DatabaseTypeName - databaseTypeName := tp.DatabaseTypeName() - if databaseTypeName != column.databaseTypeName { - t.Errorf("databasetypename name mismatch for column %q: %s != %s", name, databaseTypeName, column.databaseTypeName) - continue - } + // DatabaseTypeName + databaseTypeName := tp.DatabaseTypeName() + if databaseTypeName != column.databaseTypeName { + t.Errorf("databasetypename name mismatch for column %q: %s != %s", name, databaseTypeName, column.databaseTypeName) + continue + } - // ScanType - scanType := tp.ScanType() - if scanType != column.scanType { - if scanType == nil { - t.Errorf("scantype is null for column %q", name) - } else { - t.Errorf("scantype mismatch for column %q: %s != %s", name, scanType.Name(), column.scanType.Name()) - } - continue + // ScanType + scanType := tp.ScanType() + if scanType != column.scanType { + if scanType == nil { + t.Errorf("scantype is null for column %q", name) + } else { + t.Errorf("scantype mismatch for column %q: %s != %s", name, scanType.Name(), column.scanType.Name()) } - types[i] = scanType - - // Nullable - nullable, ok := tp.Nullable() + continue + } + types[i] = scanType + + // Nullable + nullable, ok := tp.Nullable() + if !ok { + t.Errorf("nullable not ok %q", name) + continue + } + if nullable != column.nullable { + t.Errorf("nullable mismatch for column %q: %t != %t", name, nullable, column.nullable) + } + + // Length + // length, ok := tp.Length() + // if length != column.length { + // if !ok { + // t.Errorf("length not ok for column %q", name) + // } else { + // t.Errorf("length mismatch for column %q: %d != %d", name, length, column.length) + // } + // continue + // } + + // Precision and Scale + precision, scale, ok := tp.DecimalSize() + if precision != column.precision { if !ok { - t.Errorf("nullable not ok %q", name) - continue - } - if nullable != column.nullable { - t.Errorf("nullable mismatch for column %q: %t != %t", name, nullable, column.nullable) + t.Errorf("precision not ok for column %q", name) + } else { + t.Errorf("precision mismatch for column %q: %d != %d", name, precision, column.precision) } - - // Length - // length, ok := tp.Length() - // if length != column.length { - // if !ok { - // t.Errorf("length not ok for column %q", name) - // } else { - // t.Errorf("length mismatch for column %q: %d != %d", name, length, column.length) - // } - // continue - // } - - // Precision and Scale - precision, scale, ok := tp.DecimalSize() - if precision != column.precision { - if !ok { - t.Errorf("precision not ok for column %q", name) - } else { - t.Errorf("precision mismatch for column %q: %d != %d", name, precision, column.precision) - } - continue - } - if scale != column.scale { - if !ok { - t.Errorf("scale not ok for column %q", name) - } else { - t.Errorf("scale mismatch for column %q: %d != %d", name, scale, column.scale) - } - continue + continue + } + if scale != column.scale { + if !ok { + t.Errorf("scale not ok for column %q", name) + } else { + t.Errorf("scale mismatch for column %q: %d != %d", name, scale, column.scale) } + continue } + } - values := make([]interface{}, len(tt)) - for i := range values { - values[i] = reflect.New(types[i]).Interface() + values := make([]interface{}, len(tt)) + for i := range values { + values[i] = reflect.New(types[i]).Interface() + } + i := 0 + for rows.Next() { + err = rows.Scan(values...) + if err != nil { + t.Fatalf("failed to scan values in %v", err) } - i := 0 - for rows.Next() { - err = rows.Scan(values...) - if err != nil { - t.Fatalf("failed to scan values in %v", err) - } - for j := range values { - value := reflect.ValueOf(values[j]).Elem().Interface() - if !reflect.DeepEqual(value, columns[j].valuesOut[i]) { - if columns[j].scanType == scanTypeRawBytes { - t.Errorf("row %d, column %d: %v != %v", i, j, string(value.(sql.RawBytes)), string(columns[j].valuesOut[i].(sql.RawBytes))) - } else { - t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i]) - } + for j := range values { + value := reflect.ValueOf(values[j]).Elem().Interface() + if !reflect.DeepEqual(value, columns[j].valuesOut[i]) { + if columns[j].scanType == scanTypeRawBytes { + t.Errorf("row %d, column %d: %v != %v", i, j, string(value.(sql.RawBytes)), string(columns[j].valuesOut[i].(sql.RawBytes))) + } else { + t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i]) } } - i++ - } - if i != 3 { - t.Errorf("expected 3 rows, got %d", i) } + i++ + } + if i != 3 { + t.Errorf("expected 3 rows, got %d", i) + } - if err := rows.Close(); err != nil { - t.Errorf("error closing rows: %s", err) - } - }) - } + if err := rows.Close(); err != nil { + t.Errorf("error closing rows: %s", err) + } + }) } func TestValuerWithValueReceiverGivenNilValue(t *testing.T) { @@ -2910,3 +3035,227 @@ func TestValuerWithValueReceiverGivenNilValue(t *testing.T) { // This test will panic on the INSERT if ConvertValue() does not check for typed nil before calling Value() }) } + +// TestRawBytesAreNotModified checks for a race condition that arises when a query context +// is canceled while a user is calling rows.Scan. This is a more stringent test than the one +// proposed in https://github.com/golang/go/issues/23519. Here we're explicitly using +// `sql.RawBytes` to check the contents of our internal buffers are not modified after an implicit +// call to `Rows.Close`, so Context cancellation should **not** invalidate the backing buffers. +func TestRawBytesAreNotModified(t *testing.T) { + const blob = "abcdefghijklmnop" + const contextRaceIterations = 20 + const blobSize = defaultBufSize * 3 / 4 // Second row overwrites first row. + const insertRows = 4 + + var sqlBlobs = [2]string{ + strings.Repeat(blob, blobSize/len(blob)), + strings.Repeat(strings.ToUpper(blob), blobSize/len(blob)), + } + + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8") + for i := 0; i < insertRows; i++ { + dbt.mustExec("INSERT INTO test VALUES (?, ?)", i+1, sqlBlobs[i&1]) + } + + for i := 0; i < contextRaceIterations; i++ { + func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rows, err := dbt.db.QueryContext(ctx, `SELECT id, value FROM test`) + if err != nil { + t.Fatal(err) + } + + var b int + var raw sql.RawBytes + for rows.Next() { + if err := rows.Scan(&b, &raw); err != nil { + t.Fatal(err) + } + + before := string(raw) + // Ensure cancelling the query does not corrupt the contents of `raw` + cancel() + time.Sleep(time.Microsecond * 100) + after := string(raw) + + if before != after { + t.Fatalf("the backing storage for sql.RawBytes has been modified (i=%v)", i) + } + } + rows.Close() + }() + } + }) +} + +var _ driver.DriverContext = &MySQLDriver{} + +type dialCtxKey struct{} + +func TestConnectorObeysDialTimeouts(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + RegisterDialContext("dialctxtest", func(ctx context.Context, addr string) (net.Conn, error) { + var d net.Dialer + if !ctx.Value(dialCtxKey{}).(bool) { + return nil, fmt.Errorf("test error: query context is not propagated to our dialer") + } + return d.DialContext(ctx, prot, addr) + }) + + db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@dialctxtest(%s)/%s?timeout=30s", user, pass, addr, dbname)) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + ctx := context.WithValue(context.Background(), dialCtxKey{}, true) + + _, err = db.ExecContext(ctx, "DO 1") + if err != nil { + t.Fatal(err) + } +} + +func configForTests(t *testing.T) *Config { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + mycnf := NewConfig() + mycnf.User = user + mycnf.Passwd = pass + mycnf.Addr = addr + mycnf.Net = prot + mycnf.DBName = dbname + return mycnf +} + +func TestNewConnector(t *testing.T) { + mycnf := configForTests(t) + conn, err := NewConnector(mycnf) + if err != nil { + t.Fatal(err) + } + + db := sql.OpenDB(conn) + defer db.Close() + + if err := db.Ping(); err != nil { + t.Fatal(err) + } +} + +type slowConnection struct { + net.Conn + slowdown time.Duration +} + +func (sc *slowConnection) Read(b []byte) (int, error) { + time.Sleep(sc.slowdown) + return sc.Conn.Read(b) +} + +type connectorHijack struct { + driver.Connector + connErr error +} + +func (cw *connectorHijack) Connect(ctx context.Context) (driver.Conn, error) { + var conn driver.Conn + conn, cw.connErr = cw.Connector.Connect(ctx) + return conn, cw.connErr +} + +func TestConnectorTimeoutsDuringOpen(t *testing.T) { + RegisterDialContext("slowconn", func(ctx context.Context, addr string) (net.Conn, error) { + var d net.Dialer + conn, err := d.DialContext(ctx, prot, addr) + if err != nil { + return nil, err + } + return &slowConnection{Conn: conn, slowdown: 100 * time.Millisecond}, nil + }) + + mycnf := configForTests(t) + mycnf.Net = "slowconn" + + conn, err := NewConnector(mycnf) + if err != nil { + t.Fatal(err) + } + + hijack := &connectorHijack{Connector: conn} + + db := sql.OpenDB(hijack) + defer db.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, err = db.ExecContext(ctx, "DO 1") + if err != context.DeadlineExceeded { + t.Fatalf("ExecContext should have timed out") + } + if hijack.connErr != context.DeadlineExceeded { + t.Fatalf("(*Connector).Connect should have timed out") + } +} + +// A connection which can only be closed. +type dummyConnection struct { + net.Conn + closed bool +} + +func (d *dummyConnection) Close() error { + d.closed = true + return nil +} + +func TestConnectorTimeoutsWatchCancel(t *testing.T) { + var ( + cancel func() // Used to cancel the context just after connecting. + created *dummyConnection // The created connection. + ) + + RegisterDialContext("TestConnectorTimeoutsWatchCancel", func(ctx context.Context, addr string) (net.Conn, error) { + // Canceling at this time triggers the watchCancel error branch in Connect(). + cancel() + created = &dummyConnection{} + return created, nil + }) + + mycnf := NewConfig() + mycnf.User = "root" + mycnf.Addr = "foo" + mycnf.Net = "TestConnectorTimeoutsWatchCancel" + + conn, err := NewConnector(mycnf) + if err != nil { + t.Fatal(err) + } + + db := sql.OpenDB(conn) + defer db.Close() + + var ctx context.Context + ctx, cancel = context.WithCancel(context.Background()) + defer cancel() + + if _, err := db.Conn(ctx); err != context.Canceled { + t.Errorf("got %v, want context.Canceled", err) + } + + if created == nil { + t.Fatal("no connection created") + } + if !created.closed { + t.Errorf("connection not closed") + } +} diff --git a/dsn.go b/dsn.go index d0098ff71..7227ae7f5 100644 --- a/dsn.go +++ b/dsn.go @@ -14,6 +14,7 @@ import ( "crypto/tls" "errors" "fmt" + "math/big" "net" "net/url" "sort" @@ -55,6 +56,7 @@ type Config struct { AllowCleartextPasswords bool // Allows the cleartext client side plugin AllowNativePasswords bool // Allows the native password authentication method AllowOldPasswords bool // Allows the old insecure password method + CheckConnLiveness bool // Check connections for liveness before using them ClientFoundRows bool // Return number of matching rows instead of rows changed ColumnsWithAlias bool // Prepend table alias to column names InterpolateParams bool // Interpolate placeholders into query string @@ -70,9 +72,30 @@ func NewConfig() *Config { Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, + CheckConnLiveness: true, } } +func (cfg *Config) Clone() *Config { + cp := *cfg + if cp.tls != nil { + cp.tls = cfg.tls.Clone() + } + if len(cp.Params) > 0 { + cp.Params = make(map[string]string, len(cfg.Params)) + for k, v := range cfg.Params { + cp.Params[k] = v + } + } + if cfg.pubKey != nil { + cp.pubKey = &rsa.PublicKey{ + N: new(big.Int).Set(cfg.pubKey.N), + E: cfg.pubKey.E, + } + } + return &cp +} + func (cfg *Config) normalize() error { if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { return errInvalidDSNUnsafeCollation @@ -93,23 +116,54 @@ func (cfg *Config) normalize() error { default: return errors.New("default addr for network '" + cfg.Net + "' unknown") } - } else if cfg.Net == "tcp" { cfg.Addr = ensureHavePort(cfg.Addr) } - if cfg.tls != nil { - if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify { - host, _, err := net.SplitHostPort(cfg.Addr) - if err == nil { - cfg.tls.ServerName = host - } + switch cfg.TLSConfig { + case "false", "": + // don't set anything + case "true": + cfg.tls = &tls.Config{} + case "skip-verify", "preferred": + cfg.tls = &tls.Config{InsecureSkipVerify: true} + default: + cfg.tls = getTLSConfigClone(cfg.TLSConfig) + if cfg.tls == nil { + return errors.New("invalid value / unknown config name: " + cfg.TLSConfig) + } + } + + if cfg.tls != nil && cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify { + host, _, err := net.SplitHostPort(cfg.Addr) + if err == nil { + cfg.tls.ServerName = host + } + } + + if cfg.ServerPubKey != "" { + cfg.pubKey = getServerPubKey(cfg.ServerPubKey) + if cfg.pubKey == nil { + return errors.New("invalid value / unknown server pub key name: " + cfg.ServerPubKey) } } return nil } +func writeDSNParam(buf *bytes.Buffer, hasParam *bool, name, value string) { + buf.Grow(1 + len(name) + 1 + len(value)) + if !*hasParam { + *hasParam = true + buf.WriteByte('?') + } else { + buf.WriteByte('&') + } + buf.WriteString(name) + buf.WriteByte('=') + buf.WriteString(value) +} + // FormatDSN formats the given Config into a DSN string which can be passed to // the driver. func (cfg *Config) FormatDSN() string { @@ -148,165 +202,75 @@ func (cfg *Config) FormatDSN() string { } if cfg.AllowCleartextPasswords { - if hasParam { - buf.WriteString("&allowCleartextPasswords=true") - } else { - hasParam = true - buf.WriteString("?allowCleartextPasswords=true") - } + writeDSNParam(&buf, &hasParam, "allowCleartextPasswords", "true") } if !cfg.AllowNativePasswords { - if hasParam { - buf.WriteString("&allowNativePasswords=false") - } else { - hasParam = true - buf.WriteString("?allowNativePasswords=false") - } + writeDSNParam(&buf, &hasParam, "allowNativePasswords", "false") } if cfg.AllowOldPasswords { - if hasParam { - buf.WriteString("&allowOldPasswords=true") - } else { - hasParam = true - buf.WriteString("?allowOldPasswords=true") - } + writeDSNParam(&buf, &hasParam, "allowOldPasswords", "true") + } + + if !cfg.CheckConnLiveness { + writeDSNParam(&buf, &hasParam, "checkConnLiveness", "false") } if cfg.ClientFoundRows { - if hasParam { - buf.WriteString("&clientFoundRows=true") - } else { - hasParam = true - buf.WriteString("?clientFoundRows=true") - } + writeDSNParam(&buf, &hasParam, "clientFoundRows", "true") } if col := cfg.Collation; col != defaultCollation && len(col) > 0 { - if hasParam { - buf.WriteString("&collation=") - } else { - hasParam = true - buf.WriteString("?collation=") - } - buf.WriteString(col) + writeDSNParam(&buf, &hasParam, "collation", col) } if cfg.ColumnsWithAlias { - if hasParam { - buf.WriteString("&columnsWithAlias=true") - } else { - hasParam = true - buf.WriteString("?columnsWithAlias=true") - } + writeDSNParam(&buf, &hasParam, "columnsWithAlias", "true") } if cfg.InterpolateParams { - if hasParam { - buf.WriteString("&interpolateParams=true") - } else { - hasParam = true - buf.WriteString("?interpolateParams=true") - } + writeDSNParam(&buf, &hasParam, "interpolateParams", "true") } if cfg.Loc != time.UTC && cfg.Loc != nil { - if hasParam { - buf.WriteString("&loc=") - } else { - hasParam = true - buf.WriteString("?loc=") - } - buf.WriteString(url.QueryEscape(cfg.Loc.String())) + writeDSNParam(&buf, &hasParam, "loc", url.QueryEscape(cfg.Loc.String())) } if cfg.MultiStatements { - if hasParam { - buf.WriteString("&multiStatements=true") - } else { - hasParam = true - buf.WriteString("?multiStatements=true") - } + writeDSNParam(&buf, &hasParam, "multiStatements", "true") } if cfg.ParseTime { - if hasParam { - buf.WriteString("&parseTime=true") - } else { - hasParam = true - buf.WriteString("?parseTime=true") - } + writeDSNParam(&buf, &hasParam, "parseTime", "true") } if cfg.ReadTimeout > 0 { - if hasParam { - buf.WriteString("&readTimeout=") - } else { - hasParam = true - buf.WriteString("?readTimeout=") - } - buf.WriteString(cfg.ReadTimeout.String()) + writeDSNParam(&buf, &hasParam, "readTimeout", cfg.ReadTimeout.String()) } if cfg.RejectReadOnly { - if hasParam { - buf.WriteString("&rejectReadOnly=true") - } else { - hasParam = true - buf.WriteString("?rejectReadOnly=true") - } + writeDSNParam(&buf, &hasParam, "rejectReadOnly", "true") } if len(cfg.ServerPubKey) > 0 { - if hasParam { - buf.WriteString("&serverPubKey=") - } else { - hasParam = true - buf.WriteString("?serverPubKey=") - } - buf.WriteString(url.QueryEscape(cfg.ServerPubKey)) + writeDSNParam(&buf, &hasParam, "serverPubKey", url.QueryEscape(cfg.ServerPubKey)) } if cfg.Timeout > 0 { - if hasParam { - buf.WriteString("&timeout=") - } else { - hasParam = true - buf.WriteString("?timeout=") - } - buf.WriteString(cfg.Timeout.String()) + writeDSNParam(&buf, &hasParam, "timeout", cfg.Timeout.String()) } if len(cfg.TLSConfig) > 0 { - if hasParam { - buf.WriteString("&tls=") - } else { - hasParam = true - buf.WriteString("?tls=") - } - buf.WriteString(url.QueryEscape(cfg.TLSConfig)) + writeDSNParam(&buf, &hasParam, "tls", url.QueryEscape(cfg.TLSConfig)) } if cfg.WriteTimeout > 0 { - if hasParam { - buf.WriteString("&writeTimeout=") - } else { - hasParam = true - buf.WriteString("?writeTimeout=") - } - buf.WriteString(cfg.WriteTimeout.String()) + writeDSNParam(&buf, &hasParam, "writeTimeout", cfg.WriteTimeout.String()) } if cfg.MaxAllowedPacket != defaultMaxAllowedPacket { - if hasParam { - buf.WriteString("&maxAllowedPacket=") - } else { - hasParam = true - buf.WriteString("?maxAllowedPacket=") - } - buf.WriteString(strconv.Itoa(cfg.MaxAllowedPacket)) - + writeDSNParam(&buf, &hasParam, "maxAllowedPacket", strconv.Itoa(cfg.MaxAllowedPacket)) } if len(cfg.ConnectAttrs) > 0 { @@ -341,16 +305,7 @@ func (cfg *Config) FormatDSN() string { } sort.Strings(params) for _, param := range params { - if hasParam { - buf.WriteByte('&') - } else { - hasParam = true - buf.WriteByte('?') - } - - buf.WriteString(param) - buf.WriteByte('=') - buf.WriteString(url.QueryEscape(cfg.Params[param])) + writeDSNParam(&buf, &hasParam, param, url.QueryEscape(cfg.Params[param])) } } @@ -445,7 +400,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { // cfg params switch value := param[1]; param[0] { - // Disable INFILE whitelist / enable all files + // Disable INFILE allowlist / enable all files case "allowAllFiles": var isBool bool cfg.AllowAllFiles, isBool = readBool(value) @@ -477,6 +432,14 @@ func parseDSNParams(cfg *Config, params string) (err error) { return errors.New("invalid bool value: " + value) } + // Check connections for Liveness before using them + case "checkConnLiveness": + var isBool bool + cfg.CheckConnLiveness, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + // Switch "rowsAffected" mode case "clientFoundRows": var isBool bool @@ -488,7 +451,6 @@ func parseDSNParams(cfg *Config, params string) (err error) { // Collation case "collation": cfg.Collation = value - break case "columnsWithAlias": var isBool bool @@ -556,13 +518,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { if err != nil { return fmt.Errorf("invalid value for server pub key name: %v", err) } - - if pubKey := getServerPubKey(name); pubKey != nil { - cfg.ServerPubKey = name - cfg.pubKey = pubKey - } else { - return errors.New("invalid value / unknown server pub key name: " + name) - } + cfg.ServerPubKey = name // Strict mode case "strict": @@ -581,25 +537,17 @@ func parseDSNParams(cfg *Config, params string) (err error) { if isBool { if boolValue { cfg.TLSConfig = "true" - cfg.tls = &tls.Config{} } else { cfg.TLSConfig = "false" } - } else if vl := strings.ToLower(value); vl == "skip-verify" { + } else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" { cfg.TLSConfig = vl - cfg.tls = &tls.Config{InsecureSkipVerify: true} } else { name, err := url.QueryUnescape(value) if err != nil { return fmt.Errorf("invalid value for TLS config name: %v", err) } - - if tlsConfig := getTLSConfigClone(name); tlsConfig != nil { - cfg.TLSConfig = name - cfg.tls = tlsConfig - } else { - return errors.New("invalid value / unknown config name: " + name) - } + cfg.TLSConfig = name } // I/O write Timeout diff --git a/dsn_test.go b/dsn_test.go index dae4c24e0..728241e32 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -22,55 +22,55 @@ var testDSNs = []struct { out *Config }{{ "username:password@protocol(address)/dbname?param=value", - &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", - &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, ColumnsWithAlias: true}, + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, ColumnsWithAlias: true}, }, { "username:password@protocol(address)/dbname?param=value&columnsWithAlias=true&multiStatements=true", - &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, ColumnsWithAlias: true, MultiStatements: true}, + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, ColumnsWithAlias: true, MultiStatements: true}, }, { "user@unix(/path/to/socket)/dbname?charset=utf8", - &Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, + &Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "true"}, + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "true"}, }, { "user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "skip-verify"}, + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "skip-verify"}, }, { - "user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, ClientFoundRows: true, MaxAllowedPacket: 16777216}, + "user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216&tls=false&allowCleartextPasswords=true&parseTime=true&rejectReadOnly=true", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, TLSConfig: "false", AllowCleartextPasswords: true, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, CheckConnLiveness: true, ClientFoundRows: true, MaxAllowedPacket: 16777216, ParseTime: true, RejectReadOnly: true}, }, { - "user:password@/dbname?allowNativePasswords=false&maxAllowedPacket=0", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false}, + "user:password@/dbname?allowNativePasswords=false&checkConnLiveness=false&maxAllowedPacket=0", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false, CheckConnLiveness: false}, }, { "user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", - &Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, + &Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "/dbname", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "@/", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "/", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "user:p@/ssword@/", - &Config{User: "user", Passwd: "p@/ssword", Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, + &Config{User: "user", Passwd: "p@/ssword", Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "unix/?arg=%2Fsome%2Fpath.ext", - &Config{Net: "unix", Addr: "/tmp/mysql.sock", Params: map[string]string{"arg": "/some/path.ext"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, + &Config{Net: "unix", Addr: "/tmp/mysql.sock", Params: map[string]string{"arg": "/some/path.ext"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "tcp(127.0.0.1)/dbname", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "tcp(de:ad:be:ef::ca:fe)/dbname", - &Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, + &Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "tcp(127.0.0.1)/dbname?connectAttrs=program_name:SomeService", &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", ConnectAttrs: map[string]string{"program_name": "SomeService"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, @@ -95,13 +95,14 @@ func TestDSNParser(t *testing.T) { func TestDSNParserInvalid(t *testing.T) { var invalidDSNs = []string{ - "@net(addr/", // no closing brace - "@tcp(/", // no closing brace - "tcp(/", // no closing brace - "(/", // no closing brace - "net(addr)//", // unescaped - "User:pass@tcp(1.2.3.4:3306)", // no trailing slash - "net()/", // unknown default addr + "@net(addr/", // no closing brace + "@tcp(/", // no closing brace + "tcp(/", // no closing brace + "(/", // no closing brace + "net(addr)//", // unescaped + "User:pass@tcp(1.2.3.4:3306)", // no trailing slash + "net()/", // unknown default addr + "user:pass@tcp(127.0.0.1:3306)/db/name", // invalid dbname //"/dbname?arg=/some/unescaped/path", } @@ -321,6 +322,90 @@ func TestParamsAreSorted(t *testing.T) { } } +func TestCloneConfig(t *testing.T) { + RegisterServerPubKey("testKey", testPubKeyRSA) + defer DeregisterServerPubKey("testKey") + + expectedServerName := "example.com" + dsn := "tcp(example.com:1234)/?tls=true&foobar=baz&serverPubKey=testKey" + cfg, err := ParseDSN(dsn) + if err != nil { + t.Fatal(err.Error()) + } + + cfg2 := cfg.Clone() + if cfg == cfg2 { + t.Errorf("Config.Clone did not create a separate config struct") + } + + if cfg2.tls.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName) + } + + cfg2.tls.ServerName = "example2.com" + if cfg.tls.ServerName == cfg2.tls.ServerName { + t.Errorf("changed cfg.tls.Server name should not propagate to original Config") + } + + if _, ok := cfg2.Params["foobar"]; !ok { + t.Errorf("cloned Config is missing custom params") + } + + delete(cfg2.Params, "foobar") + + if _, ok := cfg.Params["foobar"]; !ok { + t.Errorf("custom params in cloned Config should not propagate to original Config") + } + + if !reflect.DeepEqual(cfg.pubKey, cfg2.pubKey) { + t.Errorf("public key in Config should be identical") + } +} + +func TestNormalizeTLSConfig(t *testing.T) { + tt := []struct { + tlsConfig string + want *tls.Config + }{ + {"", nil}, + {"false", nil}, + {"true", &tls.Config{ServerName: "myserver"}}, + {"skip-verify", &tls.Config{InsecureSkipVerify: true}}, + {"preferred", &tls.Config{InsecureSkipVerify: true}}, + {"test_tls_config", &tls.Config{ServerName: "myServerName"}}, + } + + RegisterTLSConfig("test_tls_config", &tls.Config{ServerName: "myServerName"}) + defer func() { DeregisterTLSConfig("test_tls_config") }() + + for _, tc := range tt { + t.Run(tc.tlsConfig, func(t *testing.T) { + cfg := &Config{ + Addr: "myserver:3306", + TLSConfig: tc.tlsConfig, + } + + cfg.normalize() + + if cfg.tls == nil { + if tc.want != nil { + t.Fatal("wanted a tls config but got nil instead") + } + return + } + + if cfg.tls.ServerName != tc.want.ServerName { + t.Errorf("tls.ServerName doesn't match (want: '%s', got: '%s')", + tc.want.ServerName, cfg.tls.ServerName) + } + if cfg.tls.InsecureSkipVerify != tc.want.InsecureSkipVerify { + t.Errorf("tls.InsecureSkipVerify doesn't match (want: %T, got :%T)", + tc.want.InsecureSkipVerify, cfg.tls.InsecureSkipVerify) + } + }) + } +} + func TestAttributesAreSorted(t *testing.T) { expected := "/dbname?connectAttrs=p1:v1,p2:v2" cfg := NewConfig() diff --git a/errors.go b/errors.go index 760782ff2..92cc9a361 100644 --- a/errors.go +++ b/errors.go @@ -63,3 +63,10 @@ type MySQLError struct { func (me *MySQLError) Error() string { return fmt.Sprintf("Error %d: %s", me.Number, me.Message) } + +func (me *MySQLError) Is(err error) bool { + if merr, ok := err.(*MySQLError); ok { + return merr.Number == me.Number + } + return false +} diff --git a/errors_test.go b/errors_test.go index 96f9126d6..3a1aef74d 100644 --- a/errors_test.go +++ b/errors_test.go @@ -10,6 +10,7 @@ package mysql import ( "bytes" + "errors" "log" "testing" ) @@ -40,3 +41,21 @@ func TestErrorsStrictIgnoreNotes(t *testing.T) { dbt.mustExec("DROP TABLE IF EXISTS does_not_exist") }) } + +func TestMySQLErrIs(t *testing.T) { + infraErr := &MySQLError{1234, "the server is on fire"} + otherInfraErr := &MySQLError{1234, "the datacenter is flooded"} + if !errors.Is(infraErr, otherInfraErr) { + t.Errorf("expected errors to be the same: %+v %+v", infraErr, otherInfraErr) + } + + differentCodeErr := &MySQLError{5678, "the server is on fire"} + if errors.Is(infraErr, differentCodeErr) { + t.Fatalf("expected errors to be different: %+v %+v", infraErr, differentCodeErr) + } + + nonMysqlErr := errors.New("not a mysql error") + if errors.Is(infraErr, nonMysqlErr) { + t.Fatalf("expected errors to be different: %+v %+v", infraErr, nonMysqlErr) + } +} diff --git a/fields.go b/fields.go index e1e2ece4b..e0654a83d 100644 --- a/fields.go +++ b/fields.go @@ -41,6 +41,9 @@ func (mf *mysqlField) typeDatabaseName() string { case fieldTypeJSON: return "JSON" case fieldTypeLong: + if mf.flags&flagUnsigned != 0 { + return "UNSIGNED INT" + } return "INT" case fieldTypeLongBLOB: if mf.charSet != collations[binaryCollation] { @@ -48,6 +51,9 @@ func (mf *mysqlField) typeDatabaseName() string { } return "LONGBLOB" case fieldTypeLongLong: + if mf.flags&flagUnsigned != 0 { + return "UNSIGNED BIGINT" + } return "BIGINT" case fieldTypeMediumBLOB: if mf.charSet != collations[binaryCollation] { @@ -63,6 +69,9 @@ func (mf *mysqlField) typeDatabaseName() string { case fieldTypeSet: return "SET" case fieldTypeShort: + if mf.flags&flagUnsigned != 0 { + return "UNSIGNED SMALLINT" + } return "SMALLINT" case fieldTypeString: if mf.charSet == collations[binaryCollation] { @@ -74,6 +83,9 @@ func (mf *mysqlField) typeDatabaseName() string { case fieldTypeTimestamp: return "TIMESTAMP" case fieldTypeTiny: + if mf.flags&flagUnsigned != 0 { + return "UNSIGNED TINYINT" + } return "TINYINT" case fieldTypeTinyBLOB: if mf.charSet != collations[binaryCollation] { @@ -106,7 +118,7 @@ var ( scanTypeInt64 = reflect.TypeOf(int64(0)) scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) scanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) - scanTypeNullTime = reflect.TypeOf(NullTime{}) + scanTypeNullTime = reflect.TypeOf(sql.NullTime{}) scanTypeUint8 = reflect.TypeOf(uint8(0)) scanTypeUint16 = reflect.TypeOf(uint16(0)) scanTypeUint32 = reflect.TypeOf(uint32(0)) diff --git a/fuzz.go b/fuzz.go new file mode 100644 index 000000000..fa75adf6a --- /dev/null +++ b/fuzz.go @@ -0,0 +1,24 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package. +// +// Copyright 2020 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build gofuzz + +package mysql + +import ( + "database/sql" +) + +func Fuzz(data []byte) int { + db, err := sql.Open("mysql", string(data)) + if err != nil { + return 0 + } + db.Close() + return 1 +} diff --git a/go.mod b/go.mod new file mode 100644 index 000000000..251110478 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/go-sql-driver/mysql + +go 1.13 diff --git a/infile.go b/infile.go index 273cb0ba5..60effdfc2 100644 --- a/infile.go +++ b/infile.go @@ -23,7 +23,7 @@ var ( readerRegisterLock sync.RWMutex ) -// RegisterLocalFile adds the given file to the file whitelist, +// RegisterLocalFile adds the given file to the file allowlist, // so that it can be used by "LOAD DATA LOCAL INFILE ". // Alternatively you can allow the use of all local files with // the DSN parameter 'allowAllFiles=true' @@ -45,7 +45,7 @@ func RegisterLocalFile(filePath string) { fileRegisterLock.Unlock() } -// DeregisterLocalFile removes the given filepath from the whitelist. +// DeregisterLocalFile removes the given filepath from the allowlist. func DeregisterLocalFile(filePath string) { fileRegisterLock.Lock() delete(fileRegister, strings.Trim(filePath, `"`)) diff --git a/nulltime.go b/nulltime.go new file mode 100644 index 000000000..17af92ddc --- /dev/null +++ b/nulltime.go @@ -0,0 +1,71 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "time" +) + +// NullTime represents a time.Time that may be NULL. +// NullTime implements the Scanner interface so +// it can be used as a scan destination: +// +// var nt NullTime +// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) +// ... +// if nt.Valid { +// // use nt.Time +// } else { +// // NULL value +// } +// +// This NullTime implementation is not driver-specific +// +// Deprecated: NullTime doesn't honor the loc DSN parameter. +// NullTime.Scan interprets a time as UTC, not the loc DSN parameter. +// Use sql.NullTime instead. +type NullTime sql.NullTime + +// Scan implements the Scanner interface. +// The value type must be time.Time or string / []byte (formatted time-string), +// otherwise Scan fails. +func (nt *NullTime) Scan(value interface{}) (err error) { + if value == nil { + nt.Time, nt.Valid = time.Time{}, false + return + } + + switch v := value.(type) { + case time.Time: + nt.Time, nt.Valid = v, true + return + case []byte: + nt.Time, err = parseDateTime(v, time.UTC) + nt.Valid = (err == nil) + return + case string: + nt.Time, err = parseDateTime([]byte(v), time.UTC) + nt.Valid = (err == nil) + return + } + + nt.Valid = false + return fmt.Errorf("Can't convert %T to time.Time", value) +} + +// Value implements the driver Valuer interface. +func (nt NullTime) Value() (driver.Value, error) { + if !nt.Valid { + return nil, nil + } + return nt.Time, nil +} diff --git a/nulltime_test.go b/nulltime_test.go new file mode 100644 index 000000000..a14ec0607 --- /dev/null +++ b/nulltime_test.go @@ -0,0 +1,62 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "database/sql" + "database/sql/driver" + "testing" + "time" +) + +var ( + // Check implementation of interfaces + _ driver.Valuer = NullTime{} + _ sql.Scanner = (*NullTime)(nil) +) + +func TestScanNullTime(t *testing.T) { + var scanTests = []struct { + in interface{} + error bool + valid bool + time time.Time + }{ + {tDate, false, true, tDate}, + {sDate, false, true, tDate}, + {[]byte(sDate), false, true, tDate}, + {tDateTime, false, true, tDateTime}, + {sDateTime, false, true, tDateTime}, + {[]byte(sDateTime), false, true, tDateTime}, + {tDate0, false, true, tDate0}, + {sDate0, false, true, tDate0}, + {[]byte(sDate0), false, true, tDate0}, + {sDateTime0, false, true, tDate0}, + {[]byte(sDateTime0), false, true, tDate0}, + {"", true, false, tDate0}, + {"1234", true, false, tDate0}, + {0, true, false, tDate0}, + } + + var nt = NullTime{} + var err error + + for _, tst := range scanTests { + err = nt.Scan(tst.in) + if (err != nil) != tst.error { + t.Errorf("%v: expected error status %t, got %t", tst.in, tst.error, (err != nil)) + } + if nt.Valid != tst.valid { + t.Errorf("%v: expected valid status %t, got %t", tst.in, tst.valid, nt.Valid) + } + if nt.Time != tst.time { + t.Errorf("%v: expected time %v, got %v", tst.in, tst.time, nt.Time) + } + } +} diff --git a/packets.go b/packets.go index 13f85fe89..7d1134aa6 100644 --- a/packets.go +++ b/packets.go @@ -13,6 +13,7 @@ import ( "crypto/tls" "database/sql/driver" "encoding/binary" + "encoding/json" "errors" "fmt" "io" @@ -51,7 +52,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { mc.sequence++ // packets with length 0 terminate a previous packet which is a - // multiple of (2^24)−1 bytes long + // multiple of (2^24)-1 bytes long if pktLen == 0 { // there was no previous packet if prevData == nil { @@ -96,6 +97,35 @@ func (mc *mysqlConn) writePacket(data []byte) error { return ErrPktTooLarge } + // Perform a stale connection check. We only perform this check for + // the first query on a connection that has been checked out of the + // connection pool: a fresh connection from the pool is more likely + // to be stale, and it has not performed any previous writes that + // could cause data corruption, so it's safe to return ErrBadConn + // if the check fails. + if mc.reset { + mc.reset = false + conn := mc.netConn + if mc.rawConn != nil { + conn = mc.rawConn + } + var err error + // If this connection has a ReadTimeout which we've been setting on + // reads, reset it to its default value before we attempt a non-blocking + // read, otherwise the scheduler will just time us out before we can read + if mc.cfg.ReadTimeout != 0 { + err = conn.SetReadDeadline(time.Time{}) + } + if err == nil && mc.cfg.CheckConnLiveness { + err = connCheck(conn) + } + if err != nil { + errLog.Print("closing bad idle connection: ", err) + mc.Close() + return driver.ErrBadConn + } + } + for { var size int if pktLen >= maxPacketSize { @@ -154,15 +184,15 @@ func (mc *mysqlConn) writePacket(data []byte) error { // Handshake Initialization Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake -func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) { - data, err := mc.readPacket() +func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) { + data, err = mc.readPacket() if err != nil { // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since // in connection initialization we don't risk retrying non-idempotent actions. if err == ErrInvalidConn { return nil, "", driver.ErrBadConn } - return nil, "", err + return } if data[0] == iERR { @@ -194,11 +224,14 @@ func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) { return nil, "", ErrOldProtocol } if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { - return nil, "", ErrNoTLS + if mc.cfg.TLSConfig == "preferred" { + mc.cfg.tls = nil + } else { + return nil, "", ErrNoTLS + } } pos += 2 - plugin := "" if len(data) > pos { // character set [1 byte] // status flags [2 bytes] @@ -241,8 +274,6 @@ func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) { return b[:], plugin, nil } - plugin = defaultAuthPlugin - // make a memory safe copy of the cipher slice var b [8]byte copy(b[:], authData) @@ -251,7 +282,7 @@ func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) { // Client Authentication Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, insecureAuth bool, plugin string) error { +func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error { // Adjust client flags based on server support clientFlags := clientProtocol41 | clientLongPassword | @@ -280,7 +311,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, insecureAuth // encode length of the auth plugin data var authRespLEIBuf [9]byte - authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(authResp))) + authRespLen := len(authResp) + authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen)) if len(authRespLEI) > 1 && clientFlags&clientSecureConn != 0 { // if the length can not be written in 1 byte, it must be written as a // length encoded integer @@ -317,10 +349,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, insecureAuth } // Calculate packet length and get buffer with that size - data := mc.buf.takeSmallBuffer(pktLen + 4) - if data == nil { + data, err := mc.buf.takeSmallBuffer(pktLen + 4) + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } @@ -346,6 +378,12 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, insecureAuth return errors.New("unknown collation") } + // Filler [23 bytes] (all 0x00) + pos := 13 + for ; pos < 13+23; pos++ { + data[pos] = 0 + } + // SSL Connection Request Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest if mc.cfg.tls != nil { @@ -359,16 +397,11 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, insecureAuth if err := tlsConn.Handshake(); err != nil { return err } + mc.rawConn = mc.netConn mc.netConn = tlsConn mc.buf.nc = tlsConn } - // Filler [23 bytes] (all 0x00) - pos := 13 - for ; pos < 13+23; pos++ { - data[pos] = 0 - } - // User [null terminated string] if len(mc.cfg.User) > 0 { pos += copy(data[pos:], mc.cfg.User) @@ -411,28 +444,21 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, insecureAuth } // Send Auth packet - return mc.writePacket(data) + return mc.writePacket(data[:pos]) } // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error { +func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { pktLen := 4 + len(authData) - if addNUL { - pktLen++ - } - data := mc.buf.takeSmallBuffer(pktLen) - if data == nil { + data, err := mc.buf.takeSmallBuffer(pktLen) + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } // Add the auth data [EOF] copy(data[4:], authData) - if addNUL { - data[pktLen-1] = 0x00 - } - return mc.writePacket(data) } @@ -444,10 +470,10 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { // Reset Packet Sequence mc.sequence = 0 - data := mc.buf.takeSmallBuffer(4 + 1) - if data == nil { + data, err := mc.buf.takeSmallBuffer(4 + 1) + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } @@ -463,10 +489,10 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { mc.sequence = 0 pktLen := 1 + len(arg) - data := mc.buf.takeBuffer(pktLen + 4) - if data == nil { + data, err := mc.buf.takeBuffer(pktLen + 4) + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } @@ -484,10 +510,10 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { // Reset Packet Sequence mc.sequence = 0 - data := mc.buf.takeSmallBuffer(4 + 1 + 4) - if data == nil { + data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } @@ -524,7 +550,7 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { return data[1:], "", err case iEOF: - if len(data) < 1 { + if len(data) == 1 { // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest return nil, "mysql_old_password", nil } @@ -783,40 +809,40 @@ func (rows *textRows) readRow(dest []driver.Value) error { } // RowSet Packet - var n int - var isNull bool - pos := 0 + var ( + n int + isNull bool + pos int = 0 + ) for i := range dest { // Read bytes and convert to string dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) pos += n - if err == nil { - if !isNull { - if !mc.parseTime { - continue - } else { - switch rows.rs.columns[i].fieldType { - case fieldTypeTimestamp, fieldTypeDateTime, - fieldTypeDate, fieldTypeNewDate: - dest[i], err = parseDateTime( - string(dest[i].([]byte)), - mc.cfg.Loc, - ) - if err == nil { - continue - } - default: - continue - } - } - } else { - dest[i] = nil - continue + if err != nil { + return err + } + + if isNull { + dest[i] = nil + continue + } + + if !mc.parseTime { + continue + } + + // Parse time field + switch rows.rs.columns[i].fieldType { + case fieldTypeTimestamp, + fieldTypeDateTime, + fieldTypeDate, + fieldTypeNewDate: + if dest[i], err = parseDateTime(dest[i].([]byte), mc.cfg.Loc); err != nil { + return err } } - return err // err != nil } return nil @@ -940,7 +966,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { const minPktLen = 4 + 1 + 4 + 1 + 4 mc := stmt.mc - // Determine threshould dynamically to avoid packet size shortage. + // Determine threshold dynamically to avoid packet size shortage. longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1) if longDataSize < 64 { longDataSize = 64 @@ -950,15 +976,17 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { mc.sequence = 0 var data []byte + var err error if len(args) == 0 { - data = mc.buf.takeBuffer(minPktLen) + data, err = mc.buf.takeBuffer(minPktLen) } else { - data = mc.buf.takeCompleteBuffer() + data, err = mc.buf.takeCompleteBuffer() + // In this case the len(data) == cap(data) which is used to optimise the flow below. } - if data == nil { + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } @@ -984,7 +1012,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { pos := minPktLen var nullMask []byte - if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) { + if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= cap(data) { // buffer has to be extended but we don't know by how much so // we depend on append after all data with known sizes fit. // We stop at that because we deal with a lot of columns here @@ -993,10 +1021,11 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { copy(tmp[:pos], data[:pos]) data = tmp nullMask = data[pos : pos+maskLen] + // No need to clean nullMask as make ensures that. pos += maskLen } else { nullMask = data[pos : pos+maskLen] - for i := 0; i < maskLen; i++ { + for i := range nullMask { nullMask[i] = 0 } pos += maskLen @@ -1023,6 +1052,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { continue } + if v, ok := arg.(json.RawMessage); ok { + arg = []byte(v) + } // cache types and values switch v := arg.(type) { case int64: @@ -1041,6 +1073,22 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { ) } + case uint64: + paramTypes[i+i] = byte(fieldTypeLongLong) + paramTypes[i+i+1] = 0x80 // type is unsigned + + if cap(paramValues)-len(paramValues)-8 >= 0 { + paramValues = paramValues[:len(paramValues)+8] + binary.LittleEndian.PutUint64( + paramValues[len(paramValues)-8:], + uint64(v), + ) + } else { + paramValues = append(paramValues, + uint64ToBytes(uint64(v))..., + ) + } + case float64: paramTypes[i+i] = byte(fieldTypeDouble) paramTypes[i+i+1] = 0x00 @@ -1116,7 +1164,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if v.IsZero() { b = append(b, "0000-00-00"...) } else { - b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat) + b, err = appendDateTime(b, v.In(mc.cfg.Loc)) + if err != nil { + return err + } } paramValues = appendLengthEncodedInteger(paramValues, @@ -1133,7 +1184,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // In that case we must build the data packet with the new values buffer if valuesCap != cap(paramValues) { data = append(data[:pos], paramValues...) - mc.buf.buf = data + if err = mc.buf.store(data); err != nil { + errLog.Print(err) + return errBadConnNoWrite + } } pos += len(paramValues) diff --git a/rows.go b/rows.go index d3b1e2822..888bdb5f0 100644 --- a/rows.go +++ b/rows.go @@ -111,6 +111,13 @@ func (rows *mysqlRows) Close() (err error) { return err } + // flip the buffer for this connection if we need to drain it. + // note that for a successful query (i.e. one where rows.next() + // has been called until it returns false), `rows.mc` will be nil + // by the time the user calls `(*Rows).Close`, so we won't reach this + // see: https://github.com/golang/go/commit/651ddbdb5056ded455f47f9c494c67b389622a47 + mc.buf.flip() + // Remove unread packets from stream if !rows.rs.done { err = mc.readUntilEOF() diff --git a/statement.go b/statement.go index ce7fe4cd0..18a3ae498 100644 --- a/statement.go +++ b/statement.go @@ -10,10 +10,10 @@ package mysql import ( "database/sql/driver" + "encoding/json" "fmt" "io" "reflect" - "strconv" ) type mysqlStmt struct { @@ -44,6 +44,11 @@ func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter { return converter{} } +func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) { + nv.Value, err = converter{}.ConvertValue(nv.Value) + return +} + func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { if stmt.mc.closed.IsSet() { errLog.Print(ErrInvalidConn) @@ -130,6 +135,8 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { return rows, err } +var jsonType = reflect.TypeOf(json.RawMessage{}) + type converter struct{} // ConvertValue mirrors the reference/default converter in database/sql/driver @@ -147,12 +154,17 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { if err != nil { return nil, err } - if !driver.IsValue(sv) { - return nil, fmt.Errorf("non-Value type %T returned from Value", sv) + if driver.IsValue(sv) { + return sv, nil + } + // A value returend from the Valuer interface can be "a type handled by + // a database driver's NamedValueChecker interface" so we should accept + // uint64 here as well. + if u, ok := sv.(uint64); ok { + return u, nil } - return sv, nil + return nil, fmt.Errorf("non-Value type %T returned from Value", sv) } - rv := reflect.ValueOf(v) switch rv.Kind() { case reflect.Ptr: @@ -164,24 +176,21 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return rv.Int(), nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: - return int64(rv.Uint()), nil - case reflect.Uint64: - u64 := rv.Uint() - if u64 >= 1<<63 { - return strconv.FormatUint(u64, 10), nil - } - return int64(u64), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return rv.Uint(), nil case reflect.Float32, reflect.Float64: return rv.Float(), nil case reflect.Bool: return rv.Bool(), nil case reflect.Slice: - ek := rv.Type().Elem().Kind() - if ek == reflect.Uint8 { + switch t := rv.Type(); { + case t == jsonType: + return v, nil + case t.Elem().Kind() == reflect.Uint8: return rv.Bytes(), nil + default: + return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, t.Elem().Kind()) } - return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek) case reflect.String: return rv.String(), nil } diff --git a/statement_test.go b/statement_test.go index 98a6c1933..2563ece55 100644 --- a/statement_test.go +++ b/statement_test.go @@ -10,6 +10,8 @@ package mysql import ( "bytes" + "database/sql/driver" + "encoding/json" "testing" ) @@ -34,7 +36,7 @@ func TestConvertDerivedByteSlice(t *testing.T) { t.Fatal("Byte slice not convertible", err) } - if bytes.Compare(output.([]byte), []byte("value")) != 0 { + if !bytes.Equal(output.([]byte), []byte("value")) { t.Fatalf("Byte slice not converted, got %#v %T", output, output) } } @@ -95,6 +97,14 @@ func TestConvertSignedIntegers(t *testing.T) { } } +type myUint64 struct { + value uint64 +} + +func (u myUint64) Value() (driver.Value, error) { + return u.value, nil +} + func TestConvertUnsignedIntegers(t *testing.T) { values := []interface{}{ uint8(42), @@ -102,6 +112,7 @@ func TestConvertUnsignedIntegers(t *testing.T) { uint32(42), uint64(42), uint(42), + myUint64{uint64(42)}, } for _, value := range values { @@ -110,7 +121,7 @@ func TestConvertUnsignedIntegers(t *testing.T) { t.Fatalf("%T type not convertible %s", value, err) } - if output != int64(42) { + if output != uint64(42) { t.Fatalf("%T type not converted, got %#v %T", value, output, output) } } @@ -120,7 +131,21 @@ func TestConvertUnsignedIntegers(t *testing.T) { t.Fatal("uint64 high-bit not convertible", err) } - if output != "18446744073709551615" { - t.Fatalf("uint64 high-bit not converted, got %#v %T", output, output) + if output != ^uint64(0) { + t.Fatalf("uint64 high-bit converted, got %#v %T", output, output) + } +} + +func TestConvertJSON(t *testing.T) { + raw := json.RawMessage("{}") + + out, err := converter{}.ConvertValue(raw) + + if err != nil { + t.Fatal("json.RawMessage was failed in convert", err) + } + + if _, ok := out.(json.RawMessage); !ok { + t.Fatalf("json.RawMessage converted, got %#v %T", out, out) } } diff --git a/utils.go b/utils.go index bb1be6f1a..d303f9dc0 100644 --- a/utils.go +++ b/utils.go @@ -56,7 +56,7 @@ var ( // db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") // func RegisterTLSConfig(key string, config *tls.Config) error { - if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" { + if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" || strings.ToLower(key) == "preferred" { return fmt.Errorf("key '%s' is reserved", key) } @@ -106,81 +106,136 @@ func readBool(input string) (value bool, valid bool) { * Time related utils * ******************************************************************************/ -// NullTime represents a time.Time that may be NULL. -// NullTime implements the Scanner interface so -// it can be used as a scan destination: -// -// var nt NullTime -// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) -// ... -// if nt.Valid { -// // use nt.Time -// } else { -// // NULL value -// } -// -// This NullTime implementation is not driver-specific -type NullTime struct { - Time time.Time - Valid bool // Valid is true if Time is not NULL -} +func parseDateTime(b []byte, loc *time.Location) (time.Time, error) { + const base = "0000-00-00 00:00:00.000000" + switch len(b) { + case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM" + if string(b) == base[:len(b)] { + return time.Time{}, nil + } -// Scan implements the Scanner interface. -// The value type must be time.Time or string / []byte (formatted time-string), -// otherwise Scan fails. -func (nt *NullTime) Scan(value interface{}) (err error) { - if value == nil { - nt.Time, nt.Valid = time.Time{}, false - return - } + year, err := parseByteYear(b) + if err != nil { + return time.Time{}, err + } + if year <= 0 { + year = 1 + } + + if b[4] != '-' { + return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[4]) + } + + m, err := parseByte2Digits(b[5], b[6]) + if err != nil { + return time.Time{}, err + } + if m <= 0 { + m = 1 + } + month := time.Month(m) + + if b[7] != '-' { + return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[7]) + } + + day, err := parseByte2Digits(b[8], b[9]) + if err != nil { + return time.Time{}, err + } + if day <= 0 { + day = 1 + } + if len(b) == 10 { + return time.Date(year, month, day, 0, 0, 0, 0, loc), nil + } + + if b[10] != ' ' { + return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[10]) + } + + hour, err := parseByte2Digits(b[11], b[12]) + if err != nil { + return time.Time{}, err + } + if b[13] != ':' { + return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[13]) + } + + min, err := parseByte2Digits(b[14], b[15]) + if err != nil { + return time.Time{}, err + } + if b[16] != ':' { + return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[16]) + } + + sec, err := parseByte2Digits(b[17], b[18]) + if err != nil { + return time.Time{}, err + } + if len(b) == 19 { + return time.Date(year, month, day, hour, min, sec, 0, loc), nil + } - switch v := value.(type) { - case time.Time: - nt.Time, nt.Valid = v, true - return - case []byte: - nt.Time, err = parseDateTime(string(v), time.UTC) - nt.Valid = (err == nil) - return - case string: - nt.Time, err = parseDateTime(v, time.UTC) - nt.Valid = (err == nil) - return + if b[19] != '.' { + return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[19]) + } + nsec, err := parseByteNanoSec(b[20:]) + if err != nil { + return time.Time{}, err + } + return time.Date(year, month, day, hour, min, sec, nsec, loc), nil + default: + return time.Time{}, fmt.Errorf("invalid time bytes: %s", b) } +} - nt.Valid = false - return fmt.Errorf("Can't convert %T to time.Time", value) +func parseByteYear(b []byte) (int, error) { + year, n := 0, 1000 + for i := 0; i < 4; i++ { + v, err := bToi(b[i]) + if err != nil { + return 0, err + } + year += v * n + n /= 10 + } + return year, nil } -// Value implements the driver Valuer interface. -func (nt NullTime) Value() (driver.Value, error) { - if !nt.Valid { - return nil, nil +func parseByte2Digits(b1, b2 byte) (int, error) { + d1, err := bToi(b1) + if err != nil { + return 0, err } - return nt.Time, nil + d2, err := bToi(b2) + if err != nil { + return 0, err + } + return d1*10 + d2, nil } -func parseDateTime(str string, loc *time.Location) (t time.Time, err error) { - base := "0000-00-00 00:00:00.0000000" - switch len(str) { - case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM" - if str == base[:len(str)] { - return +func parseByteNanoSec(b []byte) (int, error) { + ns, digit := 0, 100000 // max is 6-digits + for i := 0; i < len(b); i++ { + v, err := bToi(b[i]) + if err != nil { + return 0, err } - t, err = time.Parse(timeFormat[:len(str)], str) - default: - err = fmt.Errorf("invalid time string: %s", str) - return + ns += v * digit + digit /= 10 } + // nanoseconds has 10-digits. (needs to scale digits) + // 10 - 6 = 4, so we have to multiple 1000. + return ns * 1000, nil +} - // Adjust location - if err == nil && loc != time.UTC { - y, mo, d := t.Date() - h, mi, s := t.Clock() - t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil +func bToi(b byte) (int, error) { + if b < '0' || b > '9' { + return 0, errors.New("not [0-9]") } - - return + return int(b - '0'), nil } func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) { @@ -221,6 +276,64 @@ func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Va return nil, fmt.Errorf("invalid DATETIME packet length %d", num) } +func appendDateTime(buf []byte, t time.Time) ([]byte, error) { + year, month, day := t.Date() + hour, min, sec := t.Clock() + nsec := t.Nanosecond() + + if year < 1 || year > 9999 { + return buf, errors.New("year is not in the range [1, 9999]: " + strconv.Itoa(year)) // use errors.New instead of fmt.Errorf to avoid year escape to heap + } + year100 := year / 100 + year1 := year % 100 + + var localBuf [len("2006-01-02T15:04:05.999999999")]byte // does not escape + localBuf[0], localBuf[1], localBuf[2], localBuf[3] = digits10[year100], digits01[year100], digits10[year1], digits01[year1] + localBuf[4] = '-' + localBuf[5], localBuf[6] = digits10[month], digits01[month] + localBuf[7] = '-' + localBuf[8], localBuf[9] = digits10[day], digits01[day] + + if hour == 0 && min == 0 && sec == 0 && nsec == 0 { + return append(buf, localBuf[:10]...), nil + } + + localBuf[10] = ' ' + localBuf[11], localBuf[12] = digits10[hour], digits01[hour] + localBuf[13] = ':' + localBuf[14], localBuf[15] = digits10[min], digits01[min] + localBuf[16] = ':' + localBuf[17], localBuf[18] = digits10[sec], digits01[sec] + + if nsec == 0 { + return append(buf, localBuf[:19]...), nil + } + nsec100000000 := nsec / 100000000 + nsec1000000 := (nsec / 1000000) % 100 + nsec10000 := (nsec / 10000) % 100 + nsec100 := (nsec / 100) % 100 + nsec1 := nsec % 100 + localBuf[19] = '.' + + // milli second + localBuf[20], localBuf[21], localBuf[22] = + digits01[nsec100000000], digits10[nsec1000000], digits01[nsec1000000] + // micro second + localBuf[23], localBuf[24], localBuf[25] = + digits10[nsec10000], digits01[nsec10000], digits10[nsec100] + // nano second + localBuf[26], localBuf[27], localBuf[28] = + digits01[nsec100], digits10[nsec1], digits01[nsec1] + + // trim trailing zeros + n := len(localBuf) + for n > 0 && localBuf[n-1] == '0' { + n-- + } + + return append(buf, localBuf[:n]...), nil +} + // zeroDateTime is used in formatBinaryDateTime to avoid an allocation // if the DATE or DATETIME has the zero value. // It must never be changed. @@ -683,6 +796,12 @@ type noCopy struct{} // Lock is a no-op used by -copylocks checker from `go vet`. func (*noCopy) Lock() {} +// Unlock is a no-op used by -copylocks checker from `go vet`. +// noCopy should implement sync.Locker from Go 1.11 +// https://github.com/golang/go/commit/c2eba53e7f80df21d51285879d51ab81bcfbf6bc +// https://github.com/golang/go/issues/26165 +func (*noCopy) Unlock() {} + // atomicBool is a wrapper around uint32 for usage as a boolean value with // atomic access. type atomicBool struct { @@ -690,7 +809,7 @@ type atomicBool struct { value uint32 } -// IsSet returns wether the current boolean value is true +// IsSet returns whether the current boolean value is true func (ab *atomicBool) IsSet() bool { return atomic.LoadUint32(&ab.value) > 0 } @@ -704,7 +823,7 @@ func (ab *atomicBool) Set(value bool) { } } -// TrySet sets the value of the bool and returns wether the value changed +// TrySet sets the value of the bool and returns whether the value changed func (ab *atomicBool) TrySet(value bool) bool { if value { return atomic.SwapUint32(&ab.value, 1) == 0 diff --git a/utils_test.go b/utils_test.go index 8951a7a81..b0069251e 100644 --- a/utils_test.go +++ b/utils_test.go @@ -17,46 +17,6 @@ import ( "time" ) -func TestScanNullTime(t *testing.T) { - var scanTests = []struct { - in interface{} - error bool - valid bool - time time.Time - }{ - {tDate, false, true, tDate}, - {sDate, false, true, tDate}, - {[]byte(sDate), false, true, tDate}, - {tDateTime, false, true, tDateTime}, - {sDateTime, false, true, tDateTime}, - {[]byte(sDateTime), false, true, tDateTime}, - {tDate0, false, true, tDate0}, - {sDate0, false, true, tDate0}, - {[]byte(sDate0), false, true, tDate0}, - {sDateTime0, false, true, tDate0}, - {[]byte(sDateTime0), false, true, tDate0}, - {"", true, false, tDate0}, - {"1234", true, false, tDate0}, - {0, true, false, tDate0}, - } - - var nt = NullTime{} - var err error - - for _, tst := range scanTests { - err = nt.Scan(tst.in) - if (err != nil) != tst.error { - t.Errorf("%v: expected error status %t, got %t", tst.in, tst.error, (err != nil)) - } - if nt.Valid != tst.valid { - t.Errorf("%v: expected valid status %t, got %t", tst.in, tst.valid, nt.Valid) - } - if nt.Time != tst.time { - t.Errorf("%v: expected time %v, got %v", tst.in, tst.time, nt.Time) - } - } -} - func TestLengthEncodedInteger(t *testing.T) { var integerTests = []struct { num uint64 @@ -268,7 +228,9 @@ func TestAtomicBool(t *testing.T) { t.Fatal("Expected value to be false") } - ab._noCopy.Lock() // we've "tested" it ¯\_(ツ)_/¯ + // we've "tested" them ¯\_(ツ)_/¯ + ab._noCopy.Lock() + defer ab._noCopy.Unlock() } func TestAtomicError(t *testing.T) { @@ -332,3 +294,217 @@ func TestIsolationLevelMapping(t *testing.T) { t.Fatalf("Expected error to be %q, got %q", expectedErr, err) } } + +func TestAppendDateTime(t *testing.T) { + tests := []struct { + t time.Time + str string + }{ + { + t: time.Date(1234, 5, 6, 0, 0, 0, 0, time.UTC), + str: "1234-05-06", + }, + { + t: time.Date(4567, 12, 31, 12, 0, 0, 0, time.UTC), + str: "4567-12-31 12:00:00", + }, + { + t: time.Date(2020, 5, 30, 12, 34, 0, 0, time.UTC), + str: "2020-05-30 12:34:00", + }, + { + t: time.Date(2020, 5, 30, 12, 34, 56, 0, time.UTC), + str: "2020-05-30 12:34:56", + }, + { + t: time.Date(2020, 5, 30, 22, 33, 44, 123000000, time.UTC), + str: "2020-05-30 22:33:44.123", + }, + { + t: time.Date(2020, 5, 30, 22, 33, 44, 123456000, time.UTC), + str: "2020-05-30 22:33:44.123456", + }, + { + t: time.Date(2020, 5, 30, 22, 33, 44, 123456789, time.UTC), + str: "2020-05-30 22:33:44.123456789", + }, + { + t: time.Date(9999, 12, 31, 23, 59, 59, 999999999, time.UTC), + str: "9999-12-31 23:59:59.999999999", + }, + { + t: time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC), + str: "0001-01-01", + }, + } + for _, v := range tests { + buf := make([]byte, 0, 32) + buf, _ = appendDateTime(buf, v.t) + if str := string(buf); str != v.str { + t.Errorf("appendDateTime(%v), have: %s, want: %s", v.t, str, v.str) + } + } + + // year out of range + { + v := time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC) + buf := make([]byte, 0, 32) + _, err := appendDateTime(buf, v) + if err == nil { + t.Error("want an error") + return + } + } + { + v := time.Date(10000, 1, 1, 0, 0, 0, 0, time.UTC) + buf := make([]byte, 0, 32) + _, err := appendDateTime(buf, v) + if err == nil { + t.Error("want an error") + return + } + } +} + +func TestParseDateTime(t *testing.T) { + cases := []struct { + name string + str string + }{ + { + name: "parse date", + str: "2020-05-13", + }, + { + name: "parse null date", + str: sDate0, + }, + { + name: "parse datetime", + str: "2020-05-13 21:30:45", + }, + { + name: "parse null datetime", + str: sDateTime0, + }, + { + name: "parse datetime nanosec 1-digit", + str: "2020-05-25 23:22:01.1", + }, + { + name: "parse datetime nanosec 2-digits", + str: "2020-05-25 23:22:01.15", + }, + { + name: "parse datetime nanosec 3-digits", + str: "2020-05-25 23:22:01.159", + }, + { + name: "parse datetime nanosec 4-digits", + str: "2020-05-25 23:22:01.1594", + }, + { + name: "parse datetime nanosec 5-digits", + str: "2020-05-25 23:22:01.15949", + }, + { + name: "parse datetime nanosec 6-digits", + str: "2020-05-25 23:22:01.159491", + }, + } + + for _, loc := range []*time.Location{ + time.UTC, + time.FixedZone("test", 8*60*60), + } { + for _, cc := range cases { + t.Run(cc.name+"-"+loc.String(), func(t *testing.T) { + var want time.Time + if cc.str != sDate0 && cc.str != sDateTime0 { + var err error + want, err = time.ParseInLocation(timeFormat[:len(cc.str)], cc.str, loc) + if err != nil { + t.Fatal(err) + } + } + got, err := parseDateTime([]byte(cc.str), loc) + if err != nil { + t.Fatal(err) + } + + if !want.Equal(got) { + t.Fatalf("want: %v, but got %v", want, got) + } + }) + } + } +} + +func TestParseDateTimeFail(t *testing.T) { + cases := []struct { + name string + str string + wantErr string + }{ + { + name: "parse invalid time", + str: "hello", + wantErr: "invalid time bytes: hello", + }, + { + name: "parse year", + str: "000!-00-00 00:00:00.000000", + wantErr: "not [0-9]", + }, + { + name: "parse month", + str: "0000-!0-00 00:00:00.000000", + wantErr: "not [0-9]", + }, + { + name: `parse "-" after parsed year`, + str: "0000:00-00 00:00:00.000000", + wantErr: "bad value for field: `:`", + }, + { + name: `parse "-" after parsed month`, + str: "0000-00:00 00:00:00.000000", + wantErr: "bad value for field: `:`", + }, + { + name: `parse " " after parsed date`, + str: "0000-00-00+00:00:00.000000", + wantErr: "bad value for field: `+`", + }, + { + name: `parse ":" after parsed date`, + str: "0000-00-00 00-00:00.000000", + wantErr: "bad value for field: `-`", + }, + { + name: `parse ":" after parsed hour`, + str: "0000-00-00 00:00-00.000000", + wantErr: "bad value for field: `-`", + }, + { + name: `parse "." after parsed sec`, + str: "0000-00-00 00:00:00?000000", + wantErr: "bad value for field: `?`", + }, + } + + for _, cc := range cases { + t.Run(cc.name, func(t *testing.T) { + got, err := parseDateTime([]byte(cc.str), time.UTC) + if err == nil { + t.Fatal("want error") + } + if cc.wantErr != err.Error() { + t.Fatalf("want `%s`, but got `%s`", cc.wantErr, err) + } + if !got.IsZero() { + t.Fatal("want zero time") + } + }) + } +}