diff --git a/driver_test.go b/driver_test.go index 34b476ed3..aa55d2f55 100644 --- a/driver_test.go +++ b/driver_test.go @@ -2758,13 +2758,13 @@ func TestRowsColumnTypes(t *testing.T) { nfNULL := sql.NullFloat64{Float64: 0.0, Valid: false} nf0 := sql.NullFloat64{Float64: 0.0, Valid: true} nf1337 := sql.NullFloat64{Float64: 13.37, Valid: true} - nt0 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 0, time.UTC), Valid: true} - nt1 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 100000000, time.UTC), Valid: true} - nt2 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 110000000, time.UTC), Valid: true} - nt6 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 111111000, time.UTC), Valid: true} - nd1 := NullTime{Time: time.Date(2006, 01, 02, 0, 0, 0, 0, time.UTC), Valid: true} - nd2 := NullTime{Time: time.Date(2006, 03, 04, 0, 0, 0, 0, time.UTC), Valid: true} - ndNULL := NullTime{Time: time.Time{}, Valid: false} + nt0 := nullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 0, time.UTC), Valid: true} + nt1 := nullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 100000000, time.UTC), Valid: true} + nt2 := nullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 110000000, time.UTC), Valid: true} + nt6 := nullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 111111000, time.UTC), Valid: true} + nd1 := nullTime{Time: time.Date(2006, 01, 02, 0, 0, 0, 0, time.UTC), Valid: true} + nd2 := nullTime{Time: time.Date(2006, 03, 04, 0, 0, 0, 0, time.UTC), Valid: true} + ndNULL := nullTime{Time: time.Time{}, Valid: false} rbNULL := sql.RawBytes(nil) rb0 := sql.RawBytes("0") rb42 := sql.RawBytes("42") @@ -2844,131 +2844,125 @@ func TestRowsColumnTypes(t *testing.T) { values2 = values2[:len(values2)-2] values3 = values3[:len(values3)-2] - dsns := []string{ - dsn + "&parseTime=true", - dsn + "&parseTime=false", - } - for _, testdsn := range dsns { - runTests(t, testdsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (" + schema + ")") - dbt.mustExec("INSERT INTO test VALUES (" + values1 + "), (" + values2 + "), (" + values3 + ")") + runTests(t, dsn+"&parseTime=true", func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (" + schema + ")") + dbt.mustExec("INSERT INTO test VALUES (" + values1 + "), (" + values2 + "), (" + values3 + ")") - rows, err := dbt.db.Query("SELECT * FROM test") - if err != nil { - t.Fatalf("Query: %v", err) - } + rows, err := dbt.db.Query("SELECT * FROM test") + if err != nil { + t.Fatalf("Query: %v", err) + } - tt, err := rows.ColumnTypes() - if err != nil { - t.Fatalf("ColumnTypes: %v", err) - } + tt, err := rows.ColumnTypes() + if err != nil { + t.Fatalf("ColumnTypes: %v", err) + } - if len(tt) != len(columns) { - t.Fatalf("unexpected number of columns: expected %d, got %d", len(columns), len(tt)) - } + if len(tt) != len(columns) { + t.Fatalf("unexpected number of columns: expected %d, got %d", len(columns), len(tt)) + } - types := make([]reflect.Type, len(tt)) - for i, tp := range tt { - column := columns[i] + types := make([]reflect.Type, len(tt)) + for i, tp := range tt { + column := columns[i] - // Name - name := tp.Name() - if name != column.name { - t.Errorf("column name mismatch %s != %s", name, column.name) - continue - } + // Name + name := tp.Name() + if name != column.name { + t.Errorf("column name mismatch %s != %s", name, column.name) + continue + } - // DatabaseTypeName - databaseTypeName := tp.DatabaseTypeName() - if databaseTypeName != column.databaseTypeName { - t.Errorf("databasetypename name mismatch for column %q: %s != %s", name, databaseTypeName, column.databaseTypeName) - continue - } + // DatabaseTypeName + databaseTypeName := tp.DatabaseTypeName() + if databaseTypeName != column.databaseTypeName { + t.Errorf("databasetypename name mismatch for column %q: %s != %s", name, databaseTypeName, column.databaseTypeName) + continue + } - // ScanType - scanType := tp.ScanType() - if scanType != column.scanType { - if scanType == nil { - t.Errorf("scantype is null for column %q", name) - } else { - t.Errorf("scantype mismatch for column %q: %s != %s", name, scanType.Name(), column.scanType.Name()) - } - continue + // ScanType + scanType := tp.ScanType() + if scanType != column.scanType { + if scanType == nil { + t.Errorf("scantype is null for column %q", name) + } else { + t.Errorf("scantype mismatch for column %q: %s != %s", name, scanType.Name(), column.scanType.Name()) } - types[i] = scanType - - // Nullable - nullable, ok := tp.Nullable() + continue + } + types[i] = scanType + + // Nullable + nullable, ok := tp.Nullable() + if !ok { + t.Errorf("nullable not ok %q", name) + continue + } + if nullable != column.nullable { + t.Errorf("nullable mismatch for column %q: %t != %t", name, nullable, column.nullable) + } + + // Length + // length, ok := tp.Length() + // if length != column.length { + // if !ok { + // t.Errorf("length not ok for column %q", name) + // } else { + // t.Errorf("length mismatch for column %q: %d != %d", name, length, column.length) + // } + // continue + // } + + // Precision and Scale + precision, scale, ok := tp.DecimalSize() + if precision != column.precision { if !ok { - t.Errorf("nullable not ok %q", name) - continue - } - if nullable != column.nullable { - t.Errorf("nullable mismatch for column %q: %t != %t", name, nullable, column.nullable) + t.Errorf("precision not ok for column %q", name) + } else { + t.Errorf("precision mismatch for column %q: %d != %d", name, precision, column.precision) } - - // Length - // length, ok := tp.Length() - // if length != column.length { - // if !ok { - // t.Errorf("length not ok for column %q", name) - // } else { - // t.Errorf("length mismatch for column %q: %d != %d", name, length, column.length) - // } - // continue - // } - - // Precision and Scale - precision, scale, ok := tp.DecimalSize() - if precision != column.precision { - if !ok { - t.Errorf("precision not ok for column %q", name) - } else { - t.Errorf("precision mismatch for column %q: %d != %d", name, precision, column.precision) - } - continue - } - if scale != column.scale { - if !ok { - t.Errorf("scale not ok for column %q", name) - } else { - t.Errorf("scale mismatch for column %q: %d != %d", name, scale, column.scale) - } - continue + continue + } + if scale != column.scale { + if !ok { + t.Errorf("scale not ok for column %q", name) + } else { + t.Errorf("scale mismatch for column %q: %d != %d", name, scale, column.scale) } + continue } + } - values := make([]interface{}, len(tt)) - for i := range values { - values[i] = reflect.New(types[i]).Interface() + values := make([]interface{}, len(tt)) + for i := range values { + values[i] = reflect.New(types[i]).Interface() + } + i := 0 + for rows.Next() { + err = rows.Scan(values...) + if err != nil { + t.Fatalf("failed to scan values in %v", err) } - i := 0 - for rows.Next() { - err = rows.Scan(values...) - if err != nil { - t.Fatalf("failed to scan values in %v", err) - } - for j := range values { - value := reflect.ValueOf(values[j]).Elem().Interface() - if !reflect.DeepEqual(value, columns[j].valuesOut[i]) { - if columns[j].scanType == scanTypeRawBytes { - t.Errorf("row %d, column %d: %v != %v", i, j, string(value.(sql.RawBytes)), string(columns[j].valuesOut[i].(sql.RawBytes))) - } else { - t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i]) - } + for j := range values { + value := reflect.ValueOf(values[j]).Elem().Interface() + if !reflect.DeepEqual(value, columns[j].valuesOut[i]) { + if columns[j].scanType == scanTypeRawBytes { + t.Errorf("row %d, column %d: %v != %v", i, j, string(value.(sql.RawBytes)), string(columns[j].valuesOut[i].(sql.RawBytes))) + } else { + t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i]) } } - i++ - } - if i != 3 { - t.Errorf("expected 3 rows, got %d", i) } + i++ + } + if i != 3 { + t.Errorf("expected 3 rows, got %d", i) + } - if err := rows.Close(); err != nil { - t.Errorf("error closing rows: %s", err) - } - }) - } + if err := rows.Close(); err != nil { + t.Errorf("error closing rows: %s", err) + } + }) } func TestValuerWithValueReceiverGivenNilValue(t *testing.T) { diff --git a/fields.go b/fields.go index e1e2ece4b..ed6c7a37d 100644 --- a/fields.go +++ b/fields.go @@ -106,7 +106,7 @@ var ( scanTypeInt64 = reflect.TypeOf(int64(0)) scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) scanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) - scanTypeNullTime = reflect.TypeOf(NullTime{}) + scanTypeNullTime = reflect.TypeOf(nullTime{}) scanTypeUint8 = reflect.TypeOf(uint8(0)) scanTypeUint16 = reflect.TypeOf(uint16(0)) scanTypeUint32 = reflect.TypeOf(uint32(0)) diff --git a/nulltime_go113.go b/nulltime_go113.go index c392594dd..da360a459 100644 --- a/nulltime_go113.go +++ b/nulltime_go113.go @@ -29,3 +29,8 @@ import ( // // This NullTime implementation is not driver-specific type NullTime sql.NullTime + +// for internal use. +// the mysql package uses sql.NullTime if it is available. +// if not, the package uses mysql.NullTime. +type nullTime = sql.NullTime // sql.NullTime is available diff --git a/nulltime_legacy.go b/nulltime_legacy.go index 86d159d44..9f7ae27a8 100644 --- a/nulltime_legacy.go +++ b/nulltime_legacy.go @@ -32,3 +32,8 @@ type NullTime struct { Time time.Time Valid bool // Valid is true if Time is not NULL } + +// for internal use. +// the mysql package uses sql.NullTime if it is available. +// if not, the package uses mysql.NullTime. +type nullTime = NullTime // sql.NullTime is not available