Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

return sql.NullTime if it available #1145

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
222 changes: 108 additions & 114 deletions driver_test.go
Expand Up @@ -2758,13 +2758,13 @@ func TestRowsColumnTypes(t *testing.T) {
nfNULL := sql.NullFloat64{Float64: 0.0, Valid: false}
nf0 := sql.NullFloat64{Float64: 0.0, Valid: true}
nf1337 := sql.NullFloat64{Float64: 13.37, Valid: true}
nt0 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 0, time.UTC), Valid: true}
nt1 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 100000000, time.UTC), Valid: true}
nt2 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 110000000, time.UTC), Valid: true}
nt6 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 111111000, time.UTC), Valid: true}
nd1 := NullTime{Time: time.Date(2006, 01, 02, 0, 0, 0, 0, time.UTC), Valid: true}
nd2 := NullTime{Time: time.Date(2006, 03, 04, 0, 0, 0, 0, time.UTC), Valid: true}
ndNULL := NullTime{Time: time.Time{}, Valid: false}
nt0 := nullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 0, time.UTC), Valid: true}
nt1 := nullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 100000000, time.UTC), Valid: true}
nt2 := nullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 110000000, time.UTC), Valid: true}
nt6 := nullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 111111000, time.UTC), Valid: true}
nd1 := nullTime{Time: time.Date(2006, 01, 02, 0, 0, 0, 0, time.UTC), Valid: true}
nd2 := nullTime{Time: time.Date(2006, 03, 04, 0, 0, 0, 0, time.UTC), Valid: true}
ndNULL := nullTime{Time: time.Time{}, Valid: false}
rbNULL := sql.RawBytes(nil)
rb0 := sql.RawBytes("0")
rb42 := sql.RawBytes("42")
Expand Down Expand Up @@ -2844,131 +2844,125 @@ func TestRowsColumnTypes(t *testing.T) {
values2 = values2[:len(values2)-2]
values3 = values3[:len(values3)-2]

dsns := []string{
dsn + "&parseTime=true",
dsn + "&parseTime=false",
}
for _, testdsn := range dsns {
runTests(t, testdsn, func(dbt *DBTest) {
dbt.mustExec("CREATE TABLE test (" + schema + ")")
dbt.mustExec("INSERT INTO test VALUES (" + values1 + "), (" + values2 + "), (" + values3 + ")")
runTests(t, dsn+"&parseTime=true", func(dbt *DBTest) {
dbt.mustExec("CREATE TABLE test (" + schema + ")")
dbt.mustExec("INSERT INTO test VALUES (" + values1 + "), (" + values2 + "), (" + values3 + ")")

rows, err := dbt.db.Query("SELECT * FROM test")
if err != nil {
t.Fatalf("Query: %v", err)
}
rows, err := dbt.db.Query("SELECT * FROM test")
if err != nil {
t.Fatalf("Query: %v", err)
}

tt, err := rows.ColumnTypes()
if err != nil {
t.Fatalf("ColumnTypes: %v", err)
}
tt, err := rows.ColumnTypes()
if err != nil {
t.Fatalf("ColumnTypes: %v", err)
}

if len(tt) != len(columns) {
t.Fatalf("unexpected number of columns: expected %d, got %d", len(columns), len(tt))
}
if len(tt) != len(columns) {
t.Fatalf("unexpected number of columns: expected %d, got %d", len(columns), len(tt))
}

types := make([]reflect.Type, len(tt))
for i, tp := range tt {
column := columns[i]
types := make([]reflect.Type, len(tt))
for i, tp := range tt {
column := columns[i]

// Name
name := tp.Name()
if name != column.name {
t.Errorf("column name mismatch %s != %s", name, column.name)
continue
}
// Name
name := tp.Name()
if name != column.name {
t.Errorf("column name mismatch %s != %s", name, column.name)
continue
}

// DatabaseTypeName
databaseTypeName := tp.DatabaseTypeName()
if databaseTypeName != column.databaseTypeName {
t.Errorf("databasetypename name mismatch for column %q: %s != %s", name, databaseTypeName, column.databaseTypeName)
continue
}
// DatabaseTypeName
databaseTypeName := tp.DatabaseTypeName()
if databaseTypeName != column.databaseTypeName {
t.Errorf("databasetypename name mismatch for column %q: %s != %s", name, databaseTypeName, column.databaseTypeName)
continue
}

// ScanType
scanType := tp.ScanType()
if scanType != column.scanType {
if scanType == nil {
t.Errorf("scantype is null for column %q", name)
} else {
t.Errorf("scantype mismatch for column %q: %s != %s", name, scanType.Name(), column.scanType.Name())
}
continue
// ScanType
scanType := tp.ScanType()
if scanType != column.scanType {
if scanType == nil {
t.Errorf("scantype is null for column %q", name)
} else {
t.Errorf("scantype mismatch for column %q: %s != %s", name, scanType.Name(), column.scanType.Name())
}
types[i] = scanType

// Nullable
nullable, ok := tp.Nullable()
continue
}
types[i] = scanType

// Nullable
nullable, ok := tp.Nullable()
if !ok {
t.Errorf("nullable not ok %q", name)
continue
}
if nullable != column.nullable {
t.Errorf("nullable mismatch for column %q: %t != %t", name, nullable, column.nullable)
}

// Length
// length, ok := tp.Length()
// if length != column.length {
// if !ok {
// t.Errorf("length not ok for column %q", name)
// } else {
// t.Errorf("length mismatch for column %q: %d != %d", name, length, column.length)
// }
// continue
// }

// Precision and Scale
precision, scale, ok := tp.DecimalSize()
if precision != column.precision {
if !ok {
t.Errorf("nullable not ok %q", name)
continue
}
if nullable != column.nullable {
t.Errorf("nullable mismatch for column %q: %t != %t", name, nullable, column.nullable)
t.Errorf("precision not ok for column %q", name)
} else {
t.Errorf("precision mismatch for column %q: %d != %d", name, precision, column.precision)
}

// Length
// length, ok := tp.Length()
// if length != column.length {
// if !ok {
// t.Errorf("length not ok for column %q", name)
// } else {
// t.Errorf("length mismatch for column %q: %d != %d", name, length, column.length)
// }
// continue
// }

// Precision and Scale
precision, scale, ok := tp.DecimalSize()
if precision != column.precision {
if !ok {
t.Errorf("precision not ok for column %q", name)
} else {
t.Errorf("precision mismatch for column %q: %d != %d", name, precision, column.precision)
}
continue
}
if scale != column.scale {
if !ok {
t.Errorf("scale not ok for column %q", name)
} else {
t.Errorf("scale mismatch for column %q: %d != %d", name, scale, column.scale)
}
continue
continue
}
if scale != column.scale {
if !ok {
t.Errorf("scale not ok for column %q", name)
} else {
t.Errorf("scale mismatch for column %q: %d != %d", name, scale, column.scale)
}
continue
}
}

values := make([]interface{}, len(tt))
for i := range values {
values[i] = reflect.New(types[i]).Interface()
values := make([]interface{}, len(tt))
for i := range values {
values[i] = reflect.New(types[i]).Interface()
}
i := 0
for rows.Next() {
err = rows.Scan(values...)
if err != nil {
t.Fatalf("failed to scan values in %v", err)
}
i := 0
for rows.Next() {
err = rows.Scan(values...)
if err != nil {
t.Fatalf("failed to scan values in %v", err)
}
for j := range values {
value := reflect.ValueOf(values[j]).Elem().Interface()
if !reflect.DeepEqual(value, columns[j].valuesOut[i]) {
if columns[j].scanType == scanTypeRawBytes {
t.Errorf("row %d, column %d: %v != %v", i, j, string(value.(sql.RawBytes)), string(columns[j].valuesOut[i].(sql.RawBytes)))
} else {
t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i])
}
for j := range values {
value := reflect.ValueOf(values[j]).Elem().Interface()
if !reflect.DeepEqual(value, columns[j].valuesOut[i]) {
if columns[j].scanType == scanTypeRawBytes {
t.Errorf("row %d, column %d: %v != %v", i, j, string(value.(sql.RawBytes)), string(columns[j].valuesOut[i].(sql.RawBytes)))
} else {
t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i])
}
}
i++
}
if i != 3 {
t.Errorf("expected 3 rows, got %d", i)
}
i++
}
if i != 3 {
t.Errorf("expected 3 rows, got %d", i)
}

if err := rows.Close(); err != nil {
t.Errorf("error closing rows: %s", err)
}
})
}
if err := rows.Close(); err != nil {
t.Errorf("error closing rows: %s", err)
}
})
}

func TestValuerWithValueReceiverGivenNilValue(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion fields.go
Expand Up @@ -106,7 +106,7 @@ var (
scanTypeInt64 = reflect.TypeOf(int64(0))
scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{})
scanTypeNullInt = reflect.TypeOf(sql.NullInt64{})
scanTypeNullTime = reflect.TypeOf(NullTime{})
scanTypeNullTime = reflect.TypeOf(nullTime{})
scanTypeUint8 = reflect.TypeOf(uint8(0))
scanTypeUint16 = reflect.TypeOf(uint16(0))
scanTypeUint32 = reflect.TypeOf(uint32(0))
Expand Down
5 changes: 5 additions & 0 deletions nulltime_go113.go
Expand Up @@ -29,3 +29,8 @@ import (
//
// This NullTime implementation is not driver-specific
type NullTime sql.NullTime

// for internal use.
// the mysql package uses sql.NullTime if it is available.
// if not, the package uses mysql.NullTime.
type nullTime = sql.NullTime // sql.NullTime is available
5 changes: 5 additions & 0 deletions nulltime_legacy.go
Expand Up @@ -32,3 +32,8 @@ type NullTime struct {
Time time.Time
Valid bool // Valid is true if Time is not NULL
}

// for internal use.
// the mysql package uses sql.NullTime if it is available.
// if not, the package uses mysql.NullTime.
type nullTime = NullTime // sql.NullTime is not available