diff --git a/spanner/client_test.go b/spanner/client_test.go index a82996102856..006e49afb1c9 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -1833,6 +1833,163 @@ func TestClient_QueryWithCallOptions(t *testing.T) { } } +func TestClient_EncodeCustomFieldType(t *testing.T) { + t.Parallel() + + type typesTable struct { + Int customStructToInt `spanner:"Int"` + String customStructToString `spanner:"String"` + Float customStructToFloat `spanner:"Float"` + Bool customStructToBool `spanner:"Bool"` + Time customStructToTime `spanner:"Time"` + Date customStructToDate `spanner:"Date"` + } + + server, client, teardown := setupMockedTestServer(t) + defer teardown() + ctx := context.Background() + + d := typesTable{ + Int: customStructToInt{1, 23}, + String: customStructToString{"A", "B"}, + Float: customStructToFloat{1.23, 12.3}, + Bool: customStructToBool{true, false}, + Time: customStructToTime{"A", "B"}, + Date: customStructToDate{"A", "B"}, + } + + m, err := InsertStruct("Types", &d) + if err != nil { + t.Fatalf("err: %v", err) + } + + ms := []*Mutation{m} + _, err = client.Apply(ctx, ms) + if err != nil { + t.Fatalf("err: %v", err) + } + + reqs := drainRequestsFromServer(server.TestSpanner) + + for _, req := range reqs { + if commitReq, ok := req.(*sppb.CommitRequest); ok { + val := commitReq.Mutations[0].GetInsert().Values[0] + + if got, want := val.Values[0].GetStringValue(), "123"; got != want { + t.Fatalf("value mismatch: got %v (kind %T), want %v", got, val.Values[0].GetKind(), want) + } + if got, want := val.Values[1].GetStringValue(), "A-B"; got != want { + t.Fatalf("value mismatch: got %v (kind %T), want %v", got, val.Values[1].GetKind(), want) + } + if got, want := val.Values[2].GetNumberValue(), float64(123.123); got != want { + t.Fatalf("value mismatch: got %v (kind %T), want %v", got, val.Values[2].GetKind(), want) + } + if got, want := val.Values[3].GetBoolValue(), true; got != want { + t.Fatalf("value mismatch: got %v (kind %T), want %v", got, val.Values[3].GetKind(), want) + } + if got, want := val.Values[4].GetStringValue(), "2016-11-15T15:04:05.999999999Z"; got != want { + t.Fatalf("value mismatch: got %v (kind %T), want %v", got, val.Values[4].GetKind(), want) + } + if got, want := val.Values[5].GetStringValue(), "2016-11-15"; got != want { + t.Fatalf("value mismatch: got %v (kind %T), want %v", got, val.Values[5].GetKind(), want) + } + } + } +} + +func setupDecodeCustomFieldResult(server *MockedSpannerInMemTestServer, stmt string) error { + metadata := &sppb.ResultSetMetadata{ + RowType: &sppb.StructType{ + Fields: []*sppb.StructType_Field{ + {Name: "Int", Type: &sppb.Type{Code: sppb.TypeCode_INT64}}, + {Name: "String", Type: &sppb.Type{Code: sppb.TypeCode_STRING}}, + {Name: "Float", Type: &sppb.Type{Code: sppb.TypeCode_FLOAT64}}, + {Name: "Bool", Type: &sppb.Type{Code: sppb.TypeCode_BOOL}}, + {Name: "Time", Type: &sppb.Type{Code: sppb.TypeCode_TIMESTAMP}}, + {Name: "Date", Type: &sppb.Type{Code: sppb.TypeCode_DATE}}, + }, + }, + } + rowValues := []*structpb.Value{ + {Kind: &structpb.Value_StringValue{StringValue: "123"}}, + {Kind: &structpb.Value_StringValue{StringValue: "A-B"}}, + {Kind: &structpb.Value_NumberValue{NumberValue: float64(123.123)}}, + {Kind: &structpb.Value_BoolValue{BoolValue: true}}, + {Kind: &structpb.Value_StringValue{StringValue: "2016-11-15T15:04:05.999999999Z"}}, + {Kind: &structpb.Value_StringValue{StringValue: "2016-11-15"}}, + } + rows := []*structpb.ListValue{ + {Values: rowValues}, + } + resultSet := &sppb.ResultSet{ + Metadata: metadata, + Rows: rows, + } + result := &StatementResult{ + Type: StatementResultResultSet, + ResultSet: resultSet, + } + return server.TestSpanner.PutStatementResult(stmt, result) +} + +func TestClient_DecodeCustomFieldType(t *testing.T) { + t.Parallel() + + type typesTable struct { + Int customStructToInt `spanner:"Int"` + String customStructToString `spanner:"String"` + Float customStructToFloat `spanner:"Float"` + Bool customStructToBool `spanner:"Bool"` + Time customStructToTime `spanner:"Time"` + Date customStructToDate `spanner:"Date"` + } + + server, client, teardown := setupMockedTestServer(t) + defer teardown() + + query := "SELECT * FROM Types" + setupDecodeCustomFieldResult(server, query) + + ctx := context.Background() + stmt := Statement{SQL: query} + iter := client.Single().Query(ctx, stmt) + defer iter.Stop() + + var results []typesTable + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Fatalf("failed to get next: %v", err) + } + + var d typesTable + if err := row.ToStruct(&d); err != nil { + t.Fatalf("failed to convert a row to a struct: %v", err) + } + results = append(results, d) + } + + if len(results) > 1 { + t.Fatalf("mismatch length of array: got %v, want 1", results) + } + + want := typesTable{ + Int: customStructToInt{1, 23}, + String: customStructToString{"A", "B"}, + Float: customStructToFloat{1.23, 12.3}, + Bool: customStructToBool{true, false}, + Time: customStructToTime{"A", "B"}, + Date: customStructToDate{"A", "B"}, + } + got := results[0] + if !testEqual(got, want) { + t.Fatalf("mismatch result: got %v, want %v", got, want) + } +} + func TestBatchReadOnlyTransaction_QueryOptions(t *testing.T) { ctx := context.Background() qo := QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "1"}} diff --git a/spanner/value.go b/spanner/value.go index b3cad119408c..d206478eedfa 100644 --- a/spanner/value.go +++ b/spanner/value.go @@ -50,6 +50,51 @@ var ( jsonNullBytes = []byte("null") ) +// Encoder is the interface implemented by a custom type that can be encoded to +// a supported type by Spanner. A code example: +// +// type customField struct { +// Prefix string +// Suffix string +// } +// +// // Convert a customField value to a string +// func (cf customField) EncodeSpanner() (interface{}, error) { +// var b bytes.Buffer +// b.WriteString(cf.Prefix) +// b.WriteString("-") +// b.WriteString(cf.Suffix) +// return b.String(), nil +// } +type Encoder interface { + EncodeSpanner() (interface{}, error) +} + +// Decoder is the interface implemented by a custom type that can be decoded +// from a supported type by Spanner. A code example: +// +// type customField struct { +// Prefix string +// Suffix string +// } +// +// // Convert a string to a customField value +// func (cf *customField) DecodeSpanner(val interface{}) (err error) { +// strVal, ok := val.(string) +// if !ok { +// return fmt.Errorf("failed to decode customField: %v", val) +// } +// s := strings.Split(strVal, "-") +// if len(s) > 1 { +// cf.Prefix = s[0] +// cf.Suffix = s[1] +// } +// return nil +// } +type Decoder interface { + DecodeSpanner(input interface{}) error +} + // NullableValue is the interface implemented by all null value wrapper types. type NullableValue interface { // IsNull returns true if the underlying database value is null. @@ -1112,6 +1157,16 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { case *GenericColumnValue: *p = GenericColumnValue{Type: t, Value: v} default: + // Check if the pointer is a custom type that implements spanner.Decoder + // interface. + if decodedVal, ok := ptr.(Decoder); ok { + x, err := getGenericValue(v) + if err != nil { + return err + } + return decodedVal.DecodeSpanner(x) + } + // Check if the pointer is a variant of a base type. decodableType := getDecodableSpannerType(ptr) if decodableType != spannerTypeUnknown { @@ -1613,6 +1668,20 @@ func getListValue(v *proto3.Value) (*proto3.ListValue, error) { return nil, errSrcVal(v, "List") } +// getGenericValue returns the interface{} value encoded in proto3.Value. +func getGenericValue(v *proto3.Value) (interface{}, error) { + switch x := v.GetKind().(type) { + case *proto3.Value_NumberValue: + return x.NumberValue, nil + case *proto3.Value_BoolValue: + return x.BoolValue, nil + case *proto3.Value_StringValue: + return x.StringValue, nil + default: + return 0, errSrcVal(v, "Number, Bool, String") + } +} + // errUnexpectedNumStr returns error for decoder getting a unexpected string for // representing special float values. func errUnexpectedNumStr(s string) error { @@ -2381,6 +2450,16 @@ func encodeValue(v interface{}) (*proto3.Value, *sppb.Type, error) { case []GenericColumnValue: return nil, nil, errEncoderUnsupportedType(v) default: + // Check if the value is a custom type that implements spanner.Encoder + // interface. + if encodedVal, ok := v.(Encoder); ok { + nv, err := encodedVal.EncodeSpanner() + if err != nil { + return nil, nil, err + } + return encodeValue(nv) + } + // Check if the value is a variant of a base type. decodableType := getDecodableSpannerType(v) if decodableType != spannerTypeUnknown && decodableType != spannerTypeInvalid { @@ -2652,6 +2731,11 @@ func isSupportedMutationType(v interface{}) bool { GenericColumnValue: return true default: + // Check if the custom type implements spanner.Encoder interface. + if _, ok := v.(Encoder); ok { + return true + } + decodableType := getDecodableSpannerType(v) return decodableType != spannerTypeUnknown && decodableType != spannerTypeInvalid } diff --git a/spanner/value_test.go b/spanner/value_test.go index 87700590c239..d447639dc97e 100644 --- a/spanner/value_test.go +++ b/spanner/value_test.go @@ -59,6 +59,125 @@ func mustParseDate(s string) civil.Date { return d } +type customStructToString struct { + A string + B string +} + +// Convert the customStructToString +func (c customStructToString) EncodeSpanner() (interface{}, error) { + return "A-B", nil +} + +// Convert to customStructToString +func (c *customStructToString) DecodeSpanner(val interface{}) (err error) { + c.A = "A" + c.B = "B" + return nil +} + +type customStructToInt struct { + A int64 + B int64 +} + +// Convert the customStructToInt +func (c customStructToInt) EncodeSpanner() (interface{}, error) { + return 123, nil +} + +// Convert to customStructToInt +func (c *customStructToInt) DecodeSpanner(val interface{}) (err error) { + c.A = 1 + c.B = 23 + return nil +} + +type customStructToFloat struct { + A float64 + B float64 +} + +// Convert the customStructToFloat +func (c customStructToFloat) EncodeSpanner() (interface{}, error) { + return 123.123, nil +} + +// Convert to customStructToFloat +func (c *customStructToFloat) DecodeSpanner(val interface{}) (err error) { + c.A = 1.23 + c.B = 12.3 + return nil +} + +type customStructToBool struct { + A bool + B bool +} + +// Convert the customStructToBool +func (c customStructToBool) EncodeSpanner() (interface{}, error) { + return true, nil +} + +// Convert to customStructToBool +func (c *customStructToBool) DecodeSpanner(val interface{}) (err error) { + c.A = true + c.B = false + return nil +} + +type customStructToBytes struct { + A []byte + B []byte +} + +// Convert the customStructToBytes +func (c customStructToBytes) EncodeSpanner() (interface{}, error) { + return []byte("AB"), nil +} + +// Convert to customStructToBytes +func (c *customStructToBytes) DecodeSpanner(val interface{}) (err error) { + c.A = []byte("A") + c.B = []byte("B") + return nil +} + +type customStructToTime struct { + A string + B string +} + +// Convert the customStructToTime +func (c customStructToTime) EncodeSpanner() (interface{}, error) { + return t1, nil +} + +// Convert to customStructToTime +func (c *customStructToTime) DecodeSpanner(val interface{}) (err error) { + c.A = "A" + c.B = "B" + return nil +} + +type customStructToDate struct { + A string + B string +} + +// Convert the customStructToDate +func (c customStructToDate) EncodeSpanner() (interface{}, error) { + return d1, nil +} + +// Convert to customStructToDate +func (c *customStructToDate) DecodeSpanner(val interface{}) (err error) { + c.A = "A" + c.B = "B" + return nil +} + // Test encoding Values. func TestEncodeValue(t *testing.T) { type CustomString string @@ -239,6 +358,14 @@ func TestEncodeValue(t *testing.T) { {[]CustomDate{CustomDate(d1), CustomDate(d2)}, listProto(dateProto(d1), dateProto(d2)), listType(tDate), "[]CustomDate"}, {[]CustomNullDate(nil), nullProto(), listType(tDate), "null []CustomNullDate"}, {[]CustomNullDate{{d1, true}, {civil.Date{}, false}}, listProto(dateProto(d1), nullProto()), listType(tDate), "[]NullDate"}, + // CUSTOM STRUCT + {customStructToString{"A", "B"}, stringProto("A-B"), tString, "a struct to string"}, + {customStructToInt{1, 23}, intProto(123), tInt, "a struct to int"}, + {customStructToFloat{1.23, 12.3}, floatProto(123.123), tFloat, "a struct to float"}, + {customStructToBool{true, false}, boolProto(true), tBool, "a struct to bool"}, + {customStructToBytes{[]byte("A"), []byte("B")}, bytesProto([]byte("AB")), tBytes, "a struct to bytes"}, + {customStructToTime{"A", "B"}, timeProto(tValue), tTime, "a struct to time"}, + {customStructToDate{"A", "B"}, dateProto(dValue), tDate, "a struct to date"}, } { got, gotType, err := encodeValue(test.in) if err != nil { @@ -1441,6 +1568,14 @@ func TestDecodeValue(t *testing.T) { {desc: "decode ARRAY to []CustomDate", proto: listProto(dateProto(d1), dateProto(d2)), protoType: listType(dateType()), want: []CustomDate{CustomDate(d1), CustomDate(d2)}}, {desc: "decode NULL to []CustomNullDate", proto: nullProto(), protoType: listType(dateType()), want: []CustomNullDate(nil)}, {desc: "decode ARRAY to []CustomNullDate", proto: listProto(dateProto(d1), nullProto(), dateProto(d2)), protoType: listType(dateType()), want: []CustomNullDate{{d1, true}, {}, {d2, true}}}, + // CUSTOM STRUCT + {desc: "decode STRING to CustomStructToString", proto: stringProto("A-B"), protoType: stringType(), want: customStructToString{"A", "B"}}, + {desc: "decode INT64 to CustomStructToInt", proto: intProto(123), protoType: intType(), want: customStructToInt{1, 23}}, + {desc: "decode FLOAT64 to CustomStructToFloat", proto: floatProto(123.123), protoType: floatType(), want: customStructToFloat{1.23, 12.3}}, + {desc: "decode BOOL to CustomStructToBool", proto: boolProto(true), protoType: boolType(), want: customStructToBool{true, false}}, + {desc: "decode BYTES to CustomStructToBytes", proto: bytesProto([]byte("AB")), protoType: bytesType(), want: customStructToBytes{[]byte("A"), []byte("B")}}, + {desc: "decode TIMESTAMP to CustomStructToTime", proto: timeProto(t1), protoType: timeType(), want: customStructToTime{"A", "B"}}, + {desc: "decode DATE to CustomStructToDate", proto: dateProto(d1), protoType: dateType(), want: customStructToDate{"A", "B"}}, } { gotp := reflect.New(reflect.TypeOf(test.want)) err := decodeValue(test.proto, test.protoType, gotp.Interface())