diff --git a/connection.go b/connection.go index d1d8b29fe..90aec6439 100644 --- a/connection.go +++ b/connection.go @@ -245,44 +245,10 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin if v.IsZero() { buf = append(buf, "'0000-00-00'"...) } else { - v := v.In(mc.cfg.Loc) - v = v.Add(time.Nanosecond * 500) // To round under microsecond - year := v.Year() - year100 := year / 100 - year1 := year % 100 - month := v.Month() - day := v.Day() - hour := v.Hour() - minute := v.Minute() - second := v.Second() - micro := v.Nanosecond() / 1000 - - buf = append(buf, []byte{ - '\'', - digits10[year100], digits01[year100], - digits10[year1], digits01[year1], - '-', - digits10[month], digits01[month], - '-', - digits10[day], digits01[day], - ' ', - digits10[hour], digits01[hour], - ':', - digits10[minute], digits01[minute], - ':', - digits10[second], digits01[second], - }...) - - if micro != 0 { - micro10000 := micro / 10000 - micro100 := micro / 100 % 100 - micro1 := micro % 100 - buf = append(buf, []byte{ - '.', - digits10[micro10000], digits01[micro10000], - digits10[micro100], digits01[micro100], - digits10[micro1], digits01[micro1], - }...) + buf = append(buf, '\'') + buf, err = appendDateTime(buf, v.In(mc.cfg.Loc)) + if err != nil { + return "", err } buf = append(buf, '\'') } diff --git a/packets.go b/packets.go index 8e2f5e76f..6664e5ae5 100644 --- a/packets.go +++ b/packets.go @@ -1116,7 +1116,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if v.IsZero() { b = append(b, "0000-00-00"...) } else { - b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat) + b, err = appendDateTime(b, v.In(mc.cfg.Loc)) + if err != nil { + return err + } } paramValues = appendLengthEncodedInteger(paramValues, diff --git a/utils.go b/utils.go index 9dd3679c6..b0c6e9ca3 100644 --- a/utils.go +++ b/utils.go @@ -276,6 +276,55 @@ func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Va return nil, fmt.Errorf("invalid DATETIME packet length %d", num) } +func appendDateTime(buf []byte, t time.Time) ([]byte, error) { + nsec := t.Nanosecond() + // to round under microsecond + if nsec%1000 >= 500 { // save half of time.Time.Add calls + t = t.Add(500 * time.Nanosecond) + nsec = t.Nanosecond() + } + year, month, day := t.Date() + hour, min, sec := t.Clock() + micro := nsec / 1000 + + if year < 1 || year > 9999 { + return buf, errors.New("year is not in the range [1, 9999]: " + strconv.Itoa(year)) // use errors.New instead of fmt.Errorf to avoid year escape to heap + } + year100 := year / 100 + year1 := year % 100 + + var localBuf [26]byte // does not escape + localBuf[0], localBuf[1], localBuf[2], localBuf[3] = digits10[year100], digits01[year100], digits10[year1], digits01[year1] + localBuf[4] = '-' + localBuf[5], localBuf[6] = digits10[month], digits01[month] + localBuf[7] = '-' + localBuf[8], localBuf[9] = digits10[day], digits01[day] + + if hour == 0 && min == 0 && sec == 0 && micro == 0 { + return append(buf, localBuf[:10]...), nil + } + + localBuf[10] = ' ' + localBuf[11], localBuf[12] = digits10[hour], digits01[hour] + localBuf[13] = ':' + localBuf[14], localBuf[15] = digits10[min], digits01[min] + localBuf[16] = ':' + localBuf[17], localBuf[18] = digits10[sec], digits01[sec] + + if micro == 0 { + return append(buf, localBuf[:19]...), nil + } + + micro10000 := micro / 10000 + micro100 := (micro / 100) % 100 + micro1 := micro % 100 + localBuf[19] = '.' + localBuf[20], localBuf[21], localBuf[22], localBuf[23], localBuf[24], localBuf[25] = + digits10[micro10000], digits01[micro10000], digits10[micro100], digits01[micro100], digits10[micro1], digits01[micro1] + + return append(buf, localBuf[:]...), nil +} + // zeroDateTime is used in formatBinaryDateTime to avoid an allocation // if the DATE or DATETIME has the zero value. // It must never be changed. diff --git a/utils_test.go b/utils_test.go index 114f4b3da..e3619e7a7 100644 --- a/utils_test.go +++ b/utils_test.go @@ -293,6 +293,78 @@ func TestIsolationLevelMapping(t *testing.T) { } } +func TestAppendDateTime(t *testing.T) { + tests := []struct { + t time.Time + str string + }{ + { + t: time.Date(2020, 05, 30, 0, 0, 0, 0, time.UTC), + str: "2020-05-30", + }, + { + t: time.Date(2020, 05, 30, 22, 0, 0, 0, time.UTC), + str: "2020-05-30 22:00:00", + }, + { + t: time.Date(2020, 05, 30, 22, 33, 0, 0, time.UTC), + str: "2020-05-30 22:33:00", + }, + { + t: time.Date(2020, 05, 30, 22, 33, 44, 0, time.UTC), + str: "2020-05-30 22:33:44", + }, + { + t: time.Date(2020, 05, 30, 22, 33, 44, 550000000, time.UTC), + str: "2020-05-30 22:33:44.550000", + }, + { + t: time.Date(2020, 05, 30, 22, 33, 44, 550000499, time.UTC), + str: "2020-05-30 22:33:44.550000", + }, + { + t: time.Date(2020, 05, 30, 22, 33, 44, 550000500, time.UTC), + str: "2020-05-30 22:33:44.550001", + }, + { + t: time.Date(2020, 05, 30, 22, 33, 44, 550000567, time.UTC), + str: "2020-05-30 22:33:44.550001", + }, + { + t: time.Date(2020, 05, 30, 22, 33, 44, 999999567, time.UTC), + str: "2020-05-30 22:33:45", + }, + } + for _, v := range tests { + buf := make([]byte, 0, 32) + buf, _ = appendDateTime(buf, v.t) + if str := string(buf); str != v.str { + t.Errorf("appendDateTime(%v), have: %s, want: %s", v.t, str, v.str) + return + } + } + + // year out of range + { + v := time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC) + buf := make([]byte, 0, 32) + _, err := appendDateTime(buf, v) + if err == nil { + t.Error("want an error") + return + } + } + { + v := time.Date(10000, 1, 1, 0, 0, 0, 0, time.UTC) + buf := make([]byte, 0, 32) + _, err := appendDateTime(buf, v) + if err == nil { + t.Error("want an error") + return + } + } +} + func TestParseDateTime(t *testing.T) { cases := []struct { name string