Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

utils: parse using byteslice in parseDateTime #1113

Merged
merged 11 commits into from May 31, 2020
1 change: 1 addition & 0 deletions AUTHORS
Expand Up @@ -55,6 +55,7 @@ Julien Schmidt <go-sql-driver at julienschmidt.com>
Justin Li <jli at j-li.net>
Justin Nuß <nuss.justin at gmail.com>
Kamil Dziedzic <kamil at klecza.pl>
Kei Kamikawa <x00.x7f.x86 at gmail.com>
Kevin Malachowski <kevin at chowski.com>
Kieron Woodhouse <kieron.woodhouse at infosum.com>
Lennart Rudolph <lrudolph at hmc.edu>
Expand Down
4 changes: 2 additions & 2 deletions nulltime.go
Expand Up @@ -28,11 +28,11 @@ func (nt *NullTime) Scan(value interface{}) (err error) {
nt.Time, nt.Valid = v, true
return
case []byte:
nt.Time, err = parseDateTime(string(v), time.UTC)
nt.Time, err = parseDateTime(v, time.UTC)
nt.Valid = (err == nil)
return
case string:
nt.Time, err = parseDateTime(v, time.UTC)
nt.Time, err = parseDateTime([]byte(v), time.UTC)
nt.Valid = (err == nil)
return
}
Expand Down
2 changes: 1 addition & 1 deletion packets.go
Expand Up @@ -778,7 +778,7 @@ func (rows *textRows) readRow(dest []driver.Value) error {
case fieldTypeTimestamp, fieldTypeDateTime,
fieldTypeDate, fieldTypeNewDate:
dest[i], err = parseDateTime(
string(dest[i].([]byte)),
dest[i].([]byte),
mc.cfg.Loc,
)
if err == nil {
Expand Down
136 changes: 126 additions & 10 deletions utils.go
Expand Up @@ -106,21 +106,137 @@ func readBool(input string) (value bool, valid bool) {
* Time related utils *
******************************************************************************/

func parseDateTime(str string, loc *time.Location) (t time.Time, err error) {
base := "0000-00-00 00:00:00.0000000"
switch len(str) {
func parseDateTime(b []byte, loc *time.Location) (time.Time, error) {
const base = "0000-00-00 00:00:00.000000"
switch len(b) {
case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM"
if str == base[:len(str)] {
return
if string(b) == base[:len(b)] {
return time.Time{}, nil
}
if loc == time.UTC {
return time.Parse(timeFormat[:len(str)], str)

year, err := parseByteYear(b)
if err != nil {
return time.Time{}, err
}
if year <= 0 {
year = 1
}

if b[4] != '-' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[4])
}

m, err := parseByte2Digits(b[5], b[6])
if err != nil {
return time.Time{}, err
}
if m <= 0 {
m = 1
}
month := time.Month(m)

if b[7] != '-' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[7])
}

day, err := parseByte2Digits(b[8], b[9])
if err != nil {
return time.Time{}, err
}
if day <= 0 {
day = 1
}
if len(b) == 10 {
return time.Date(year, month, day, 0, 0, 0, 0, loc), nil
}

if b[10] != ' ' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[10])
}

hour, err := parseByte2Digits(b[11], b[12])
if err != nil {
return time.Time{}, err
}
if b[13] != ':' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[13])
}

min, err := parseByte2Digits(b[14], b[15])
if err != nil {
return time.Time{}, err
}
if b[16] != ':' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[16])
}
return time.ParseInLocation(timeFormat[:len(str)], str, loc)

sec, err := parseByte2Digits(b[17], b[18])
if err != nil {
return time.Time{}, err
}
if len(b) == 19 {
return time.Date(year, month, day, hour, min, sec, 0, loc), nil
}

if b[19] != '.' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[19])
}
nsec, err := parseByteNanoSec(b)
Code-Hex marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return time.Time{}, err
}
return time.Date(year, month, day, hour, min, sec, nsec, loc), nil
default:
err = fmt.Errorf("invalid time string: %s", str)
return
return time.Time{}, fmt.Errorf("invalid time bytes: %s", b)
}
}

func parseByteYear(b []byte) (int, error) {
year, n := 0, 1000
for i := 0; i < 4; i++ {
v, err := bToi(b[i])
if err != nil {
return 0, err
}
year += v * n
n = n / 10
}
return year, nil
}

func parseByte2Digits(b1, b2 byte) (int, error) {
d2, err := bToi(b1)
if err != nil {
return 0, err
}
d1, err := bToi(b2)
if err != nil {
return 0, err
}
return d2*10 + d1, nil
Code-Hex marked this conversation as resolved.
Show resolved Hide resolved
}

func parseByteNanoSec(b []byte) (int, error) {
l := len(b)
ns, digit := 0, 100000 // max is 6-digits
for i := 20; i < l; i++ {
Code-Hex marked this conversation as resolved.
Show resolved Hide resolved
v, err := bToi(b[i])
if err != nil {
return 0, err
}
ns += v * digit
digit /= 10
}
// nanoseconds has 10-digits. (needs to scale digits)
// 10 - 6 = 4, so we have to multiple 1000.
return ns * 1000, nil
}

func bToi(b byte) (int, error) {
if b < '0' || b > '9' {
return 0, errors.New("not [0-9]")
}
return int(b - '0'), nil
}

func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) {
Expand Down
195 changes: 168 additions & 27 deletions utils_test.go
Expand Up @@ -13,6 +13,7 @@ import (
"database/sql"
"database/sql/driver"
"encoding/binary"
"fmt"
"testing"
"time"
)
Expand Down Expand Up @@ -293,44 +294,184 @@ func TestIsolationLevelMapping(t *testing.T) {
}
}

func TestParseDateTime(t *testing.T) {
// UTC loc
{
str := "2020-05-13 21:30:45"
t1, err := parseDateTime(str, time.UTC)
if err != nil {
t.Error(err)
func deprecatedParseDateTime(str string, loc *time.Location) (t time.Time, err error) {
const base = "0000-00-00 00:00:00.000000"
switch len(str) {
case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM"
if str == base[:len(str)] {
return
}
t2 := time.Date(2020, 5, 13,
21, 30, 45, 0, time.UTC)
if !t1.Equal(t2) {
t.Errorf("want equal, have: %v, want: %v", t1, t2)
return
if loc == time.UTC {
return time.Parse(timeFormat[:len(str)], str)
}
return time.ParseInLocation(timeFormat[:len(str)], str, loc)
default:
err = fmt.Errorf("invalid time string: %s", str)
return
}
// non-UTC loc
{
str := "2020-05-13 21:30:45"
loc := time.FixedZone("test", 8*60*60)
t1, err := parseDateTime(str, loc)
if err != nil {
t.Error(err)
return
}
t2 := time.Date(2020, 5, 13,
21, 30, 45, 0, loc)
if !t1.Equal(t2) {
t.Errorf("want equal, have: %v, want: %v", t1, t2)
return
}

func TestParseDateTime(t *testing.T) {
cases := []struct {
name string
str string
}{
{
name: "parse date",
str: "2020-05-13",
},
{
name: "parse null date",
str: sDate0,
},
{
name: "parse datetime",
str: "2020-05-13 21:30:45",
},
{
name: "parse null datetime",
str: sDateTime0,
},
{
name: "parse datetime nanosec 1-digit",
str: "2020-05-25 23:22:01.1",
},
{
name: "parse datetime nanosec 2-digits",
str: "2020-05-25 23:22:01.15",
},
{
name: "parse datetime nanosec 3-digits",
str: "2020-05-25 23:22:01.159",
},
{
name: "parse datetime nanosec 4-digits",
str: "2020-05-25 23:22:01.1594",
},
{
name: "parse datetime nanosec 5-digits",
str: "2020-05-25 23:22:01.15949",
},
{
name: "parse datetime nanosec 6-digits",
str: "2020-05-25 23:22:01.159491",
},
}

for _, loc := range []*time.Location{
time.UTC,
time.FixedZone("test", 8*60*60),
} {
for _, cc := range cases {
t.Run(cc.name+"-"+loc.String(), func(t *testing.T) {
want, err := deprecatedParseDateTime(cc.str, loc)
if err != nil {
t.Fatal(err)
}
got, err := parseDateTime([]byte(cc.str), loc)
if err != nil {
t.Fatal(err)
}

if !want.Equal(got) {
t.Fatalf("want: %v, but got %v", want, got)
}
})
}
}
}

func TestParseDateTimeFail(t *testing.T) {
cases := []struct {
name string
str string
wantErr string
}{
{
name: "parse invalid time",
str: "hello",
wantErr: "invalid time bytes: hello",
},
{
name: "parse year",
str: "000!-00-00 00:00:00.000000",
wantErr: "not [0-9]",
},
{
name: "parse month",
str: "0000-!0-00 00:00:00.000000",
wantErr: "not [0-9]",
},
{
name: `parse "-" after parsed year`,
str: "0000:00-00 00:00:00.000000",
wantErr: "bad value for field: `:`",
},
{
name: `parse "-" after parsed month`,
str: "0000-00:00 00:00:00.000000",
wantErr: "bad value for field: `:`",
},
{
name: `parse " " after parsed date`,
str: "0000-00-00+00:00:00.000000",
wantErr: "bad value for field: `+`",
},
{
name: `parse ":" after parsed date`,
str: "0000-00-00 00-00:00.000000",
wantErr: "bad value for field: `-`",
},
{
name: `parse ":" after parsed hour`,
str: "0000-00-00 00:00-00.000000",
wantErr: "bad value for field: `-`",
},
{
name: `parse "." after parsed sec`,
str: "0000-00-00 00:00:00?000000",
wantErr: "bad value for field: `?`",
},
}

for _, cc := range cases {
t.Run(cc.name, func(t *testing.T) {
got, err := parseDateTime([]byte(cc.str), time.UTC)
if err == nil {
t.Fatal("want error")
}
if cc.wantErr != err.Error() {
t.Fatalf("want `%s`, but got `%s`", cc.wantErr, err)
}
if !got.IsZero() {
t.Fatal("want zero time")
}
})
}
}

func BenchmarkParseDateTime(b *testing.B) {
str := "2020-05-13 21:30:45"
loc := time.FixedZone("test", 8*60*60)
for i := 0; i < b.N; i++ {
_, _ = parseDateTime(str, loc)
_, _ = deprecatedParseDateTime(str, loc)
}
}

func BenchmarkParseByteDateTime(b *testing.B) {
bStr := []byte("2020-05-25 23:22:01.159491")
loc := time.FixedZone("test", 8*60*60)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = parseDateTime(bStr, loc)
}
}

func BenchmarkParseByteDateTimeStringCast(b *testing.B) {
bStr := []byte("2020-05-25 23:22:01.159491")
loc := time.FixedZone("test", 8*60*60)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = deprecatedParseDateTime(string(bStr), loc)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please post the result of this benchmark and remove deprecatedParseDateTime.
We do not keep old implementations only for benchmark.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed 2c6aa27

}
}