Skip to content

Commit

Permalink
feat(spanner): support custom field type (#2614)
Browse files Browse the repository at this point in the history
* Support custom field type.
* Add an end-to-end test for encode/decode custom fields.
* Add code examples in the docs.
  • Loading branch information
hengfengli committed Jul 23, 2020
1 parent 188927f commit 5ffb250
Show file tree
Hide file tree
Showing 3 changed files with 376 additions and 0 deletions.
157 changes: 157 additions & 0 deletions spanner/client_test.go
Expand Up @@ -1833,6 +1833,163 @@ func TestClient_QueryWithCallOptions(t *testing.T) {
}
}

func TestClient_EncodeCustomFieldType(t *testing.T) {
t.Parallel()

type typesTable struct {
Int customStructToInt `spanner:"Int"`
String customStructToString `spanner:"String"`
Float customStructToFloat `spanner:"Float"`
Bool customStructToBool `spanner:"Bool"`
Time customStructToTime `spanner:"Time"`
Date customStructToDate `spanner:"Date"`
}

server, client, teardown := setupMockedTestServer(t)
defer teardown()
ctx := context.Background()

d := typesTable{
Int: customStructToInt{1, 23},
String: customStructToString{"A", "B"},
Float: customStructToFloat{1.23, 12.3},
Bool: customStructToBool{true, false},
Time: customStructToTime{"A", "B"},
Date: customStructToDate{"A", "B"},
}

m, err := InsertStruct("Types", &d)
if err != nil {
t.Fatalf("err: %v", err)
}

ms := []*Mutation{m}
_, err = client.Apply(ctx, ms)
if err != nil {
t.Fatalf("err: %v", err)
}

reqs := drainRequestsFromServer(server.TestSpanner)

for _, req := range reqs {
if commitReq, ok := req.(*sppb.CommitRequest); ok {
val := commitReq.Mutations[0].GetInsert().Values[0]

if got, want := val.Values[0].GetStringValue(), "123"; got != want {
t.Fatalf("value mismatch: got %v (kind %T), want %v", got, val.Values[0].GetKind(), want)
}
if got, want := val.Values[1].GetStringValue(), "A-B"; got != want {
t.Fatalf("value mismatch: got %v (kind %T), want %v", got, val.Values[1].GetKind(), want)
}
if got, want := val.Values[2].GetNumberValue(), float64(123.123); got != want {
t.Fatalf("value mismatch: got %v (kind %T), want %v", got, val.Values[2].GetKind(), want)
}
if got, want := val.Values[3].GetBoolValue(), true; got != want {
t.Fatalf("value mismatch: got %v (kind %T), want %v", got, val.Values[3].GetKind(), want)
}
if got, want := val.Values[4].GetStringValue(), "2016-11-15T15:04:05.999999999Z"; got != want {
t.Fatalf("value mismatch: got %v (kind %T), want %v", got, val.Values[4].GetKind(), want)
}
if got, want := val.Values[5].GetStringValue(), "2016-11-15"; got != want {
t.Fatalf("value mismatch: got %v (kind %T), want %v", got, val.Values[5].GetKind(), want)
}
}
}
}

func setupDecodeCustomFieldResult(server *MockedSpannerInMemTestServer, stmt string) error {
metadata := &sppb.ResultSetMetadata{
RowType: &sppb.StructType{
Fields: []*sppb.StructType_Field{
{Name: "Int", Type: &sppb.Type{Code: sppb.TypeCode_INT64}},
{Name: "String", Type: &sppb.Type{Code: sppb.TypeCode_STRING}},
{Name: "Float", Type: &sppb.Type{Code: sppb.TypeCode_FLOAT64}},
{Name: "Bool", Type: &sppb.Type{Code: sppb.TypeCode_BOOL}},
{Name: "Time", Type: &sppb.Type{Code: sppb.TypeCode_TIMESTAMP}},
{Name: "Date", Type: &sppb.Type{Code: sppb.TypeCode_DATE}},
},
},
}
rowValues := []*structpb.Value{
{Kind: &structpb.Value_StringValue{StringValue: "123"}},
{Kind: &structpb.Value_StringValue{StringValue: "A-B"}},
{Kind: &structpb.Value_NumberValue{NumberValue: float64(123.123)}},
{Kind: &structpb.Value_BoolValue{BoolValue: true}},
{Kind: &structpb.Value_StringValue{StringValue: "2016-11-15T15:04:05.999999999Z"}},
{Kind: &structpb.Value_StringValue{StringValue: "2016-11-15"}},
}
rows := []*structpb.ListValue{
{Values: rowValues},
}
resultSet := &sppb.ResultSet{
Metadata: metadata,
Rows: rows,
}
result := &StatementResult{
Type: StatementResultResultSet,
ResultSet: resultSet,
}
return server.TestSpanner.PutStatementResult(stmt, result)
}

func TestClient_DecodeCustomFieldType(t *testing.T) {
t.Parallel()

type typesTable struct {
Int customStructToInt `spanner:"Int"`
String customStructToString `spanner:"String"`
Float customStructToFloat `spanner:"Float"`
Bool customStructToBool `spanner:"Bool"`
Time customStructToTime `spanner:"Time"`
Date customStructToDate `spanner:"Date"`
}

server, client, teardown := setupMockedTestServer(t)
defer teardown()

query := "SELECT * FROM Types"
setupDecodeCustomFieldResult(server, query)

ctx := context.Background()
stmt := Statement{SQL: query}
iter := client.Single().Query(ctx, stmt)
defer iter.Stop()

var results []typesTable
for {
row, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
t.Fatalf("failed to get next: %v", err)
}

var d typesTable
if err := row.ToStruct(&d); err != nil {
t.Fatalf("failed to convert a row to a struct: %v", err)
}
results = append(results, d)
}

if len(results) > 1 {
t.Fatalf("mismatch length of array: got %v, want 1", results)
}

want := typesTable{
Int: customStructToInt{1, 23},
String: customStructToString{"A", "B"},
Float: customStructToFloat{1.23, 12.3},
Bool: customStructToBool{true, false},
Time: customStructToTime{"A", "B"},
Date: customStructToDate{"A", "B"},
}
got := results[0]
if !testEqual(got, want) {
t.Fatalf("mismatch result: got %v, want %v", got, want)
}
}

func TestBatchReadOnlyTransaction_QueryOptions(t *testing.T) {
ctx := context.Background()
qo := QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "1"}}
Expand Down
84 changes: 84 additions & 0 deletions spanner/value.go
Expand Up @@ -50,6 +50,51 @@ var (
jsonNullBytes = []byte("null")
)

// Encoder is the interface implemented by a custom type that can be encoded to
// a supported type by Spanner. A code example:
//
// type customField struct {
// Prefix string
// Suffix string
// }
//
// // Convert a customField value to a string
// func (cf customField) EncodeSpanner() (interface{}, error) {
// var b bytes.Buffer
// b.WriteString(cf.Prefix)
// b.WriteString("-")
// b.WriteString(cf.Suffix)
// return b.String(), nil
// }
type Encoder interface {
EncodeSpanner() (interface{}, error)
}

// Decoder is the interface implemented by a custom type that can be decoded
// from a supported type by Spanner. A code example:
//
// type customField struct {
// Prefix string
// Suffix string
// }
//
// // Convert a string to a customField value
// func (cf *customField) DecodeSpanner(val interface{}) (err error) {
// strVal, ok := val.(string)
// if !ok {
// return fmt.Errorf("failed to decode customField: %v", val)
// }
// s := strings.Split(strVal, "-")
// if len(s) > 1 {
// cf.Prefix = s[0]
// cf.Suffix = s[1]
// }
// return nil
// }
type Decoder interface {
DecodeSpanner(input interface{}) error
}

// NullableValue is the interface implemented by all null value wrapper types.
type NullableValue interface {
// IsNull returns true if the underlying database value is null.
Expand Down Expand Up @@ -1112,6 +1157,16 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error {
case *GenericColumnValue:
*p = GenericColumnValue{Type: t, Value: v}
default:
// Check if the pointer is a custom type that implements spanner.Decoder
// interface.
if decodedVal, ok := ptr.(Decoder); ok {
x, err := getGenericValue(v)
if err != nil {
return err
}
return decodedVal.DecodeSpanner(x)
}

// Check if the pointer is a variant of a base type.
decodableType := getDecodableSpannerType(ptr)
if decodableType != spannerTypeUnknown {
Expand Down Expand Up @@ -1613,6 +1668,20 @@ func getListValue(v *proto3.Value) (*proto3.ListValue, error) {
return nil, errSrcVal(v, "List")
}

// getGenericValue returns the interface{} value encoded in proto3.Value.
func getGenericValue(v *proto3.Value) (interface{}, error) {
switch x := v.GetKind().(type) {
case *proto3.Value_NumberValue:
return x.NumberValue, nil
case *proto3.Value_BoolValue:
return x.BoolValue, nil
case *proto3.Value_StringValue:
return x.StringValue, nil
default:
return 0, errSrcVal(v, "Number, Bool, String")
}
}

// errUnexpectedNumStr returns error for decoder getting a unexpected string for
// representing special float values.
func errUnexpectedNumStr(s string) error {
Expand Down Expand Up @@ -2381,6 +2450,16 @@ func encodeValue(v interface{}) (*proto3.Value, *sppb.Type, error) {
case []GenericColumnValue:
return nil, nil, errEncoderUnsupportedType(v)
default:
// Check if the value is a custom type that implements spanner.Encoder
// interface.
if encodedVal, ok := v.(Encoder); ok {
nv, err := encodedVal.EncodeSpanner()
if err != nil {
return nil, nil, err
}
return encodeValue(nv)
}

// Check if the value is a variant of a base type.
decodableType := getDecodableSpannerType(v)
if decodableType != spannerTypeUnknown && decodableType != spannerTypeInvalid {
Expand Down Expand Up @@ -2652,6 +2731,11 @@ func isSupportedMutationType(v interface{}) bool {
GenericColumnValue:
return true
default:
// Check if the custom type implements spanner.Encoder interface.
if _, ok := v.(Encoder); ok {
return true
}

decodableType := getDecodableSpannerType(v)
return decodableType != spannerTypeUnknown && decodableType != spannerTypeInvalid
}
Expand Down

0 comments on commit 5ffb250

Please sign in to comment.