Skip to content

Commit

Permalink
return sql.NullTime if it available (go-sql-driver#1145)
Browse files Browse the repository at this point in the history
* return sql.NullTime if it available

* NullTime should be used with parseTime=true option
  • Loading branch information
shogo82148 authored and tz70s committed Sep 5, 2020
1 parent aa1902c commit 8650021
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 115 deletions.
222 changes: 108 additions & 114 deletions driver_test.go
Expand Up @@ -2758,13 +2758,13 @@ func TestRowsColumnTypes(t *testing.T) {
nfNULL := sql.NullFloat64{Float64: 0.0, Valid: false}
nf0 := sql.NullFloat64{Float64: 0.0, Valid: true}
nf1337 := sql.NullFloat64{Float64: 13.37, Valid: true}
nt0 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 0, time.UTC), Valid: true}
nt1 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 100000000, time.UTC), Valid: true}
nt2 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 110000000, time.UTC), Valid: true}
nt6 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 111111000, time.UTC), Valid: true}
nd1 := NullTime{Time: time.Date(2006, 01, 02, 0, 0, 0, 0, time.UTC), Valid: true}
nd2 := NullTime{Time: time.Date(2006, 03, 04, 0, 0, 0, 0, time.UTC), Valid: true}
ndNULL := NullTime{Time: time.Time{}, Valid: false}
nt0 := nullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 0, time.UTC), Valid: true}
nt1 := nullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 100000000, time.UTC), Valid: true}
nt2 := nullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 110000000, time.UTC), Valid: true}
nt6 := nullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 111111000, time.UTC), Valid: true}
nd1 := nullTime{Time: time.Date(2006, 01, 02, 0, 0, 0, 0, time.UTC), Valid: true}
nd2 := nullTime{Time: time.Date(2006, 03, 04, 0, 0, 0, 0, time.UTC), Valid: true}
ndNULL := nullTime{Time: time.Time{}, Valid: false}
rbNULL := sql.RawBytes(nil)
rb0 := sql.RawBytes("0")
rb42 := sql.RawBytes("42")
Expand Down Expand Up @@ -2844,131 +2844,125 @@ func TestRowsColumnTypes(t *testing.T) {
values2 = values2[:len(values2)-2]
values3 = values3[:len(values3)-2]

dsns := []string{
dsn + "&parseTime=true",
dsn + "&parseTime=false",
}
for _, testdsn := range dsns {
runTests(t, testdsn, func(dbt *DBTest) {
dbt.mustExec("CREATE TABLE test (" + schema + ")")
dbt.mustExec("INSERT INTO test VALUES (" + values1 + "), (" + values2 + "), (" + values3 + ")")
runTests(t, dsn+"&parseTime=true", func(dbt *DBTest) {
dbt.mustExec("CREATE TABLE test (" + schema + ")")
dbt.mustExec("INSERT INTO test VALUES (" + values1 + "), (" + values2 + "), (" + values3 + ")")

rows, err := dbt.db.Query("SELECT * FROM test")
if err != nil {
t.Fatalf("Query: %v", err)
}
rows, err := dbt.db.Query("SELECT * FROM test")
if err != nil {
t.Fatalf("Query: %v", err)
}

tt, err := rows.ColumnTypes()
if err != nil {
t.Fatalf("ColumnTypes: %v", err)
}
tt, err := rows.ColumnTypes()
if err != nil {
t.Fatalf("ColumnTypes: %v", err)
}

if len(tt) != len(columns) {
t.Fatalf("unexpected number of columns: expected %d, got %d", len(columns), len(tt))
}
if len(tt) != len(columns) {
t.Fatalf("unexpected number of columns: expected %d, got %d", len(columns), len(tt))
}

types := make([]reflect.Type, len(tt))
for i, tp := range tt {
column := columns[i]
types := make([]reflect.Type, len(tt))
for i, tp := range tt {
column := columns[i]

// Name
name := tp.Name()
if name != column.name {
t.Errorf("column name mismatch %s != %s", name, column.name)
continue
}
// Name
name := tp.Name()
if name != column.name {
t.Errorf("column name mismatch %s != %s", name, column.name)
continue
}

// DatabaseTypeName
databaseTypeName := tp.DatabaseTypeName()
if databaseTypeName != column.databaseTypeName {
t.Errorf("databasetypename name mismatch for column %q: %s != %s", name, databaseTypeName, column.databaseTypeName)
continue
}
// DatabaseTypeName
databaseTypeName := tp.DatabaseTypeName()
if databaseTypeName != column.databaseTypeName {
t.Errorf("databasetypename name mismatch for column %q: %s != %s", name, databaseTypeName, column.databaseTypeName)
continue
}

// ScanType
scanType := tp.ScanType()
if scanType != column.scanType {
if scanType == nil {
t.Errorf("scantype is null for column %q", name)
} else {
t.Errorf("scantype mismatch for column %q: %s != %s", name, scanType.Name(), column.scanType.Name())
}
continue
// ScanType
scanType := tp.ScanType()
if scanType != column.scanType {
if scanType == nil {
t.Errorf("scantype is null for column %q", name)
} else {
t.Errorf("scantype mismatch for column %q: %s != %s", name, scanType.Name(), column.scanType.Name())
}
types[i] = scanType

// Nullable
nullable, ok := tp.Nullable()
continue
}
types[i] = scanType

// Nullable
nullable, ok := tp.Nullable()
if !ok {
t.Errorf("nullable not ok %q", name)
continue
}
if nullable != column.nullable {
t.Errorf("nullable mismatch for column %q: %t != %t", name, nullable, column.nullable)
}

// Length
// length, ok := tp.Length()
// if length != column.length {
// if !ok {
// t.Errorf("length not ok for column %q", name)
// } else {
// t.Errorf("length mismatch for column %q: %d != %d", name, length, column.length)
// }
// continue
// }

// Precision and Scale
precision, scale, ok := tp.DecimalSize()
if precision != column.precision {
if !ok {
t.Errorf("nullable not ok %q", name)
continue
}
if nullable != column.nullable {
t.Errorf("nullable mismatch for column %q: %t != %t", name, nullable, column.nullable)
t.Errorf("precision not ok for column %q", name)
} else {
t.Errorf("precision mismatch for column %q: %d != %d", name, precision, column.precision)
}

// Length
// length, ok := tp.Length()
// if length != column.length {
// if !ok {
// t.Errorf("length not ok for column %q", name)
// } else {
// t.Errorf("length mismatch for column %q: %d != %d", name, length, column.length)
// }
// continue
// }

// Precision and Scale
precision, scale, ok := tp.DecimalSize()
if precision != column.precision {
if !ok {
t.Errorf("precision not ok for column %q", name)
} else {
t.Errorf("precision mismatch for column %q: %d != %d", name, precision, column.precision)
}
continue
}
if scale != column.scale {
if !ok {
t.Errorf("scale not ok for column %q", name)
} else {
t.Errorf("scale mismatch for column %q: %d != %d", name, scale, column.scale)
}
continue
continue
}
if scale != column.scale {
if !ok {
t.Errorf("scale not ok for column %q", name)
} else {
t.Errorf("scale mismatch for column %q: %d != %d", name, scale, column.scale)
}
continue
}
}

values := make([]interface{}, len(tt))
for i := range values {
values[i] = reflect.New(types[i]).Interface()
values := make([]interface{}, len(tt))
for i := range values {
values[i] = reflect.New(types[i]).Interface()
}
i := 0
for rows.Next() {
err = rows.Scan(values...)
if err != nil {
t.Fatalf("failed to scan values in %v", err)
}
i := 0
for rows.Next() {
err = rows.Scan(values...)
if err != nil {
t.Fatalf("failed to scan values in %v", err)
}
for j := range values {
value := reflect.ValueOf(values[j]).Elem().Interface()
if !reflect.DeepEqual(value, columns[j].valuesOut[i]) {
if columns[j].scanType == scanTypeRawBytes {
t.Errorf("row %d, column %d: %v != %v", i, j, string(value.(sql.RawBytes)), string(columns[j].valuesOut[i].(sql.RawBytes)))
} else {
t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i])
}
for j := range values {
value := reflect.ValueOf(values[j]).Elem().Interface()
if !reflect.DeepEqual(value, columns[j].valuesOut[i]) {
if columns[j].scanType == scanTypeRawBytes {
t.Errorf("row %d, column %d: %v != %v", i, j, string(value.(sql.RawBytes)), string(columns[j].valuesOut[i].(sql.RawBytes)))
} else {
t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i])
}
}
i++
}
if i != 3 {
t.Errorf("expected 3 rows, got %d", i)
}
i++
}
if i != 3 {
t.Errorf("expected 3 rows, got %d", i)
}

if err := rows.Close(); err != nil {
t.Errorf("error closing rows: %s", err)
}
})
}
if err := rows.Close(); err != nil {
t.Errorf("error closing rows: %s", err)
}
})
}

func TestValuerWithValueReceiverGivenNilValue(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion fields.go
Expand Up @@ -106,7 +106,7 @@ var (
scanTypeInt64 = reflect.TypeOf(int64(0))
scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{})
scanTypeNullInt = reflect.TypeOf(sql.NullInt64{})
scanTypeNullTime = reflect.TypeOf(NullTime{})
scanTypeNullTime = reflect.TypeOf(nullTime{})
scanTypeUint8 = reflect.TypeOf(uint8(0))
scanTypeUint16 = reflect.TypeOf(uint16(0))
scanTypeUint32 = reflect.TypeOf(uint32(0))
Expand Down
5 changes: 5 additions & 0 deletions nulltime_go113.go
Expand Up @@ -33,3 +33,8 @@ import (
// NullTime.Scan interprets a time as UTC, not the loc DSN parameter.
// Use sql.NullTime instead.
type NullTime sql.NullTime

// for internal use.
// the mysql package uses sql.NullTime if it is available.
// if not, the package uses mysql.NullTime.
type nullTime = sql.NullTime // sql.NullTime is available
5 changes: 5 additions & 0 deletions nulltime_legacy.go
Expand Up @@ -32,3 +32,8 @@ type NullTime struct {
Time time.Time
Valid bool // Valid is true if Time is not NULL
}

// for internal use.
// the mysql package uses sql.NullTime if it is available.
// if not, the package uses mysql.NullTime.
type nullTime = NullTime // sql.NullTime is not available

0 comments on commit 8650021

Please sign in to comment.