Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

utils: parse using byteslice in parseDateTime #1113

Merged
merged 11 commits into from May 31, 2020
1 change: 1 addition & 0 deletions AUTHORS
Expand Up @@ -55,6 +55,7 @@ Julien Schmidt <go-sql-driver at julienschmidt.com>
Justin Li <jli at j-li.net>
Justin Nuß <nuss.justin at gmail.com>
Kamil Dziedzic <kamil at klecza.pl>
Kei Kamikawa <x00.x7f.x86 at gmail.com>
Kevin Malachowski <kevin at chowski.com>
Kieron Woodhouse <kieron.woodhouse at infosum.com>
Lennart Rudolph <lrudolph at hmc.edu>
Expand Down
4 changes: 2 additions & 2 deletions nulltime.go
Expand Up @@ -28,11 +28,11 @@ func (nt *NullTime) Scan(value interface{}) (err error) {
nt.Time, nt.Valid = v, true
return
case []byte:
nt.Time, err = parseDateTime(string(v), time.UTC)
nt.Time, err = parseDateTime(v, time.UTC)
nt.Valid = (err == nil)
return
case string:
nt.Time, err = parseDateTime(v, time.UTC)
nt.Time, err = parseDateTime([]byte(v), time.UTC)
nt.Valid = (err == nil)
return
}
Expand Down
2 changes: 1 addition & 1 deletion packets.go
Expand Up @@ -778,7 +778,7 @@ func (rows *textRows) readRow(dest []driver.Value) error {
case fieldTypeTimestamp, fieldTypeDateTime,
fieldTypeDate, fieldTypeNewDate:
dest[i], err = parseDateTime(
string(dest[i].([]byte)),
dest[i].([]byte),
mc.cfg.Loc,
)
if err == nil {
Expand Down
141 changes: 131 additions & 10 deletions utils.go
Expand Up @@ -9,6 +9,7 @@
package mysql

import (
"bytes"
"crypto/tls"
"database/sql"
"database/sql/driver"
Expand Down Expand Up @@ -106,21 +107,141 @@ func readBool(input string) (value bool, valid bool) {
* Time related utils *
******************************************************************************/

func parseDateTime(str string, loc *time.Location) (t time.Time, err error) {
base := "0000-00-00 00:00:00.0000000"
switch len(str) {
var (
nullTimeBaseStr = "0000-00-00 00:00:00.000000"
nullTimeBaseByte = []byte(nullTimeBaseStr)
methane marked this conversation as resolved.
Show resolved Hide resolved
)

func parseDateTime(b []byte, loc *time.Location) (time.Time, error) {
switch len(b) {
case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM"
if str == base[:len(str)] {
return
if bytes.Compare(b, nullTimeBaseByte[:len(b)]) == 0 {
return time.Time{}, nil
}

year, err := parseByteYear(b)
if err != nil {
return time.Time{}, err
}
if year <= 0 {
year = 1
}
if loc == time.UTC {
return time.Parse(timeFormat[:len(str)], str)

if b[4] != '-' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[4])
}

m, err := parseByte2Digits(b[5], b[6])
if err != nil {
return time.Time{}, err
}
if m <= 0 {
m = 1
}
month := time.Month(m)

if b[7] != '-' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[7])
}

day, err := parseByte2Digits(b[8], b[9])
if err != nil {
return time.Time{}, err
}
if day <= 0 {
day = 1
}
if len(b) == 10 {
return time.Date(year, month, day, 0, 0, 0, 0, loc), nil
}

if b[10] != ' ' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[10])
}

hour, err := parseByte2Digits(b[11], b[12])
if err != nil {
return time.Time{}, err
}
if b[13] != ':' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[13])
}

min, err := parseByte2Digits(b[14], b[15])
if err != nil {
return time.Time{}, err
}
return time.ParseInLocation(timeFormat[:len(str)], str, loc)
if b[16] != ':' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[16])
}

sec, err := parseByte2Digits(b[17], b[18])
if err != nil {
return time.Time{}, err
}
if len(b) == 19 {
return time.Date(year, month, day, hour, min, sec, 0, loc), nil
}

if b[19] != '.' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[19])
}
nsec, err := parseByteNanoSec(b)
Code-Hex marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return time.Time{}, err
}
return time.Date(year, month, day, hour, min, sec, nsec, loc), nil
default:
err = fmt.Errorf("invalid time string: %s", str)
return
return time.Time{}, fmt.Errorf("invalid time bytes: %s", b)
}
}

func parseByteYear(b []byte) (int, error) {
year, n := 0, 1000
for i := 0; i < 4; i++ {
v, err := bToi(b[i])
if err != nil {
return 0, err
}
year += v * n
n = n / 10
}
return year, nil
}

func parseByte2Digits(b1, b2 byte) (int, error) {
d2, err := bToi(b1)
if err != nil {
return 0, err
}
d1, err := bToi(b2)
if err != nil {
return 0, err
}
return d2*10 + d1, nil
Code-Hex marked this conversation as resolved.
Show resolved Hide resolved
}

func parseByteNanoSec(b []byte) (int, error) {
l := len(b)
ns, digit := 0, 100000 // max is 6-digits
for i := 20; i < l; i++ {
Code-Hex marked this conversation as resolved.
Show resolved Hide resolved
v, err := bToi(b[i])
if err != nil {
return 0, err
}
ns += v * digit
digit /= 10
}
// nanoseconds has 10-digits. (needs to scale digits)
// 10 - 6 = 4, so we have to multiple 1000.
return ns * 1000, nil
}

func bToi(b byte) (int, error) {
if b < '0' || b > '9' {
return 0, errors.New("not [0-9]")
}
return int(b - '0'), nil
}

func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) {
Expand Down