Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

utils: parse using byteslice in parseDateTime #1113

Merged
merged 11 commits into from May 31, 2020
1 change: 1 addition & 0 deletions AUTHORS
Expand Up @@ -55,6 +55,7 @@ Julien Schmidt <go-sql-driver at julienschmidt.com>
Justin Li <jli at j-li.net>
Justin Nuß <nuss.justin at gmail.com>
Kamil Dziedzic <kamil at klecza.pl>
Kei Kamikawa <x00.x7f.x86 at gmail.com>
Kevin Malachowski <kevin at chowski.com>
Kieron Woodhouse <kieron.woodhouse at infosum.com>
Lennart Rudolph <lrudolph at hmc.edu>
Expand Down
4 changes: 2 additions & 2 deletions nulltime.go
Expand Up @@ -28,11 +28,11 @@ func (nt *NullTime) Scan(value interface{}) (err error) {
nt.Time, nt.Valid = v, true
return
case []byte:
nt.Time, err = parseDateTime(string(v), time.UTC)
nt.Time, err = parseDateTime(v, time.UTC)
nt.Valid = (err == nil)
return
case string:
nt.Time, err = parseDateTime(v, time.UTC)
nt.Time, err = parseDateTime([]byte(v), time.UTC)
nt.Valid = (err == nil)
return
}
Expand Down
2 changes: 1 addition & 1 deletion packets.go
Expand Up @@ -778,7 +778,7 @@ func (rows *textRows) readRow(dest []driver.Value) error {
case fieldTypeTimestamp, fieldTypeDateTime,
fieldTypeDate, fieldTypeNewDate:
dest[i], err = parseDateTime(
string(dest[i].([]byte)),
dest[i].([]byte),
mc.cfg.Loc,
)
if err == nil {
Expand Down
135 changes: 125 additions & 10 deletions utils.go
Expand Up @@ -106,21 +106,136 @@ func readBool(input string) (value bool, valid bool) {
* Time related utils *
******************************************************************************/

func parseDateTime(str string, loc *time.Location) (t time.Time, err error) {
base := "0000-00-00 00:00:00.0000000"
switch len(str) {
func parseDateTime(b []byte, loc *time.Location) (time.Time, error) {
const base = "0000-00-00 00:00:00.000000"
switch len(b) {
case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM"
if str == base[:len(str)] {
return
if string(b) == base[:len(b)] {
return time.Time{}, nil
}
if loc == time.UTC {
return time.Parse(timeFormat[:len(str)], str)

year, err := parseByteYear(b)
if err != nil {
return time.Time{}, err
}
if year <= 0 {
year = 1
}

if b[4] != '-' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[4])
}

m, err := parseByte2Digits(b[5], b[6])
if err != nil {
return time.Time{}, err
}
if m <= 0 {
m = 1
}
month := time.Month(m)

if b[7] != '-' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[7])
}

day, err := parseByte2Digits(b[8], b[9])
if err != nil {
return time.Time{}, err
}
if day <= 0 {
day = 1
}
if len(b) == 10 {
return time.Date(year, month, day, 0, 0, 0, 0, loc), nil
}

if b[10] != ' ' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[10])
}

hour, err := parseByte2Digits(b[11], b[12])
if err != nil {
return time.Time{}, err
}
if b[13] != ':' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[13])
}

min, err := parseByte2Digits(b[14], b[15])
if err != nil {
return time.Time{}, err
}
if b[16] != ':' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[16])
}
return time.ParseInLocation(timeFormat[:len(str)], str, loc)

sec, err := parseByte2Digits(b[17], b[18])
if err != nil {
return time.Time{}, err
}
if len(b) == 19 {
return time.Date(year, month, day, hour, min, sec, 0, loc), nil
}

if b[19] != '.' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[19])
}
nsec, err := parseByteNanoSec(b[20:])
if err != nil {
return time.Time{}, err
}
return time.Date(year, month, day, hour, min, sec, nsec, loc), nil
default:
err = fmt.Errorf("invalid time string: %s", str)
return
return time.Time{}, fmt.Errorf("invalid time bytes: %s", b)
}
}

func parseByteYear(b []byte) (int, error) {
year, n := 0, 1000
for i := 0; i < 4; i++ {
v, err := bToi(b[i])
if err != nil {
return 0, err
}
year += v * n
n = n / 10
}
return year, nil
}

func parseByte2Digits(b1, b2 byte) (int, error) {
d1, err := bToi(b1)
if err != nil {
return 0, err
}
d2, err := bToi(b2)
if err != nil {
return 0, err
}
return d1*10 + d2, nil
}

func parseByteNanoSec(b []byte) (int, error) {
ns, digit := 0, 100000 // max is 6-digits
for i := 0; i < len(b); i++ {
v, err := bToi(b[i])
if err != nil {
return 0, err
}
ns += v * digit
digit /= 10
}
// nanoseconds has 10-digits. (needs to scale digits)
// 10 - 6 = 4, so we have to multiple 1000.
return ns * 1000, nil
}

func bToi(b byte) (int, error) {
if b < '0' || b > '9' {
return 0, errors.New("not [0-9]")
}
return int(b - '0'), nil
}

func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) {
Expand Down
167 changes: 134 additions & 33 deletions utils_test.go
Expand Up @@ -294,43 +294,144 @@ func TestIsolationLevelMapping(t *testing.T) {
}

func TestParseDateTime(t *testing.T) {
// UTC loc
{
str := "2020-05-13 21:30:45"
t1, err := parseDateTime(str, time.UTC)
if err != nil {
t.Error(err)
return
}
t2 := time.Date(2020, 5, 13,
21, 30, 45, 0, time.UTC)
if !t1.Equal(t2) {
t.Errorf("want equal, have: %v, want: %v", t1, t2)
return
}
cases := []struct {
name string
str string
}{
{
name: "parse date",
str: "2020-05-13",
},
{
name: "parse null date",
str: sDate0,
},
{
name: "parse datetime",
str: "2020-05-13 21:30:45",
},
{
name: "parse null datetime",
str: sDateTime0,
},
{
name: "parse datetime nanosec 1-digit",
str: "2020-05-25 23:22:01.1",
},
{
name: "parse datetime nanosec 2-digits",
str: "2020-05-25 23:22:01.15",
},
{
name: "parse datetime nanosec 3-digits",
str: "2020-05-25 23:22:01.159",
},
{
name: "parse datetime nanosec 4-digits",
str: "2020-05-25 23:22:01.1594",
},
{
name: "parse datetime nanosec 5-digits",
str: "2020-05-25 23:22:01.15949",
},
{
name: "parse datetime nanosec 6-digits",
str: "2020-05-25 23:22:01.159491",
},
}
// non-UTC loc
{
str := "2020-05-13 21:30:45"
loc := time.FixedZone("test", 8*60*60)
t1, err := parseDateTime(str, loc)
if err != nil {
t.Error(err)
return
}
t2 := time.Date(2020, 5, 13,
21, 30, 45, 0, loc)
if !t1.Equal(t2) {
t.Errorf("want equal, have: %v, want: %v", t1, t2)
return

for _, loc := range []*time.Location{
time.UTC,
time.FixedZone("test", 8*60*60),
} {
for _, cc := range cases {
t.Run(cc.name+"-"+loc.String(), func(t *testing.T) {
var want time.Time
if cc.str != sDate0 && cc.str != sDateTime0 {
var err error
want, err = time.ParseInLocation(timeFormat[:len(cc.str)], cc.str, loc)
if err != nil {
t.Fatal(err)
}
}
got, err := parseDateTime([]byte(cc.str), loc)
if err != nil {
t.Fatal(err)
}

if !want.Equal(got) {
t.Fatalf("want: %v, but got %v", want, got)
}
})
}
}
}

func BenchmarkParseDateTime(b *testing.B) {
str := "2020-05-13 21:30:45"
loc := time.FixedZone("test", 8*60*60)
for i := 0; i < b.N; i++ {
_, _ = parseDateTime(str, loc)
func TestParseDateTimeFail(t *testing.T) {
cases := []struct {
name string
str string
wantErr string
}{
{
name: "parse invalid time",
str: "hello",
wantErr: "invalid time bytes: hello",
},
{
name: "parse year",
str: "000!-00-00 00:00:00.000000",
wantErr: "not [0-9]",
},
{
name: "parse month",
str: "0000-!0-00 00:00:00.000000",
wantErr: "not [0-9]",
},
{
name: `parse "-" after parsed year`,
str: "0000:00-00 00:00:00.000000",
wantErr: "bad value for field: `:`",
},
{
name: `parse "-" after parsed month`,
str: "0000-00:00 00:00:00.000000",
wantErr: "bad value for field: `:`",
},
{
name: `parse " " after parsed date`,
str: "0000-00-00+00:00:00.000000",
wantErr: "bad value for field: `+`",
},
{
name: `parse ":" after parsed date`,
str: "0000-00-00 00-00:00.000000",
wantErr: "bad value for field: `-`",
},
{
name: `parse ":" after parsed hour`,
str: "0000-00-00 00:00-00.000000",
wantErr: "bad value for field: `-`",
},
{
name: `parse "." after parsed sec`,
str: "0000-00-00 00:00:00?000000",
wantErr: "bad value for field: `?`",
},
}

for _, cc := range cases {
t.Run(cc.name, func(t *testing.T) {
got, err := parseDateTime([]byte(cc.str), time.UTC)
if err == nil {
t.Fatal("want error")
}
if cc.wantErr != err.Error() {
t.Fatalf("want `%s`, but got `%s`", cc.wantErr, err)
}
if !got.IsZero() {
t.Fatal("want zero time")
}
})
}
}