Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(spanner): allow decoding null values to spanner.Decoder #4558

Merged
merged 2 commits into from Aug 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 20 additions & 2 deletions spanner/value.go
Expand Up @@ -1330,7 +1330,7 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error {
// Check if the pointer is a custom type that implements spanner.Decoder
// interface.
if decodedVal, ok := ptr.(Decoder); ok {
x, err := getGenericValue(v)
x, err := getGenericValue(t, v)
if err != nil {
return err
}
Expand Down Expand Up @@ -1909,19 +1909,37 @@ func getListValue(v *proto3.Value) (*proto3.ListValue, error) {
}

// getGenericValue returns the interface{} value encoded in proto3.Value.
func getGenericValue(v *proto3.Value) (interface{}, error) {
func getGenericValue(t *sppb.Type, v *proto3.Value) (interface{}, error) {
switch x := v.GetKind().(type) {
case *proto3.Value_NumberValue:
return x.NumberValue, nil
case *proto3.Value_BoolValue:
return x.BoolValue, nil
case *proto3.Value_StringValue:
return x.StringValue, nil
case *proto3.Value_NullValue:
return getTypedNil(t)
default:
return 0, errSrcVal(v, "Number, Bool, String")
}
}

func getTypedNil(t *sppb.Type) (interface{}, error) {
switch t.Code {
case sppb.TypeCode_FLOAT64:
var f *float64
return f, nil
case sppb.TypeCode_BOOL:
var b *bool
return b, nil
default:
// The encoding for most types is string, except for the ones listed
// above.
var s *string
return s, nil
}
}

// errUnexpectedNumericStr returns error for decoder getting an unexpected
// string for representing special numeric values.
func errUnexpectedNumericStr(s string) error {
Expand Down
31 changes: 30 additions & 1 deletion spanner/value_test.go
Expand Up @@ -179,6 +179,21 @@ func (c *customStructToDate) DecodeSpanner(val interface{}) (err error) {
return nil
}

type customStructToNull struct {
val interface{}
}

func (c customStructToNull) EncodeSpanner() (interface{}, error) {
return c.val, nil
}

func (c *customStructToNull) DecodeSpanner(val interface{}) (err error) {
if reflect.ValueOf(val).IsNil() {
return nil
}
return fmt.Errorf("val mismatch: expected nil, got %v", val)
}

// Test encoding Values.
func TestEncodeValue(t *testing.T) {
type CustomString string
Expand Down Expand Up @@ -384,6 +399,14 @@ func TestEncodeValue(t *testing.T) {
{customStructToBytes{[]byte("A"), []byte("B")}, bytesProto([]byte("AB")), tBytes, "a struct to bytes"},
{customStructToTime{"A", "B"}, timeProto(tValue), tTime, "a struct to time"},
{customStructToDate{"A", "B"}, dateProto(dValue), tDate, "a struct to date"},
{customStructToNull{val: bNilPtr}, nullProto(), tBool, "a struct to null bool"},
{customStructToNull{val: []byte(nil)}, nullProto(), tBytes, "a struct to null bytes"},
{customStructToNull{val: sNilPtr}, nullProto(), tString, "a struct to null string"},
{customStructToNull{val: iNilPtr}, nullProto(), tInt, "a struct to null int"},
{customStructToNull{val: fNilPtr}, nullProto(), tFloat, "a struct to null float"},
{customStructToNull{val: numNilPtr}, nullProto(), tNumeric, "a struct to null numeric"},
{customStructToNull{val: dNilPtr}, nullProto(), tDate, "a struct to null date"},
{customStructToNull{val: tNilPtr}, nullProto(), tTime, "a struct to null timestamp"},
// CUSTOM NUMERIC / CUSTOM NUMERIC ARRAY
{CustomNumeric(*numValuePtr), numericProto(numValuePtr), tNumeric, "CustomNumeric"},
{CustomNullNumeric{*numValuePtr, true}, numericProto(numValuePtr), tNumeric, "CustomNullNumeric with value"},
Expand Down Expand Up @@ -1632,6 +1655,12 @@ func TestDecodeValue(t *testing.T) {
{desc: "decode BYTES to CustomStructToBytes", proto: bytesProto([]byte("AB")), protoType: bytesType(), want: customStructToBytes{[]byte("A"), []byte("B")}},
{desc: "decode TIMESTAMP to CustomStructToTime", proto: timeProto(t1), protoType: timeType(), want: customStructToTime{"A", "B"}},
{desc: "decode DATE to CustomStructToDate", proto: dateProto(d1), protoType: dateType(), want: customStructToDate{"A", "B"}},
{desc: "decode NULL bool to CustomStructToNull", proto: nullProto(), protoType: boolType(), want: customStructToNull{}},
{desc: "decode NULL float to CustomStructToNull", proto: nullProto(), protoType: floatType(), want: customStructToNull{}},
{desc: "decode NULL string to CustomStructToNull", proto: nullProto(), protoType: stringType(), want: customStructToNull{}},
{desc: "decode NULL array of bool to CustomStructToNull", proto: nullProto(), protoType: listType(boolType()), want: customStructToNull{}},
{desc: "decode NULL array of float to CustomStructToNull", proto: nullProto(), protoType: listType(floatType()), want: customStructToNull{}},
{desc: "decode NULL array of string to CustomStructToNull", proto: nullProto(), protoType: listType(stringType()), want: customStructToNull{}},
} {
gotp := reflect.New(reflect.TypeOf(test.want))
v := gotp.Interface()
Expand Down Expand Up @@ -1665,7 +1694,7 @@ func TestDecodeValue(t *testing.T) {
continue
}
got := reflect.Indirect(gotp).Interface()
if !testutil.Equal(got, test.want, cmp.AllowUnexported(CustomNumeric{}, CustomTime{}, CustomDate{}, Row{}, big.Rat{}, big.Int{})) {
if !testutil.Equal(got, test.want, cmp.AllowUnexported(CustomNumeric{}, CustomTime{}, CustomDate{}, Row{}, big.Rat{}, big.Int{}, customStructToNull{})) {
t.Errorf("%s: unexpected decoding result - got %v (%T), want %v (%T)", test.desc, got, got, test.want, test.want)
}
}
Expand Down