Skip to content

Commit

Permalink
fix(spanner): allow decoding null values to spanner.Decoder
Browse files Browse the repository at this point in the history
Allow NULL values from the database to be passed in to the DecodeSpanner method
of a struct that implements spanner.Decoder.

Fixes #4552
  • Loading branch information
olavloite committed Aug 5, 2021
1 parent e9ddbfe commit 9cf7bc4
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 3 deletions.
22 changes: 20 additions & 2 deletions spanner/value.go
Expand Up @@ -1330,7 +1330,7 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error {
// Check if the pointer is a custom type that implements spanner.Decoder
// interface.
if decodedVal, ok := ptr.(Decoder); ok {
x, err := getGenericValue(v)
x, err := getGenericValue(t, v)
if err != nil {
return err
}
Expand Down Expand Up @@ -1909,19 +1909,37 @@ func getListValue(v *proto3.Value) (*proto3.ListValue, error) {
}

// getGenericValue returns the interface{} value encoded in proto3.Value.
func getGenericValue(v *proto3.Value) (interface{}, error) {
func getGenericValue(t *sppb.Type, v *proto3.Value) (interface{}, error) {
switch x := v.GetKind().(type) {
case *proto3.Value_NumberValue:
return x.NumberValue, nil
case *proto3.Value_BoolValue:
return x.BoolValue, nil
case *proto3.Value_StringValue:
return x.StringValue, nil
case *proto3.Value_NullValue:
return getTypedNil(t)
default:
return 0, errSrcVal(v, "Number, Bool, String")
}
}

func getTypedNil(t *sppb.Type) (interface{}, error) {
switch t.Code {
case sppb.TypeCode_FLOAT64:
var f *float64
return f, nil
case sppb.TypeCode_BOOL:
var b *bool
return b, nil
default:
// The encoding for most types is string, except for the ones listed
// above.
var s *string
return s, nil
}
}

// errUnexpectedNumericStr returns error for decoder getting an unexpected
// string for representing special numeric values.
func errUnexpectedNumericStr(s string) error {
Expand Down
31 changes: 30 additions & 1 deletion spanner/value_test.go
Expand Up @@ -179,6 +179,21 @@ func (c *customStructToDate) DecodeSpanner(val interface{}) (err error) {
return nil
}

type customStructToNull struct {
val interface{}
}

func (c customStructToNull) EncodeSpanner() (interface{}, error) {
return c.val, nil
}

func (c *customStructToNull) DecodeSpanner(val interface{}) (err error) {
if reflect.ValueOf(val).IsNil() {
return nil
}
return fmt.Errorf("val mismatch: expected nil, got %v", val)
}

// Test encoding Values.
func TestEncodeValue(t *testing.T) {
type CustomString string
Expand Down Expand Up @@ -384,6 +399,14 @@ func TestEncodeValue(t *testing.T) {
{customStructToBytes{[]byte("A"), []byte("B")}, bytesProto([]byte("AB")), tBytes, "a struct to bytes"},
{customStructToTime{"A", "B"}, timeProto(tValue), tTime, "a struct to time"},
{customStructToDate{"A", "B"}, dateProto(dValue), tDate, "a struct to date"},
{customStructToNull{val: bNilPtr}, nullProto(), tBool, "a struct to null bool"},
{customStructToNull{val: []byte(nil)}, nullProto(), tBytes, "a struct to null bytes"},
{customStructToNull{val: sNilPtr}, nullProto(), tString, "a struct to null string"},
{customStructToNull{val: iNilPtr}, nullProto(), tInt, "a struct to null int"},
{customStructToNull{val: fNilPtr}, nullProto(), tFloat, "a struct to null float"},
{customStructToNull{val: numNilPtr}, nullProto(), tNumeric, "a struct to null numeric"},
{customStructToNull{val: dNilPtr}, nullProto(), tDate, "a struct to null date"},
{customStructToNull{val: tNilPtr}, nullProto(), tTime, "a struct to null timestamp"},
// CUSTOM NUMERIC / CUSTOM NUMERIC ARRAY
{CustomNumeric(*numValuePtr), numericProto(numValuePtr), tNumeric, "CustomNumeric"},
{CustomNullNumeric{*numValuePtr, true}, numericProto(numValuePtr), tNumeric, "CustomNullNumeric with value"},
Expand Down Expand Up @@ -1632,6 +1655,12 @@ func TestDecodeValue(t *testing.T) {
{desc: "decode BYTES to CustomStructToBytes", proto: bytesProto([]byte("AB")), protoType: bytesType(), want: customStructToBytes{[]byte("A"), []byte("B")}},
{desc: "decode TIMESTAMP to CustomStructToTime", proto: timeProto(t1), protoType: timeType(), want: customStructToTime{"A", "B"}},
{desc: "decode DATE to CustomStructToDate", proto: dateProto(d1), protoType: dateType(), want: customStructToDate{"A", "B"}},
{desc: "decode NULL bool to CustomStructToNull", proto: nullProto(), protoType: boolType(), want: customStructToNull{}},
{desc: "decode NULL float to CustomStructToNull", proto: nullProto(), protoType: floatType(), want: customStructToNull{}},
{desc: "decode NULL string to CustomStructToNull", proto: nullProto(), protoType: stringType(), want: customStructToNull{}},
{desc: "decode NULL array of bool to CustomStructToNull", proto: nullProto(), protoType: listType(boolType()), want: customStructToNull{}},
{desc: "decode NULL array of float to CustomStructToNull", proto: nullProto(), protoType: listType(floatType()), want: customStructToNull{}},
{desc: "decode NULL array of string to CustomStructToNull", proto: nullProto(), protoType: listType(stringType()), want: customStructToNull{}},
} {
gotp := reflect.New(reflect.TypeOf(test.want))
v := gotp.Interface()
Expand Down Expand Up @@ -1665,7 +1694,7 @@ func TestDecodeValue(t *testing.T) {
continue
}
got := reflect.Indirect(gotp).Interface()
if !testutil.Equal(got, test.want, cmp.AllowUnexported(CustomNumeric{}, CustomTime{}, CustomDate{}, Row{}, big.Rat{}, big.Int{})) {
if !testutil.Equal(got, test.want, cmp.AllowUnexported(CustomNumeric{}, CustomTime{}, CustomDate{}, Row{}, big.Rat{}, big.Int{}, customStructToNull{})) {
t.Errorf("%s: unexpected decoding result - got %v (%T), want %v (%T)", test.desc, got, got, test.want, test.want)
}
}
Expand Down

0 comments on commit 9cf7bc4

Please sign in to comment.