Skip to content

Commit

Permalink
feat(spanner): implement valuer and scanner interfaces (#4936)
Browse files Browse the repository at this point in the history
### Valuer / Scanner Interfaces
Adds implementations for the `driver.Valuer` and `sql.Scanner` interfaces for the Spanner Null* types. This makes it possible to use both the `spanner.Null*` types and the underlying native types in the Go sql driver.

That is, both the following will then be supported:

```go
var r spanner.NullNumeric
rows.Scan(&r)
```

AND

```go
var r big.Rat
rows.Scan(&r)
```

It is not possible to implement this directly in the Go sql driver, as these types are defined in the `spanner` package.

The interfaces are not implemented for the `spanner.NullJSON` type for two reasons:

1. `NullJSON` already has a field called `Value`, which makes it technically impossible to add a method called `Value()`.
2. The underlying value of `NullJSON` is of type `interface{}`, which means that it can be anything. This means that there is no relevant other type than `NullJSON` that a user can use when calling `sql.Row#Scan(dest ...interface{})` for a JSON column.

### Gorm Data Type
Adds default data type mappings for the `spanner.Null*` types. That is; `NullInt64` is for example mapped by default to an `INT64` column. This allows structs that use `spanner.Null*` types for its fields to be used directly in Gorm migrations without the need to annotate them with the data type they should have in the database.
This feature is implemented by adding the `func GormDataType() string` to each of the `spanner.Null*` types.
  • Loading branch information
olavloite committed Oct 19, 2021
1 parent 44bc953 commit 4537b45
Showing 1 changed file with 287 additions and 1 deletion.
288 changes: 287 additions & 1 deletion spanner/value.go
Expand Up @@ -18,6 +18,8 @@ package spanner

import (
"bytes"
"database/sql"
"database/sql/driver"
"encoding/base64"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -210,6 +212,43 @@ func (n *NullInt64) UnmarshalJSON(payload []byte) error {
return nil
}

// Value implements the driver.Valuer interface.
func (n NullInt64) Value() (driver.Value, error) {
if n.IsNull() {
return nil, nil
}
return n.Int64, nil
}

// Scan implements the sql.Scanner interface.
func (n *NullInt64) Scan(value interface{}) error {
if value == nil {
n.Int64, n.Valid = 0, false
return nil
}
n.Valid = true
switch p := value.(type) {
default:
return spannerErrorf(codes.InvalidArgument, "invalid type for NullInt64: %v", p)
case *int64:
n.Int64 = *p
case int64:
n.Int64 = p
case *NullInt64:
n.Int64 = p.Int64
n.Valid = p.Valid
case NullInt64:
n.Int64 = p.Int64
n.Valid = p.Valid
}
return nil
}

// GormDataType is used by gorm to determine the default data type for fields with this type.
func (n NullInt64) GormDataType() string {
return "INT64"
}

// NullString represents a Cloud Spanner STRING that may be NULL.
type NullString struct {
StringVal string // StringVal contains the value when it is non-NULL, and an empty string when NULL.
Expand Down Expand Up @@ -256,6 +295,43 @@ func (n *NullString) UnmarshalJSON(payload []byte) error {
return nil
}

// Value implements the driver.Valuer interface.
func (n NullString) Value() (driver.Value, error) {
if n.IsNull() {
return nil, nil
}
return n.StringVal, nil
}

// Scan implements the sql.Scanner interface.
func (n *NullString) Scan(value interface{}) error {
if value == nil {
n.StringVal, n.Valid = "", false
return nil
}
n.Valid = true
switch p := value.(type) {
default:
return spannerErrorf(codes.InvalidArgument, "invalid type for NullString: %v", p)
case *string:
n.StringVal = *p
case string:
n.StringVal = p
case *NullString:
n.StringVal = p.StringVal
n.Valid = p.Valid
case NullString:
n.StringVal = p.StringVal
n.Valid = p.Valid
}
return nil
}

// GormDataType is used by gorm to determine the default data type for fields with this type.
func (n NullString) GormDataType() string {
return "STRING(MAX)"
}

// NullFloat64 represents a Cloud Spanner FLOAT64 that may be NULL.
type NullFloat64 struct {
Float64 float64 // Float64 contains the value when it is non-NULL, and zero when NULL.
Expand Down Expand Up @@ -302,6 +378,43 @@ func (n *NullFloat64) UnmarshalJSON(payload []byte) error {
return nil
}

// Value implements the driver.Valuer interface.
func (n NullFloat64) Value() (driver.Value, error) {
if n.IsNull() {
return nil, nil
}
return n.Float64, nil
}

// Scan implements the sql.Scanner interface.
func (n *NullFloat64) Scan(value interface{}) error {
if value == nil {
n.Float64, n.Valid = 0, false
return nil
}
n.Valid = true
switch p := value.(type) {
default:
return spannerErrorf(codes.InvalidArgument, "invalid type for NullFloat64: %v", p)
case *float64:
n.Float64 = *p
case float64:
n.Float64 = p
case *NullFloat64:
n.Float64 = p.Float64
n.Valid = p.Valid
case NullFloat64:
n.Float64 = p.Float64
n.Valid = p.Valid
}
return nil
}

// GormDataType is used by gorm to determine the default data type for fields with this type.
func (n NullFloat64) GormDataType() string {
return "FLOAT64"
}

// NullBool represents a Cloud Spanner BOOL that may be NULL.
type NullBool struct {
Bool bool // Bool contains the value when it is non-NULL, and false when NULL.
Expand Down Expand Up @@ -348,6 +461,43 @@ func (n *NullBool) UnmarshalJSON(payload []byte) error {
return nil
}

// Value implements the driver.Valuer interface.
func (n NullBool) Value() (driver.Value, error) {
if n.IsNull() {
return nil, nil
}
return n.Bool, nil
}

// Scan implements the sql.Scanner interface.
func (n *NullBool) Scan(value interface{}) error {
if value == nil {
n.Bool, n.Valid = false, false
return nil
}
n.Valid = true
switch p := value.(type) {
default:
return spannerErrorf(codes.InvalidArgument, "invalid type for NullBool: %v", p)
case *bool:
n.Bool = *p
case bool:
n.Bool = p
case *NullBool:
n.Bool = p.Bool
n.Valid = p.Valid
case NullBool:
n.Bool = p.Bool
n.Valid = p.Valid
}
return nil
}

// GormDataType is used by gorm to determine the default data type for fields with this type.
func (n NullBool) GormDataType() string {
return "BOOL"
}

// NullTime represents a Cloud Spanner TIMESTAMP that may be null.
type NullTime struct {
Time time.Time // Time contains the value when it is non-NULL, and a zero time.Time when NULL.
Expand Down Expand Up @@ -399,6 +549,43 @@ func (n *NullTime) UnmarshalJSON(payload []byte) error {
return nil
}

// Value implements the driver.Valuer interface.
func (n NullTime) Value() (driver.Value, error) {
if n.IsNull() {
return nil, nil
}
return n.Time, nil
}

// Scan implements the sql.Scanner interface.
func (n *NullTime) Scan(value interface{}) error {
if value == nil {
n.Time, n.Valid = time.Time{}, false
return nil
}
n.Valid = true
switch p := value.(type) {
default:
return spannerErrorf(codes.InvalidArgument, "invalid type for NullTime: %v", p)
case *time.Time:
n.Time = *p
case time.Time:
n.Time = p
case *NullTime:
n.Time = p.Time
n.Valid = p.Valid
case NullTime:
n.Time = p.Time
n.Valid = p.Valid
}
return nil
}

// GormDataType is used by gorm to determine the default data type for fields with this type.
func (n NullTime) GormDataType() string {
return "TIMESTAMP"
}

// NullDate represents a Cloud Spanner DATE that may be null.
type NullDate struct {
Date civil.Date // Date contains the value when it is non-NULL, and a zero civil.Date when NULL.
Expand Down Expand Up @@ -450,6 +637,43 @@ func (n *NullDate) UnmarshalJSON(payload []byte) error {
return nil
}

// Value implements the driver.Valuer interface.
func (n NullDate) Value() (driver.Value, error) {
if n.IsNull() {
return nil, nil
}
return n.Date, nil
}

// Scan implements the sql.Scanner interface.
func (n *NullDate) Scan(value interface{}) error {
if value == nil {
n.Date, n.Valid = civil.Date{}, false
return nil
}
n.Valid = true
switch p := value.(type) {
default:
return spannerErrorf(codes.InvalidArgument, "invalid type for NullDate: %v", p)
case *civil.Date:
n.Date = *p
case civil.Date:
n.Date = p
case *NullDate:
n.Date = p.Date
n.Valid = p.Valid
case NullDate:
n.Date = p.Date
n.Valid = p.Valid
}
return nil
}

// GormDataType is used by gorm to determine the default data type for fields with this type.
func (n NullDate) GormDataType() string {
return "DATE"
}

// NullNumeric represents a Cloud Spanner Numeric that may be NULL.
type NullNumeric struct {
Numeric big.Rat // Numeric contains the value when it is non-NULL, and a zero big.Rat when NULL.
Expand Down Expand Up @@ -501,10 +725,52 @@ func (n *NullNumeric) UnmarshalJSON(payload []byte) error {
return nil
}

// Value implements the driver.Valuer interface.
func (n NullNumeric) Value() (driver.Value, error) {
if n.IsNull() {
return nil, nil
}
return n.Numeric, nil
}

// Scan implements the sql.Scanner interface.
func (n *NullNumeric) Scan(value interface{}) error {
if value == nil {
n.Numeric, n.Valid = big.Rat{}, false
return nil
}
n.Valid = true
switch p := value.(type) {
default:
return spannerErrorf(codes.InvalidArgument, "invalid type for NullNumeric: %v", p)
case *big.Rat:
n.Numeric = *p
case big.Rat:
n.Numeric = p
case *NullNumeric:
n.Numeric = p.Numeric
n.Valid = p.Valid
case NullNumeric:
n.Numeric = p.Numeric
n.Valid = p.Valid
}
return nil
}

// GormDataType is used by gorm to determine the default data type for fields with this type.
func (n NullNumeric) GormDataType() string {
return "NUMERIC"
}

// NullJSON represents a Cloud Spanner JSON that may be NULL.
//
// This type must always be used when encoding values to a JSON column in Cloud
// Spanner.
//
// NullJSON does not implement the driver.Valuer and sql.Scanner interfaces, as
// the underlying value can be anything. This means that the type NullJSON must
// also be used when calling sql.Row#Scan(dest ...interface{}) for a JSON
// column.
type NullJSON struct {
Value interface{} // Val contains the value when it is non-NULL, and nil when NULL.
Valid bool // Valid is true if Json is not NULL.
Expand Down Expand Up @@ -554,6 +820,11 @@ func (n *NullJSON) UnmarshalJSON(payload []byte) error {
return nil
}

// GormDataType is used by gorm to determine the default data type for fields with this type.
func (n NullJSON) GormDataType() string {
return "JSON"
}

// NullRow represents a Cloud Spanner STRUCT that may be NULL.
// See also the document for Row.
// Note that NullRow is not a valid Cloud Spanner column Type.
Expand Down Expand Up @@ -700,7 +971,12 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error {
return err
}
*p = x
case *NullString, **string:
case *NullString, **string, *sql.NullString:
// Most Null* types are automatically supported for both spanner.Null* and sql.Null* types, except for
// NullString, and we need to add explicit support for it here. The reason that the other types are
// automatically supported is that they use the same field names (e.g. spanner.NullBool and sql.NullBool both
// contain the fields Valid and Bool). spanner.NullString has a field StringVal, sql.NullString has a field
// String.
if p == nil {
return errNilDst(p)
}
Expand All @@ -713,6 +989,8 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error {
*sp = NullString{}
case **string:
*sp = nil
case *sql.NullString:
*sp = sql.NullString{}
}
break
}
Expand All @@ -726,6 +1004,9 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error {
sp.StringVal = x
case **string:
*sp = &x
case *sql.NullString:
sp.Valid = true
sp.String = x
}
case *[]NullString, *[]*string:
if p == nil {
Expand Down Expand Up @@ -2737,6 +3018,11 @@ func encodeValue(v interface{}) (*proto3.Value, *sppb.Type, error) {
return encodeValue(v.StringVal)
}
pt = stringType()
case sql.NullString:
if v.Valid {
return encodeValue(v.String)
}
pt = stringType()
case []string:
if v != nil {
pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] })
Expand Down

0 comments on commit 4537b45

Please sign in to comment.