diff --git a/.changelog/2096a6beb82d44bea4a469c197f6de40.json b/.changelog/2096a6beb82d44bea4a469c197f6de40.json new file mode 100644 index 00000000000..e580502da74 --- /dev/null +++ b/.changelog/2096a6beb82d44bea4a469c197f6de40.json @@ -0,0 +1,8 @@ +{ + "id": "2096a6be-b82d-44be-a4a4-69c197f6de40", + "type": "feature", + "description": "Add support for expression names with dots via new NameBuilder function NameNoDotSplit, related to [aws/aws-sdk-go#2570](https://github.com/aws/aws-sdk-go/issues/2570)", + "modules": [ + "feature/dynamodb/expression" + ] +} \ No newline at end of file diff --git a/.changelog/98a8c469e1d64c9aa06b0e467ac61dd9.json b/.changelog/98a8c469e1d64c9aa06b0e467ac61dd9.json new file mode 100644 index 00000000000..4eb187140df --- /dev/null +++ b/.changelog/98a8c469e1d64c9aa06b0e467ac61dd9.json @@ -0,0 +1,9 @@ +{ + "id": "98a8c469-e1d6-4c9a-a06b-0e467ac61dd9", + "type": "bugfix", + "description": "Fixes [#1569](https://github.com/aws/aws-sdk-go-v2/issues/1569) inconsistent serialization of Go struct field names", + "modules": [ + "feature/dynamodb/attributevalue", + "feature/dynamodbstreams/attributevalue" + ] +} diff --git a/.changelog/db81731fa3ab450e9ea3535a0d4aaedd.json b/.changelog/db81731fa3ab450e9ea3535a0d4aaedd.json new file mode 100644 index 00000000000..36c874b163c --- /dev/null +++ b/.changelog/db81731fa3ab450e9ea3535a0d4aaedd.json @@ -0,0 +1,9 @@ +{ + "id": "db81731f-a3ab-450e-9ea3-535a0d4aaedd", + "type": "feature", + "description": "Fixes [#645](https://github.com/aws/aws-sdk-go-v2/issues/645), [#411](https://github.com/aws/aws-sdk-go-v2/issues/411) by adding support for (un)marshaling AttributeValue maps to Go maps key types of string, number, bool, and types implementing encoding.Text(un)Marshaler interface", + "modules": [ + "feature/dynamodb/attributevalue", + "feature/dynamodbstreams/attributevalue" + ] +} \ No newline at end of file diff --git a/feature/dynamodb/attributevalue/decode.go b/feature/dynamodb/attributevalue/decode.go index 5a02853dc95..fc3f322dd01 100644 --- a/feature/dynamodb/attributevalue/decode.go +++ b/feature/dynamodb/attributevalue/decode.go @@ -1,6 +1,7 @@ package attributevalue import ( + "encoding" "fmt" "reflect" "strconv" @@ -197,7 +198,7 @@ func UnmarshalListOfMapsWithOptions(l []map[string]types.AttributeValue, out int } // DecoderOptions is a collection of options to configure how the decoder -// unmarshalls the value. +// unmarshals the value. type DecoderOptions struct { // Support other custom struct tag keys, such as `yaml`, `json`, or `toml`. // Note that values provided with a custom TagKey must also be supported @@ -221,7 +222,7 @@ type Decoder struct { // NewDecoder creates a new Decoder with default configuration. Use // the `opts` functional options to override the default configuration. func NewDecoder(optFns ...func(*DecoderOptions)) *Decoder { - var options DecoderOptions + options := DecoderOptions{TagKey: defaultTagKey} for _, fn := range optFns { fn(&options) } @@ -254,14 +255,14 @@ func (d *Decoder) decode(av types.AttributeValue, v reflect.Value, fieldTag tag) var u Unmarshaler _, isNull := av.(*types.AttributeValueMemberNULL) if av == nil || isNull { - u, v = indirect(v, true) + u, v = indirect(v, indirectOptions{decodeNull: true}) if u != nil { return u.UnmarshalDynamoDBAttributeValue(av) } return d.decodeNull(v) } - u, v = indirect(v, false) + u, v = indirect(v, indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBAttributeValue(av) } @@ -386,7 +387,7 @@ func (d *Decoder) decodeBinarySet(bs [][]byte, v reflect.Value) error { if !isArray { v.SetLen(i + 1) } - u, elem := indirect(v.Index(i), false) + u, elem := indirect(v.Index(i), indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberBS{Value: bs}) } @@ -513,7 +514,7 @@ func (d *Decoder) decodeNumberSet(ns []string, v reflect.Value) error { if !isArray { v.SetLen(i + 1) } - u, elem := indirect(v.Index(i), false) + u, elem := indirect(v.Index(i), indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberNS{Value: ns}) } @@ -564,32 +565,48 @@ func (d *Decoder) decodeList(avList []types.AttributeValue, v reflect.Value) err return nil } -func (d *Decoder) decodeMap(avMap map[string]types.AttributeValue, v reflect.Value) error { +func (d *Decoder) decodeMap(avMap map[string]types.AttributeValue, v reflect.Value) (err error) { + var decodeMapKey func(v string, key reflect.Value, fieldTag tag) error + switch v.Kind() { case reflect.Map: - t := v.Type() - if t.Key().Kind() != reflect.String { - return &UnmarshalTypeError{Value: "map string key", Type: t.Key()} + decodeMapKey, err = d.getMapKeyDecoder(v.Type().Key()) + if err != nil { + return err } + if v.IsNil() { - v.Set(reflect.MakeMap(t)) + v.Set(reflect.MakeMap(v.Type())) } case reflect.Struct: case reflect.Interface: v.Set(reflect.MakeMap(stringInterfaceMapType)) + decodeMapKey = d.decodeString v = v.Elem() default: return &UnmarshalTypeError{Value: "map", Type: v.Type()} } if v.Kind() == reflect.Map { + keyType := v.Type().Key() + valueType := v.Type().Elem() for k, av := range avMap { - key := reflect.New(v.Type().Key()).Elem() - key.SetString(k) - elem := reflect.New(v.Type().Elem()).Elem() + key := reflect.New(keyType).Elem() + // handle pointer keys + _, indirectKey := indirect(key, indirectOptions{skipUnmarshaler: true}) + if err := decodeMapKey(k, indirectKey, tag{}); err != nil { + return &UnmarshalTypeError{ + Value: fmt.Sprintf("map key %q", k), + Type: keyType, + Err: err, + } + } + + elem := reflect.New(valueType).Elem() if err := d.decode(av, elem, tag{}); err != nil { return err } + v.SetMapIndex(key, elem) } } else if v.Kind() == reflect.Struct { @@ -609,6 +626,50 @@ func (d *Decoder) decodeMap(avMap map[string]types.AttributeValue, v reflect.Val return nil } +var numberType = reflect.TypeOf(Number("")) +var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + +func (d *Decoder) getMapKeyDecoder(keyType reflect.Type) (func(string, reflect.Value, tag) error, error) { + // Test the key type to determine if it implements the TextUnmarshaler interface. + if reflect.PtrTo(keyType).Implements(textUnmarshalerType) || keyType.Implements(textUnmarshalerType) { + return func(v string, k reflect.Value, _ tag) error { + if !k.CanAddr() { + return fmt.Errorf("cannot take address of map key, %v", k.Type()) + } + return k.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(v)) + }, nil + } + + var decodeMapKey func(v string, key reflect.Value, fieldTag tag) error + + switch keyType.Kind() { + case reflect.Bool: + decodeMapKey = func(v string, key reflect.Value, fieldTag tag) error { + b, err := strconv.ParseBool(v) + if err != nil { + return err + } + return d.decodeBool(b, key) + } + case reflect.String: + // Number type handled as a string + decodeMapKey = d.decodeString + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + decodeMapKey = d.decodeNumber + + default: + return nil, &UnmarshalTypeError{ + Value: "map key must be string, number, bool, or TextUnmarshaler", + Type: keyType, + } + } + + return decodeMapKey, nil +} + func (d *Decoder) decodeNull(v reflect.Value) error { if v.IsValid() && v.CanSet() { v.Set(reflect.Zero(v.Type())) @@ -675,7 +736,7 @@ func (d *Decoder) decodeStringSet(ss []string, v reflect.Value) error { if !isArray { v.SetLen(i + 1) } - u, elem := indirect(v.Index(i), false) + u, elem := indirect(v.Index(i), indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberSS{Value: ss}) } @@ -713,38 +774,82 @@ func decoderFieldByIndex(v reflect.Value, index []int) reflect.Value { return v } +type indirectOptions struct { + decodeNull bool + skipUnmarshaler bool +} + // indirect will walk a value's interface or pointer value types. Returning // the final value or the value a unmarshaler is defined on. // // Based on the enoding/json type reflect value type indirection in Go Stdlib // https://golang.org/src/encoding/json/decode.go indirect func. -func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, reflect.Value) { +func indirect(v reflect.Value, opts indirectOptions) (Unmarshaler, reflect.Value) { + // Issue #24153 indicates that it is generally not a guaranteed property + // that you may round-trip a reflect.Value by calling Value.Addr().Elem() + // and expect the value to still be settable for values derived from + // unexported embedded struct fields. + // + // The logic below effectively does this when it first addresses the value + // (to satisfy possible pointer methods) and continues to dereference + // subsequent pointers as necessary. + // + // After the first round-trip, we set v back to the original value to + // preserve the original RW flags contained in reflect.Value. + v0 := v + haveAddr := false + + // If v is a named type and is addressable, + // start with its address, so that if the type has pointer methods, + // we find them. if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() { + haveAddr = true v = v.Addr() } + for { + // Load value from interface, but only if the result will be + // usefully addressable. if v.Kind() == reflect.Interface && !v.IsNil() { e := v.Elem() - if e.Kind() == reflect.Ptr && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Ptr) { + if e.Kind() == reflect.Ptr && !e.IsNil() && (!opts.decodeNull || e.Elem().Kind() == reflect.Ptr) { + haveAddr = false v = e continue } + if e.Kind() != reflect.Ptr && e.IsValid() { + return nil, e + } } if v.Kind() != reflect.Ptr { break } - if v.Elem().Kind() != reflect.Ptr && decodingNull && v.CanSet() { + if opts.decodeNull && v.CanSet() { + break + } + + // Prevent infinite loop if v is an interface pointing to its own address: + // var v interface{} + // v = &v + if v.Elem().Kind() == reflect.Interface && v.Elem().Elem() == v { + v = v.Elem() break } if v.IsNil() { v.Set(reflect.New(v.Type().Elem())) } - if v.Type().NumMethod() > 0 { + if !opts.skipUnmarshaler && v.Type().NumMethod() > 0 && v.CanInterface() { if u, ok := v.Interface().(Unmarshaler); ok { return u, reflect.Value{} } } - v = v.Elem() + + if haveAddr { + v = v0 // restore original value after round-trip Value.Addr().Elem() + haveAddr = false + } else { + v = v.Elem() + } } return nil, v @@ -782,8 +887,12 @@ func (n Number) String() string { type UnmarshalTypeError struct { Value string Type reflect.Type + Err error } +// Unwrap returns the underlying error if any. +func (e *UnmarshalTypeError) Unwrap() error { return e.Err } + // Error returns the string representation of the error. // satisfying the error interface func (e *UnmarshalTypeError) Error() string { diff --git a/feature/dynamodb/attributevalue/decode_test.go b/feature/dynamodb/attributevalue/decode_test.go index feafd697c83..10e279de89e 100644 --- a/feature/dynamodb/attributevalue/decode_test.go +++ b/feature/dynamodb/attributevalue/decode_test.go @@ -335,7 +335,10 @@ func TestUnmarshalMapError(t *testing.T) { }, actual: &map[int]interface{}{}, expected: nil, - err: &UnmarshalTypeError{Value: "map string key", Type: reflect.TypeOf(int(0))}, + err: &UnmarshalTypeError{ + Value: `map key "BOOL"`, + Type: reflect.TypeOf(int(0)), + }, }, } @@ -765,3 +768,197 @@ func TestDecodeAliasType(t *testing.T) { t.Errorf("expect:\n%v\nactual:\n%v", expect, actual) } } + +type testUnmarshalMapKeyComplex struct { + Foo string +} + +func (t *testUnmarshalMapKeyComplex) UnmarshalText(b []byte) error { + t.Foo = string(b) + return nil +} +func (t *testUnmarshalMapKeyComplex) UnmarshalDynamoDBAttributeValue(av types.AttributeValue) error { + avM, ok := av.(*types.AttributeValueMemberM) + if !ok { + return fmt.Errorf("unexpected AttributeValue type %T, %v", av, av) + } + avFoo, ok := avM.Value["foo"] + if !ok { + return nil + } + + avS, ok := avFoo.(*types.AttributeValueMemberS) + if !ok { + return fmt.Errorf("unexpected Foo AttributeValue type, %T, %v", avM, avM) + } + + t.Foo = avS.Value + + return nil +} + +func TestUnmarshalMap_keyTypes(t *testing.T) { + type StrAlias string + type IntAlias int + type BoolAlias bool + + cases := map[string]struct { + input map[string]types.AttributeValue + expectVal interface{} + expectType func() interface{} + }{ + "string key": { + input: map[string]types.AttributeValue{ + "a": &types.AttributeValueMemberN{Value: "123"}, + "b": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[string]interface{}{} }, + expectVal: map[string]interface{}{ + "a": 123., + "b": "efg", + }, + }, + "string alias key": { + input: map[string]types.AttributeValue{ + "a": &types.AttributeValueMemberN{Value: "123"}, + "b": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[StrAlias]interface{}{} }, + expectVal: map[StrAlias]interface{}{ + "a": 123., + "b": "efg", + }, + }, + "Number key": { + input: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[Number]interface{}{} }, + expectVal: map[Number]interface{}{ + Number("1"): 123., + Number("2"): "efg", + }, + }, + "int key": { + input: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[int]interface{}{} }, + expectVal: map[int]interface{}{ + 1: 123., + 2: "efg", + }, + }, + "int alias key": { + input: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[IntAlias]interface{}{} }, + expectVal: map[IntAlias]interface{}{ + 1: 123., + 2: "efg", + }, + }, + "bool key": { + input: map[string]types.AttributeValue{ + "true": &types.AttributeValueMemberN{Value: "123"}, + "false": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[bool]interface{}{} }, + expectVal: map[bool]interface{}{ + true: 123., + false: "efg", + }, + }, + "bool alias key": { + input: map[string]types.AttributeValue{ + "true": &types.AttributeValueMemberN{Value: "123"}, + "false": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[BoolAlias]interface{}{} }, + expectVal: map[BoolAlias]interface{}{ + true: 123., + false: "efg", + }, + }, + "textMarshaler key": { + input: map[string]types.AttributeValue{ + "Foo:1": &types.AttributeValueMemberN{Value: "123"}, + "Foo:2": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[testTextMarshaler]interface{}{} }, + expectVal: map[testTextMarshaler]interface{}{ + {Foo: "1"}: 123., + {Foo: "2"}: "efg", + }, + }, + "textMarshaler DDBAvMarshaler key": { + input: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[testUnmarshalMapKeyComplex]interface{}{} }, + expectVal: map[testUnmarshalMapKeyComplex]interface{}{ + {Foo: "1"}: 123., + {Foo: "2"}: "efg", + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + actualVal := c.expectType() + err := UnmarshalMap(c.input, &actualVal) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + t.Logf("expectType, %T", actualVal) + + if diff := cmp.Diff(c.expectVal, actualVal); diff != "" { + t.Errorf("expect value match\n%s", diff) + } + }) + } +} + +func TestUnmarshalMap_keyPtrTypes(t *testing.T) { + input := map[string]types.AttributeValue{ + "Foo:1": &types.AttributeValueMemberN{Value: "123"}, + "Foo:2": &types.AttributeValueMemberS{Value: "efg"}, + } + + expectVal := map[*testTextMarshaler]interface{}{ + {Foo: "1"}: 123., + {Foo: "2"}: "efg", + } + + actualVal := map[*testTextMarshaler]interface{}{} + err := UnmarshalMap(input, &actualVal) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + t.Logf("expectType, %T", actualVal) + + if e, a := len(expectVal), len(actualVal); e != a { + t.Errorf("expect %v values, got %v", e, a) + } + + for k, v := range expectVal { + var found bool + for ak, av := range actualVal { + if *k == *ak { + found = true + if diff := cmp.Diff(v, av); diff != "" { + t.Errorf("expect value match\n%s", diff) + } + } + } + if !found { + t.Errorf("expect %v key not found", *k) + } + } + +} diff --git a/feature/dynamodb/attributevalue/encode.go b/feature/dynamodb/attributevalue/encode.go index c8dcf94736a..4a96c08d31c 100644 --- a/feature/dynamodb/attributevalue/encode.go +++ b/feature/dynamodb/attributevalue/encode.go @@ -1,6 +1,7 @@ package attributevalue import ( + "encoding" "fmt" "reflect" "strconv" @@ -380,6 +381,7 @@ type Encoder struct { // the `opts` functional options to override the default configuration. func NewEncoder(optFns ...func(*EncoderOptions)) *Encoder { options := EncoderOptions{ + TagKey: defaultTagKey, NullEmptySets: true, } for _, fn := range optFns { @@ -497,9 +499,9 @@ func (e *Encoder) encodeStruct(v reflect.Value, fieldTag tag) (types.AttributeVa func (e *Encoder) encodeMap(v reflect.Value, fieldTag tag) (types.AttributeValue, error) { m := &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{}} for _, key := range v.MapKeys() { - keyName := fmt.Sprint(key.Interface()) - if keyName == "" { - return nil, &InvalidMarshalError{msg: "map key cannot be empty"} + keyName, err := mapKeyAsString(key, fieldTag) + if err != nil { + return nil, err } elemVal := v.MapIndex(key) @@ -519,6 +521,40 @@ func (e *Encoder) encodeMap(v reflect.Value, fieldTag tag) (types.AttributeValue return m, nil } +func mapKeyAsString(keyVal reflect.Value, fieldTag tag) (keyStr string, err error) { + defer func() { + if err != nil { + return + } + if keyStr == "" { + err = &InvalidMarshalError{msg: "map key cannot be empty"} + } + }() + + if k, ok := keyVal.Interface().(encoding.TextMarshaler); ok { + b, err := k.MarshalText() + if err != nil { + return "", fmt.Errorf("failed to marshal text, %w", err) + } + return string(b), err + } + + switch keyVal.Kind() { + case reflect.Bool, + reflect.String, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + + return fmt.Sprint(keyVal.Interface()), nil + + default: + return "", &InvalidMarshalError{ + msg: "map key type not supported, must be string, number, bool, or TextMarshaler", + } + } +} + func (e *Encoder) encodeSlice(v reflect.Value, fieldTag tag) (types.AttributeValue, error) { if v.Type().Elem().Kind() == reflect.Uint8 { slice := reflect.MakeSlice(byteSliceType, v.Len(), v.Len()) diff --git a/feature/dynamodb/attributevalue/encode_test.go b/feature/dynamodb/attributevalue/encode_test.go index 58793aee34f..23c2fbfc01a 100644 --- a/feature/dynamodb/attributevalue/encode_test.go +++ b/feature/dynamodb/attributevalue/encode_test.go @@ -1,13 +1,14 @@ package attributevalue import ( - smithydocument "github.com/aws/smithy-go/document" - "github.com/google/go-cmp/cmp/cmpopts" "reflect" "strconv" "testing" "time" + smithydocument "github.com/aws/smithy-go/document" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/google/go-cmp/cmp" @@ -366,3 +367,130 @@ func TestEncoderFieldByIndex(t *testing.T) { t.Error("expected f to be of kind Int with value equal to outer.Inner") } } + +func TestMarshalMap_keyTypes(t *testing.T) { + type StrAlias string + type IntAlias int + type BoolAlias bool + + cases := map[string]struct { + input interface{} + expectAV map[string]types.AttributeValue + }{ + "string key": { + input: map[string]interface{}{ + "a": 123, + "b": "efg", + }, + expectAV: map[string]types.AttributeValue{ + "a": &types.AttributeValueMemberN{Value: "123"}, + "b": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "string alias key": { + input: map[StrAlias]interface{}{ + "a": 123, + "b": "efg", + }, + expectAV: map[string]types.AttributeValue{ + "a": &types.AttributeValueMemberN{Value: "123"}, + "b": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "Number key": { + input: map[Number]interface{}{ + Number("1"): 123, + Number("2"): "efg", + }, + expectAV: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "int key": { + input: map[int]interface{}{ + 1: 123, + 2: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "int alias key": { + input: map[IntAlias]interface{}{ + 1: 123, + 2: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "bool key": { + input: map[bool]interface{}{ + true: 123, + false: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "true": &types.AttributeValueMemberN{Value: "123"}, + "false": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "bool alias key": { + input: map[BoolAlias]interface{}{ + true: 123, + false: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "true": &types.AttributeValueMemberN{Value: "123"}, + "false": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "textMarshaler key": { + input: map[testTextMarshaler]interface{}{ + {Foo: "1"}: 123, + {Foo: "2"}: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "Foo:1": &types.AttributeValueMemberN{Value: "123"}, + "Foo:2": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "textMarshaler ptr key": { + input: map[*testTextMarshaler]interface{}{ + {Foo: "1"}: 123, + {Foo: "2"}: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "Foo:1": &types.AttributeValueMemberN{Value: "123"}, + "Foo:2": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + av, err := MarshalMap(c.input) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + cmpOptions := cmp.Options{ + cmpopts.IgnoreUnexported(types.AttributeValueMemberM{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberN{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberNS{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberBOOL{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberB{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberBS{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberL{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberS{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberSS{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberNULL{}), + } + if diff := cmp.Diff(c.expectAV, av, cmpOptions...); diff != "" { + t.Errorf("expect attribute value match\n%s", diff) + } + }) + } +} diff --git a/feature/dynamodb/attributevalue/field.go b/feature/dynamodb/attributevalue/field.go index 7abd3479a96..4f63bc7df99 100644 --- a/feature/dynamodb/attributevalue/field.go +++ b/feature/dynamodb/attributevalue/field.go @@ -5,6 +5,8 @@ import ( "sort" ) +const defaultTagKey = "dynamodbav" + type field struct { tag @@ -46,7 +48,12 @@ type structFieldOptions struct { // unionStructFields returns a list of fields for the given type. Type info is cached // to avoid repeated calls into the reflect package func unionStructFields(t reflect.Type, opts structFieldOptions) *cachedFields { - if cached, ok := fieldCache.Load(t); ok { + key := fieldCacheKey{ + typ: t, + opts: opts, + } + + if cached, ok := fieldCache.Load(key); ok { return cached } @@ -62,7 +69,7 @@ func unionStructFields(t reflect.Type, opts structFieldOptions) *cachedFields { fs.fieldsByName[f.Name] = i } - cached, _ := fieldCache.LoadOrStore(t, fs) + cached, _ := fieldCache.LoadOrStore(key, fs) return cached } @@ -105,7 +112,7 @@ func enumFields(t reflect.Type, opts structFieldOptions) []field { fieldTag := tag{} fieldTag.parseAVTag(sf.Tag) // Because MarshalOptions.TagKey must be explicitly set. - if opts.TagKey != "" && fieldTag == (tag{}) { + if opts.TagKey != "" && opts.TagKey != defaultTagKey { fieldTag.parseStructTag(opts.TagKey, sf.Tag) } diff --git a/feature/dynamodb/attributevalue/field_cache.go b/feature/dynamodb/attributevalue/field_cache.go index 60a9d9c7499..c0fc4679a86 100644 --- a/feature/dynamodb/attributevalue/field_cache.go +++ b/feature/dynamodb/attributevalue/field_cache.go @@ -1,25 +1,31 @@ package attributevalue import ( + "reflect" "strings" "sync" ) -var fieldCache fieldCacher +var fieldCache = &fieldCacher{} + +type fieldCacheKey struct { + typ reflect.Type + opts structFieldOptions +} type fieldCacher struct { cache sync.Map } -func (c *fieldCacher) Load(t interface{}) (*cachedFields, bool) { - if v, ok := c.cache.Load(t); ok { +func (c *fieldCacher) Load(key fieldCacheKey) (*cachedFields, bool) { + if v, ok := c.cache.Load(key); ok { return v.(*cachedFields), true } return nil, false } -func (c *fieldCacher) LoadOrStore(t interface{}, fs *cachedFields) (*cachedFields, bool) { - v, ok := c.cache.LoadOrStore(t, fs) +func (c *fieldCacher) LoadOrStore(key fieldCacheKey, fs *cachedFields) (*cachedFields, bool) { + v, ok := c.cache.LoadOrStore(key, fs) return v.(*cachedFields), ok } diff --git a/feature/dynamodb/attributevalue/field_test.go b/feature/dynamodb/attributevalue/field_test.go index 82a09d6cf99..9c2b291703d 100644 --- a/feature/dynamodb/attributevalue/field_test.go +++ b/feature/dynamodb/attributevalue/field_test.go @@ -1,6 +1,7 @@ package attributevalue import ( + "fmt" "reflect" "testing" ) @@ -22,7 +23,7 @@ type unionComplex struct { } type unionTagged struct { - A int `json:"A"` + A int `dynamodbav:"ddbav" json:"A" taga:"TagA" tagb:"TagB"` } type unionTaggedComplex struct { @@ -32,97 +33,211 @@ type unionTaggedComplex struct { } func TestUnionStructFields(t *testing.T) { - var cases = []struct { + origFieldCache := fieldCache + defer func() { fieldCache = origFieldCache }() + + fieldCache = &fieldCacher{} + + var cases = map[string]struct { in interface{} + opts structFieldOptions expect []testUnionValues }{ - { - in: unionSimple{1, "2", []string{"abc"}}, + "simple input": { + in: unionSimple{1, "2", []string{"abc"}}, + opts: structFieldOptions{TagKey: "json"}, expect: []testUnionValues{ {"A", 1}, {"B", "2"}, {"C", []string{"abc"}}, }, }, - { + "nested struct": { in: unionComplex{ unionSimple: unionSimple{1, "2", []string{"abc"}}, A: 2, }, + opts: structFieldOptions{TagKey: "json"}, expect: []testUnionValues{ {"B", "2"}, {"C", []string{"abc"}}, {"A", 2}, }, }, - { + "with TagKey unset": { + in: unionTaggedComplex{ + unionSimple: unionSimple{1, "2", []string{"abc"}}, + unionTagged: unionTagged{3}, + B: "3", + }, + expect: []testUnionValues{ + {"A", 1}, + {"C", []string{"abc"}}, + {"ddbav", 3}, + {"B", "3"}, + }, + }, + "with TagKey json": { in: unionTaggedComplex{ unionSimple: unionSimple{1, "2", []string{"abc"}}, unionTagged: unionTagged{3}, B: "3", }, + opts: structFieldOptions{TagKey: "json"}, expect: []testUnionValues{ {"C", []string{"abc"}}, {"A", 3}, {"B", "3"}, }, }, + "with TagKey taga": { + in: unionTaggedComplex{ + unionSimple: unionSimple{1, "2", []string{"abc"}}, + unionTagged: unionTagged{3}, + B: "3", + }, + opts: structFieldOptions{TagKey: "taga"}, + expect: []testUnionValues{ + {"A", 1}, + {"C", []string{"abc"}}, + {"TagA", 3}, + {"B", "3"}, + }, + }, + "with TagKey tagb": { + in: unionTaggedComplex{ + unionSimple: unionSimple{1, "2", []string{"abc"}}, + unionTagged: unionTagged{3}, + B: "3", + }, + opts: structFieldOptions{TagKey: "tagb"}, + expect: []testUnionValues{ + {"A", 1}, + {"C", []string{"abc"}}, + {"TagB", 3}, + {"B", "3"}, + }, + }, } - for i, c := range cases { - v := reflect.ValueOf(c.in) + for name, c := range cases { + t.Run(name, func(t *testing.T) { + v := reflect.ValueOf(c.in) - fields := unionStructFields(v.Type(), structFieldOptions{TagKey: "json"}) - for j, f := range fields.All() { - expected := c.expect[j] - if e, a := expected.Name, f.Name; e != a { - t.Errorf("%d:%d expect %v, got %v", i, j, e, f) - } - actual := v.FieldByIndex(f.Index).Interface() - if e, a := expected.Value, actual; !reflect.DeepEqual(e, a) { - t.Errorf("%d:%d expect %v, got %v", i, j, e, f) + fields := unionStructFields(v.Type(), c.opts) + for i, f := range fields.All() { + expected := c.expect[i] + if e, a := expected.Name, f.Name; e != a { + t.Errorf("%d expect %v, got %v, %v", i, e, a, f) + } + actual := v.FieldByIndex(f.Index).Interface() + if e, a := expected.Value, actual; !reflect.DeepEqual(e, a) { + t.Errorf("%d expect %v, got %v, %v", i, e, a, f) + } } - } + }) } } func TestCachedFields(t *testing.T) { type myStruct struct { - Dog int + Dog int `tag1:"rabbit" tag2:"cow" tag3:"horse"` CAT string bird bool } - fields := unionStructFields(reflect.TypeOf(myStruct{}), structFieldOptions{}) - - const expectedNumFields = 2 - if numFields := len(fields.All()); numFields != expectedNumFields { - t.Errorf("expected number of fields to be %d but got %d", expectedNumFields, numFields) - } - - cases := []struct { + cases := map[string][]struct { Name string FieldName string Found bool }{ - {"Dog", "Dog", true}, - {"dog", "Dog", true}, - {"DOG", "Dog", true}, - {"Yorkie", "", false}, - {"Cat", "CAT", true}, - {"cat", "CAT", true}, - {"CAT", "CAT", true}, - {"tiger", "", false}, - {"bird", "", false}, + "": { + {"Dog", "Dog", true}, + {"dog", "Dog", true}, + {"DOG", "Dog", true}, + {"Yorkie", "", false}, + {"Cat", "CAT", true}, + {"cat", "CAT", true}, + {"CAT", "CAT", true}, + {"tiger", "", false}, + {"bird", "", false}, + }, + "tag1": { + {"rabbit", "rabbit", true}, + {"Rabbit", "rabbit", true}, + {"cow", "", false}, + {"Cow", "", false}, + {"horse", "", false}, + {"Horse", "", false}, + {"Dog", "", false}, + {"dog", "", false}, + {"DOG", "", false}, + {"Cat", "CAT", true}, + {"cat", "CAT", true}, + {"CAT", "CAT", true}, + {"tiger", "", false}, + {"bird", "", false}, + }, + "tag2": { + {"rabbit", "", false}, + {"Rabbit", "", false}, + {"cow", "cow", true}, + {"Cow", "cow", true}, + {"horse", "", false}, + {"Horse", "", false}, + {"Dog", "", false}, + {"dog", "", false}, + {"DOG", "", false}, + {"Cat", "CAT", true}, + {"cat", "CAT", true}, + {"CAT", "CAT", true}, + {"tiger", "", false}, + {"bird", "", false}, + }, + "tag3": { + {"rabbit", "", false}, + {"Rabbit", "", false}, + {"cow", "", false}, + {"Cow", "", false}, + {"horse", "horse", true}, + {"Horse", "horse", true}, + {"Dog", "", false}, + {"dog", "", false}, + {"DOG", "", false}, + {"Cat", "CAT", true}, + {"cat", "CAT", true}, + {"CAT", "CAT", true}, + {"tiger", "", false}, + {"bird", "", false}, + }, } - for _, c := range cases { - f, found := fields.FieldByName(c.Name) - if found != c.Found { - t.Errorf("expected found to be %v but got %v", c.Found, found) - } - if found && f.Name != c.FieldName { - t.Errorf("expected field name to be %s but got %s", c.FieldName, f.Name) + for tagKey, cs := range cases { + for _, c := range cs { + name := tagKey + if name == "" { + name = "none" + } + t.Run(fmt.Sprintf("%s/%s", name, c.Name), func(t *testing.T) { + t.Parallel() + + fields := unionStructFields(reflect.TypeOf(myStruct{}), structFieldOptions{ + TagKey: tagKey, + }) + + const expectedNumFields = 2 + if numFields := len(fields.All()); numFields != expectedNumFields { + t.Errorf("expect %v fields, got %d", expectedNumFields, numFields) + } + + f, found := fields.FieldByName(c.Name) + if found != c.Found { + t.Errorf("expect %v found, got %v", c.Found, found) + } + if found && f.Name != c.FieldName { + t.Errorf("expect %v field name, got %s", c.FieldName, f.Name) + } + }) } } } diff --git a/feature/dynamodb/attributevalue/marshaler_test.go b/feature/dynamodb/attributevalue/marshaler_test.go index 6d84c8adac5..6d895e89b4a 100644 --- a/feature/dynamodb/attributevalue/marshaler_test.go +++ b/feature/dynamodb/attributevalue/marshaler_test.go @@ -520,7 +520,7 @@ func compareObjects(t *testing.T, expected interface{}, actual interface{}) { } func BenchmarkMarshalOneMember(b *testing.B) { - fieldCache = fieldCacher{} + fieldCache = &fieldCacher{} simple := simpleMarshalStruct{ String: "abc", @@ -547,7 +547,7 @@ func BenchmarkMarshalOneMember(b *testing.B) { } func BenchmarkMarshalTwoMembers(b *testing.B) { - fieldCache = fieldCacher{} + fieldCache = &fieldCacher{} simple := simpleMarshalStruct{ String: "abc", @@ -576,7 +576,7 @@ func BenchmarkMarshalTwoMembers(b *testing.B) { } func BenchmarkUnmarshalOneMember(b *testing.B) { - fieldCache = fieldCacher{} + fieldCache = &fieldCacher{} myStructAVMap, _ := Marshal(simpleMarshalStruct{ String: "abc", @@ -605,7 +605,7 @@ func BenchmarkUnmarshalOneMember(b *testing.B) { } func BenchmarkUnmarshalTwoMembers(b *testing.B) { - fieldCache = fieldCacher{} + fieldCache = &fieldCacher{} myStructAVMap, _ := Marshal(simpleMarshalStruct{ String: "abc", diff --git a/feature/dynamodb/attributevalue/shared_test.go b/feature/dynamodb/attributevalue/shared_test.go index 63a09bbf640..b2c249ae4ef 100644 --- a/feature/dynamodb/attributevalue/shared_test.go +++ b/feature/dynamodb/attributevalue/shared_test.go @@ -1,18 +1,37 @@ package attributevalue import ( - smithydocument "github.com/aws/smithy-go/document" - "github.com/google/go-cmp/cmp/cmpopts" + "fmt" "reflect" "strings" "testing" "time" + smithydocument "github.com/aws/smithy-go/document" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/google/go-cmp/cmp" ) +type testTextMarshaler struct { + Foo string +} + +func (t *testTextMarshaler) UnmarshalText(b []byte) error { + if !strings.HasPrefix(string(b), "Foo:") { + return fmt.Errorf(`missing "Foo:" prefix`) + } + + t.Foo = string(b)[len("Foo:"):] + return nil +} + +func (t testTextMarshaler) MarshalText() ([]byte, error) { + return []byte("Foo:" + t.Foo), nil +} + type testBinarySetStruct struct { Binarys [][]byte `dynamodbav:",binaryset"` } diff --git a/feature/dynamodb/attributevalue/tag.go b/feature/dynamodb/attributevalue/tag.go index 6eb901706fb..f01c432e6ea 100644 --- a/feature/dynamodb/attributevalue/tag.go +++ b/feature/dynamodb/attributevalue/tag.go @@ -18,7 +18,7 @@ type tag struct { } func (t *tag) parseAVTag(structTag reflect.StructTag) { - tagStr := structTag.Get("dynamodbav") + tagStr := structTag.Get(defaultTagKey) if len(tagStr) == 0 { return } diff --git a/feature/dynamodb/expression/expression_test.go b/feature/dynamodb/expression/expression_test.go index e748277d281..c3c3565da71 100644 --- a/feature/dynamodb/expression/expression_test.go +++ b/feature/dynamodb/expression/expression_test.go @@ -384,7 +384,7 @@ func TestUpdate(t *testing.T) { setOperation: { { name: NameBuilder{ - name: "foo", + names: []string{"foo"}, }, value: ValueBuilder{ value: 5, @@ -407,7 +407,7 @@ func TestUpdate(t *testing.T) { setOperation: { { name: NameBuilder{ - name: "foo", + names: []string{"foo"}, }, value: ValueBuilder{ value: 5, @@ -416,7 +416,7 @@ func TestUpdate(t *testing.T) { }, { name: NameBuilder{ - name: "bar", + names: []string{"bar"}, }, value: ValueBuilder{ value: 6, @@ -425,7 +425,7 @@ func TestUpdate(t *testing.T) { }, { name: NameBuilder{ - name: "baz", + names: []string{"baz"}, }, value: ValueBuilder{ value: 7, @@ -496,7 +496,7 @@ func TestNames(t *testing.T) { condition: ConditionBuilder{ operandList: []OperandBuilder{ NameBuilder{ - name: "foo", + names: []string{"foo"}, }, ValueBuilder{ value: 5, @@ -507,7 +507,7 @@ func TestNames(t *testing.T) { filter: ConditionBuilder{ operandList: []OperandBuilder{ NameBuilder{ - name: "bar", + names: []string{"bar"}, }, ValueBuilder{ value: 6, @@ -518,13 +518,13 @@ func TestNames(t *testing.T) { projection: ProjectionBuilder{ names: []NameBuilder{ { - name: "foo", + names: []string{"foo"}, }, { - name: "bar", + names: []string{"bar"}, }, { - name: "baz", + names: []string{"baz"}, }, }, }, @@ -618,7 +618,7 @@ func TestValues(t *testing.T) { condition: ConditionBuilder{ operandList: []OperandBuilder{ NameBuilder{ - name: "foo", + names: []string{"foo"}, }, ValueBuilder{ value: 5, @@ -629,7 +629,7 @@ func TestValues(t *testing.T) { filter: ConditionBuilder{ operandList: []OperandBuilder{ NameBuilder{ - name: "bar", + names: []string{"bar"}, }, ValueBuilder{ value: 6, @@ -640,13 +640,13 @@ func TestValues(t *testing.T) { projection: ProjectionBuilder{ names: []NameBuilder{ { - name: "foo", + names: []string{"foo"}, }, { - name: "bar", + names: []string{"bar"}, }, { - name: "baz", + names: []string{"baz"}, }, }, }, @@ -702,7 +702,7 @@ func TestBuildChildTrees(t *testing.T) { condition: ConditionBuilder{ operandList: []OperandBuilder{ NameBuilder{ - name: "foo", + names: []string{"foo"}, }, ValueBuilder{ value: 5, @@ -713,7 +713,7 @@ func TestBuildChildTrees(t *testing.T) { filter: ConditionBuilder{ operandList: []OperandBuilder{ NameBuilder{ - name: "bar", + names: []string{"bar"}, }, ValueBuilder{ value: 6, @@ -724,13 +724,13 @@ func TestBuildChildTrees(t *testing.T) { projection: ProjectionBuilder{ names: []NameBuilder{ { - name: "foo", + names: []string{"foo"}, }, { - name: "bar", + names: []string{"bar"}, }, { - name: "baz", + names: []string{"baz"}, }, }, }, diff --git a/feature/dynamodb/expression/go.mod b/feature/dynamodb/expression/go.mod index 0d6eb5bbf0c..6f1311f5580 100644 --- a/feature/dynamodb/expression/go.mod +++ b/feature/dynamodb/expression/go.mod @@ -6,6 +6,7 @@ require ( github.com/aws/aws-sdk-go-v2 v1.13.0 github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.6.0 github.com/aws/aws-sdk-go-v2/service/dynamodb v1.13.0 + github.com/google/go-cmp v0.5.6 ) replace github.com/aws/aws-sdk-go-v2 => ../../../ diff --git a/feature/dynamodb/expression/operand.go b/feature/dynamodb/expression/operand.go index 77caed0615e..7eabc00a663 100644 --- a/feature/dynamodb/expression/operand.go +++ b/feature/dynamodb/expression/operand.go @@ -18,7 +18,16 @@ import ( // // Create a ValueBuilder representing the string "aValue" // valueBuilder := expression.Value("aValue") type ValueBuilder struct { - value interface{} + value interface{} + options ValueBuilderOptions +} + +// ValueBuilderOptions provides the options for how a value is built, and +// encoded in the expression. +type ValueBuilderOptions struct { + // Use functional options to specify how the value will be encoded. If the + // value is already an AttributeValue, the EncoderOptions will be ignored. + EncoderOptions []func(*attributevalue.EncoderOptions) } // NameBuilder represents a name of a top level item attribute or a nested @@ -32,7 +41,7 @@ type ValueBuilder struct { // // Create a NameBuilder representing the item attribute "aName" // nameBuilder := expression.Name("aName") type NameBuilder struct { - name string + names []string } // SizeBuilder represents the output of the function size ("someName"), which @@ -132,8 +141,75 @@ type OperandBuilder interface { // // Use Name() to create a condition expression // condition := expression.Name("foo").Equal(expression.Name("bar")) func Name(name string) NameBuilder { + if len(name) == 0 { + return NameBuilder{} + } + + return NameBuilder{ + names: strings.Split(name, "."), + } +} + +// AppendName to adds additional name fields, returning a new NameBuilder. Can +// be used to append list indexes and map fields to the Expression attribute +// name. +// +// Leading or trailing dots(`.`) for Names that are not created with +// NameNoDotSplit will result in an error when the expression is built. The +// dot(`.`) will be added automatically as needed. +func (nb NameBuilder) AppendName(field NameBuilder) NameBuilder { + names := make([]string, 0, len(nb.names)+len(field.names)) + names = append(names, nb.names...) + names = append(names, field.names...) + + // If the name being append starts with a list index it to the name being + // appended to. This allows list indexes to be appended to names. If there + // is a syntax error in the name, it will be caught when the expression is + // built via BuildOperand method. + if len(nb.names) != 0 && len(field.names) != 0 { + lastLeftName := len(nb.names) - 1 + firstRightName := lastLeftName + 1 + if v := names[firstRightName]; len(v) > 0 && v[0] == '[' { + if end := strings.Index(v, "]"); end != -1 { + names[lastLeftName] += v[0 : end+1] + names[firstRightName] = v[end+1:] + // Remove the name if it is empty after moving the index. + if len(names[firstRightName]) == 0 { + copy(names[firstRightName:], names[firstRightName+1:]) + names[len(names)-1] = "" + names = names[:len(names)-1] + } + } + } + } + + return NameBuilder{ + names: names, + } +} + +// NameNoDotSplit returns a NameBuilder. The argument should represent the +// desired item attribute. The name will not be split on dots. The name may end +// with square brackets for list indexes. Square brackets will not be +// considered a part of the NameLiteral. +// +// Use NameBuilder.WithField method to add subsequent map field names. +// Use NameBuilder.WithListIndex method to add list index to the name. +// +// See: http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.Attributes.html +// +// Example: +// +// // Specify a name containing dots, and should not be split. +// name := expression.NameLiteral("Top.Level") +// +// // Specify a nested attribute +// nested := expression.Name("Record[6].SongList") +// // Use Name() to create a condition expression +// condition := expression.Name("foo").Equal(expression.Name("bar")) +func NameNoDotSplit(name string) NameBuilder { return NameBuilder{ - name: name, + names: []string{name}, } } @@ -157,6 +233,37 @@ func Value(value interface{}) ValueBuilder { } } +// ValueWithOptions creates a ValueBuilder and sets its value to the argument. The value +// will be marshalled using the attributevalue package, unless it is of +// type types.AttributeValue, where it will be used directly. +// +// The ValueBuilderOptions functional options parameter allows you to specify +// how the value will be encoded. Including options like AttributeValue +// encoding struct tag. If value is already an DynamoDB AttributeValue, +// encoding options will have not effect. +// +// Empty slices and maps will be encoded as their respective empty types.AttributeValue +// types. If a NULL value is required, pass a dynamodb.AttributeValue, e.g.: +// emptyList := &types.AttributeValueMemberNULL{Value: true} +// +// Example: +// +// // Use Value() to create a condition expression +// condition := expression.Name("foo").Equal(expression.Value(10)) +// // Use Value() to set the value of a set expression. +// update := Set(expression.Name("greets"), expression.Value(&types.AttributeValueMemberS{Value: "hello"})) +func ValueWithOptions(value interface{}, optFns ...func(*ValueBuilderOptions)) ValueBuilder { + var options ValueBuilderOptions + for _, fn := range optFns { + fn(&options) + } + + return ValueBuilder{ + value: value, + options: options, + } +} + // Size creates a SizeBuilder representing the size of the item attribute // specified by the argument NameBuilder. Size() is only valid for certain types // of item attributes. For documentation, @@ -455,7 +562,10 @@ func IfNotExists(name NameBuilder, setValue OperandBuilder) SetValueBuilder { // // // Use IfNotExists() to set item attribute "someName" to value 5 if // // "someName" does not exist yet. (Prevents overwrite) -// update, err := expression.Set(expression.Name("someName"), expression.Name("someName").IfNotExists(expression.Value(5))) +// update, err := expression.Set( +// expression.Name("someName"), +// expression.Name("someName").IfNotExists(expression.Value(5)), +// ) // // Expression Equivalent: // @@ -473,9 +583,10 @@ func (nb NameBuilder) IfNotExists(rightOperand OperandBuilder) SetValueBuilder { // Builder is called. BuildOperand() should never be called externally. // BuildOperand() aliases all strings to avoid stepping over DynamoDB's reserved // words. +// // More information on reserved words at http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/ReservedWords.html func (nb NameBuilder) BuildOperand() (Operand, error) { - if nb.name == "" { + if len(nb.names) == 0 { return Operand{}, newUnsetParameterError("BuildOperand", "NameBuilder") } @@ -483,15 +594,17 @@ func (nb NameBuilder) BuildOperand() (Operand, error) { names: []string{}, } - nameSplit := strings.Split(nb.name, ".") - fmtNames := make([]string, 0, len(nameSplit)) - - for _, word := range nameSplit { + fmtNames := make([]string, 0, len(nb.names)) + for _, word := range nb.names { var substr string if word == "" { return Operand{}, newInvalidParameterError("BuildOperand", "NameBuilder") } + if idx := strings.Index(word, "]"); idx != -1 && idx != len(word)-1 { + return Operand{}, newInvalidParameterError("BuildOperand", "NameBuilder") + } + if word[len(word)-1] == ']' { for j, char := range word { if char == '[' { @@ -554,7 +667,7 @@ func (vb ValueBuilder) BuildOperand() (Operand, error) { case types.AttributeValue: expr = v default: - expr, err = attributevalue.Marshal(vb.value) + expr, err = attributevalue.MarshalWithOptions(vb.value, vb.options.EncoderOptions...) if err != nil { return Operand{}, newInvalidParameterError("BuildOperand", "ValueBuilder") } diff --git a/feature/dynamodb/expression/operand_test.go b/feature/dynamodb/expression/operand_test.go index 5921b0dfec0..108d49916ce 100644 --- a/feature/dynamodb/expression/operand_test.go +++ b/feature/dynamodb/expression/operand_test.go @@ -1,11 +1,13 @@ package expression import ( - "reflect" "strings" "testing" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" ) // opeErrorMode will help with error cases and checking error types @@ -23,6 +25,11 @@ const ( ) func TestBuildOperand(t *testing.T) { + type mockStructValue struct { + A string `dynamodbav:"ddbA" tagb:"TagB"` + B string + } + cases := []struct { name string input OperandBuilder @@ -45,6 +52,37 @@ func TestBuildOperand(t *testing.T) { fmtExpr: "$n.$n", }, }, + { + name: "struct value", + input: ValueWithOptions(mockStructValue{A: "abc123", B: "efg456"}), + expected: exprNode{ + values: []types.AttributeValue{ + &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "ddbA": &types.AttributeValueMemberS{Value: "abc123"}, + "B": &types.AttributeValueMemberS{Value: "efg456"}, + }}, + }, + fmtExpr: "$v", + }, + }, + { + name: "struct value with TagKey", + input: ValueWithOptions(mockStructValue{A: "abc123", B: "efg456"}, + func(o *ValueBuilderOptions) { + o.EncoderOptions = append(o.EncoderOptions, func(o *attributevalue.EncoderOptions) { + o.TagKey = "tagb" + }) + }), + expected: exprNode{ + values: []types.AttributeValue{ + &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "TagB": &types.AttributeValueMemberS{Value: "abc123"}, + "B": &types.AttributeValueMemberS{Value: "efg456"}, + }}, + }, + fmtExpr: "$v", + }, + }, { name: "basic value", input: Value(5), @@ -181,6 +219,14 @@ func TestBuildOperand(t *testing.T) { fmtExpr: "$n.$n", }, }, + { + name: "no split name", + input: NameNoDotSplit("foo.bar"), + expected: exprNode{ + names: []string{"foo.bar"}, + fmtExpr: "$n", + }, + }, { name: "nested name with index", input: Name("foo.bar[0].baz"), @@ -189,6 +235,33 @@ func TestBuildOperand(t *testing.T) { fmtExpr: "$n.$n[0].$n", }, }, + { + name: "no split name with index", + input: NameNoDotSplit("foo.bar[0]"), + expected: exprNode{ + names: []string{"foo.bar"}, + fmtExpr: "$n[0]", + }, + }, + { + name: "no split name append name", + input: NameNoDotSplit("foo.bar").AppendName(Name("foo.bar")), + expected: exprNode{ + names: []string{"foo.bar", "foo", "bar"}, + fmtExpr: "$n.$n.$n", + }, + }, + { + name: "no split name append name with list index", + input: NameNoDotSplit("foo.bar"). + AppendName(Name("foo.bar")). + AppendName(Name("[0]")). + AppendName(Name("abc123")), + expected: exprNode{ + names: []string{"foo.bar", "foo", "bar", "abc123"}, + fmtExpr: "$n.$n.$n[0].$n", + }, + }, { name: "basic size", input: Name("foo").Size(), @@ -238,19 +311,30 @@ func TestBuildOperand(t *testing.T) { if c.err != noOperandError { if err == nil { t.Errorf("expect error %q, got no error", c.err) - } else { - if e, a := string(c.err), err.Error(); !strings.Contains(a, e) { - t.Errorf("expect %q error message to be in %q", e, a) - } - } - } else { - if err != nil { - t.Errorf("expect no error, got unexpected Error %q", err) + } else if e, a := string(c.err), err.Error(); !strings.Contains(a, e) { + t.Errorf("expect %q error message to be in %q", e, a) } + return + } + if err != nil { + t.Fatalf("expect no error, got unexpected Error %q", err) + } - if e, a := c.expected, operand.exprNode; !reflect.DeepEqual(a, e) { - t.Errorf("expect %v, got %v", e, a) - } + cmpOptions := cmp.Options{ + cmp.AllowUnexported(exprNode{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberM{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberN{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberNS{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberBOOL{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberB{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberBS{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberL{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberS{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberSS{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberNULL{}), + } + if diff := cmp.Diff(c.expected, operand.exprNode, cmpOptions...); diff != "" { + t.Errorf("expect operand match\n%s", diff) } }) } diff --git a/feature/dynamodb/expression/update_test.go b/feature/dynamodb/expression/update_test.go index 917e63d2276..4c0a4108ce6 100644 --- a/feature/dynamodb/expression/update_test.go +++ b/feature/dynamodb/expression/update_test.go @@ -599,7 +599,7 @@ func TestUpdateBuildChildNodes(t *testing.T) { { mode: setOperation, name: NameBuilder{ - name: "foo", + names: []string{"foo"}, }, value: ValueBuilder{ value: 5, @@ -608,7 +608,7 @@ func TestUpdateBuildChildNodes(t *testing.T) { { mode: setOperation, name: NameBuilder{ - name: "bar", + names: []string{"bar"}, }, value: ValueBuilder{ value: 6, @@ -617,7 +617,7 @@ func TestUpdateBuildChildNodes(t *testing.T) { { mode: setOperation, name: NameBuilder{ - name: "baz", + names: []string{"baz"}, }, value: ValueBuilder{ value: 7, @@ -626,7 +626,7 @@ func TestUpdateBuildChildNodes(t *testing.T) { { mode: setOperation, name: NameBuilder{ - name: "qux", + names: []string{"qux"}, }, value: ValueBuilder{ value: 8, diff --git a/feature/dynamodbstreams/attributevalue/decode.go b/feature/dynamodbstreams/attributevalue/decode.go index 368c51936ec..2d6ec05e2d7 100644 --- a/feature/dynamodbstreams/attributevalue/decode.go +++ b/feature/dynamodbstreams/attributevalue/decode.go @@ -1,6 +1,7 @@ package attributevalue import ( + "encoding" "fmt" "reflect" "strconv" @@ -197,7 +198,7 @@ func UnmarshalListOfMapsWithOptions(l []map[string]types.AttributeValue, out int } // DecoderOptions is a collection of options to configure how the decoder -// unmarshalls the value. +// unmarshals the value. type DecoderOptions struct { // Support other custom struct tag keys, such as `yaml`, `json`, or `toml`. // Note that values provided with a custom TagKey must also be supported @@ -221,7 +222,7 @@ type Decoder struct { // NewDecoder creates a new Decoder with default configuration. Use // the `opts` functional options to override the default configuration. func NewDecoder(optFns ...func(*DecoderOptions)) *Decoder { - var options DecoderOptions + options := DecoderOptions{TagKey: defaultTagKey} for _, fn := range optFns { fn(&options) } @@ -254,14 +255,14 @@ func (d *Decoder) decode(av types.AttributeValue, v reflect.Value, fieldTag tag) var u Unmarshaler _, isNull := av.(*types.AttributeValueMemberNULL) if av == nil || isNull { - u, v = indirect(v, true) + u, v = indirect(v, indirectOptions{decodeNull: true}) if u != nil { return u.UnmarshalDynamoDBStreamsAttributeValue(av) } return d.decodeNull(v) } - u, v = indirect(v, false) + u, v = indirect(v, indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBStreamsAttributeValue(av) } @@ -386,7 +387,7 @@ func (d *Decoder) decodeBinarySet(bs [][]byte, v reflect.Value) error { if !isArray { v.SetLen(i + 1) } - u, elem := indirect(v.Index(i), false) + u, elem := indirect(v.Index(i), indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBStreamsAttributeValue(&types.AttributeValueMemberBS{Value: bs}) } @@ -513,7 +514,7 @@ func (d *Decoder) decodeNumberSet(ns []string, v reflect.Value) error { if !isArray { v.SetLen(i + 1) } - u, elem := indirect(v.Index(i), false) + u, elem := indirect(v.Index(i), indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBStreamsAttributeValue(&types.AttributeValueMemberNS{Value: ns}) } @@ -564,32 +565,48 @@ func (d *Decoder) decodeList(avList []types.AttributeValue, v reflect.Value) err return nil } -func (d *Decoder) decodeMap(avMap map[string]types.AttributeValue, v reflect.Value) error { +func (d *Decoder) decodeMap(avMap map[string]types.AttributeValue, v reflect.Value) (err error) { + var decodeMapKey func(v string, key reflect.Value, fieldTag tag) error + switch v.Kind() { case reflect.Map: - t := v.Type() - if t.Key().Kind() != reflect.String { - return &UnmarshalTypeError{Value: "map string key", Type: t.Key()} + decodeMapKey, err = d.getMapKeyDecoder(v.Type().Key()) + if err != nil { + return err } + if v.IsNil() { - v.Set(reflect.MakeMap(t)) + v.Set(reflect.MakeMap(v.Type())) } case reflect.Struct: case reflect.Interface: v.Set(reflect.MakeMap(stringInterfaceMapType)) + decodeMapKey = d.decodeString v = v.Elem() default: return &UnmarshalTypeError{Value: "map", Type: v.Type()} } if v.Kind() == reflect.Map { + keyType := v.Type().Key() + valueType := v.Type().Elem() for k, av := range avMap { - key := reflect.New(v.Type().Key()).Elem() - key.SetString(k) - elem := reflect.New(v.Type().Elem()).Elem() + key := reflect.New(keyType).Elem() + // handle pointer keys + _, indirectKey := indirect(key, indirectOptions{skipUnmarshaler: true}) + if err := decodeMapKey(k, indirectKey, tag{}); err != nil { + return &UnmarshalTypeError{ + Value: fmt.Sprintf("map key %q", k), + Type: keyType, + Err: err, + } + } + + elem := reflect.New(valueType).Elem() if err := d.decode(av, elem, tag{}); err != nil { return err } + v.SetMapIndex(key, elem) } } else if v.Kind() == reflect.Struct { @@ -609,6 +626,50 @@ func (d *Decoder) decodeMap(avMap map[string]types.AttributeValue, v reflect.Val return nil } +var numberType = reflect.TypeOf(Number("")) +var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + +func (d *Decoder) getMapKeyDecoder(keyType reflect.Type) (func(string, reflect.Value, tag) error, error) { + // Test the key type to determine if it implements the TextUnmarshaler interface. + if reflect.PtrTo(keyType).Implements(textUnmarshalerType) || keyType.Implements(textUnmarshalerType) { + return func(v string, k reflect.Value, _ tag) error { + if !k.CanAddr() { + return fmt.Errorf("cannot take address of map key, %v", k.Type()) + } + return k.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(v)) + }, nil + } + + var decodeMapKey func(v string, key reflect.Value, fieldTag tag) error + + switch keyType.Kind() { + case reflect.Bool: + decodeMapKey = func(v string, key reflect.Value, fieldTag tag) error { + b, err := strconv.ParseBool(v) + if err != nil { + return err + } + return d.decodeBool(b, key) + } + case reflect.String: + // Number type handled as a string + decodeMapKey = d.decodeString + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + decodeMapKey = d.decodeNumber + + default: + return nil, &UnmarshalTypeError{ + Value: "map key must be string, number, bool, or TextUnmarshaler", + Type: keyType, + } + } + + return decodeMapKey, nil +} + func (d *Decoder) decodeNull(v reflect.Value) error { if v.IsValid() && v.CanSet() { v.Set(reflect.Zero(v.Type())) @@ -675,7 +736,7 @@ func (d *Decoder) decodeStringSet(ss []string, v reflect.Value) error { if !isArray { v.SetLen(i + 1) } - u, elem := indirect(v.Index(i), false) + u, elem := indirect(v.Index(i), indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBStreamsAttributeValue(&types.AttributeValueMemberSS{Value: ss}) } @@ -713,38 +774,82 @@ func decoderFieldByIndex(v reflect.Value, index []int) reflect.Value { return v } +type indirectOptions struct { + decodeNull bool + skipUnmarshaler bool +} + // indirect will walk a value's interface or pointer value types. Returning // the final value or the value a unmarshaler is defined on. // // Based on the enoding/json type reflect value type indirection in Go Stdlib // https://golang.org/src/encoding/json/decode.go indirect func. -func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, reflect.Value) { +func indirect(v reflect.Value, opts indirectOptions) (Unmarshaler, reflect.Value) { + // Issue #24153 indicates that it is generally not a guaranteed property + // that you may round-trip a reflect.Value by calling Value.Addr().Elem() + // and expect the value to still be settable for values derived from + // unexported embedded struct fields. + // + // The logic below effectively does this when it first addresses the value + // (to satisfy possible pointer methods) and continues to dereference + // subsequent pointers as necessary. + // + // After the first round-trip, we set v back to the original value to + // preserve the original RW flags contained in reflect.Value. + v0 := v + haveAddr := false + + // If v is a named type and is addressable, + // start with its address, so that if the type has pointer methods, + // we find them. if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() { + haveAddr = true v = v.Addr() } + for { + // Load value from interface, but only if the result will be + // usefully addressable. if v.Kind() == reflect.Interface && !v.IsNil() { e := v.Elem() - if e.Kind() == reflect.Ptr && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Ptr) { + if e.Kind() == reflect.Ptr && !e.IsNil() && (!opts.decodeNull || e.Elem().Kind() == reflect.Ptr) { + haveAddr = false v = e continue } + if e.Kind() != reflect.Ptr && e.IsValid() { + return nil, e + } } if v.Kind() != reflect.Ptr { break } - if v.Elem().Kind() != reflect.Ptr && decodingNull && v.CanSet() { + if opts.decodeNull && v.CanSet() { + break + } + + // Prevent infinite loop if v is an interface pointing to its own address: + // var v interface{} + // v = &v + if v.Elem().Kind() == reflect.Interface && v.Elem().Elem() == v { + v = v.Elem() break } if v.IsNil() { v.Set(reflect.New(v.Type().Elem())) } - if v.Type().NumMethod() > 0 { + if !opts.skipUnmarshaler && v.Type().NumMethod() > 0 && v.CanInterface() { if u, ok := v.Interface().(Unmarshaler); ok { return u, reflect.Value{} } } - v = v.Elem() + + if haveAddr { + v = v0 // restore original value after round-trip Value.Addr().Elem() + haveAddr = false + } else { + v = v.Elem() + } } return nil, v @@ -782,8 +887,12 @@ func (n Number) String() string { type UnmarshalTypeError struct { Value string Type reflect.Type + Err error } +// Unwrap returns the underlying error if any. +func (e *UnmarshalTypeError) Unwrap() error { return e.Err } + // Error returns the string representation of the error. // satisfying the error interface func (e *UnmarshalTypeError) Error() string { diff --git a/feature/dynamodbstreams/attributevalue/decode_test.go b/feature/dynamodbstreams/attributevalue/decode_test.go index e64c36e5e4a..0f83592a558 100644 --- a/feature/dynamodbstreams/attributevalue/decode_test.go +++ b/feature/dynamodbstreams/attributevalue/decode_test.go @@ -335,7 +335,10 @@ func TestUnmarshalMapError(t *testing.T) { }, actual: &map[int]interface{}{}, expected: nil, - err: &UnmarshalTypeError{Value: "map string key", Type: reflect.TypeOf(int(0))}, + err: &UnmarshalTypeError{ + Value: `map key "BOOL"`, + Type: reflect.TypeOf(int(0)), + }, }, } @@ -765,3 +768,197 @@ func TestDecodeAliasType(t *testing.T) { t.Errorf("expect:\n%v\nactual:\n%v", expect, actual) } } + +type testUnmarshalMapKeyComplex struct { + Foo string +} + +func (t *testUnmarshalMapKeyComplex) UnmarshalText(b []byte) error { + t.Foo = string(b) + return nil +} +func (t *testUnmarshalMapKeyComplex) UnmarshalDynamoDBStreamsAttributeValue(av types.AttributeValue) error { + avM, ok := av.(*types.AttributeValueMemberM) + if !ok { + return fmt.Errorf("unexpected AttributeValue type %T, %v", av, av) + } + avFoo, ok := avM.Value["foo"] + if !ok { + return nil + } + + avS, ok := avFoo.(*types.AttributeValueMemberS) + if !ok { + return fmt.Errorf("unexpected Foo AttributeValue type, %T, %v", avM, avM) + } + + t.Foo = avS.Value + + return nil +} + +func TestUnmarshalMap_keyTypes(t *testing.T) { + type StrAlias string + type IntAlias int + type BoolAlias bool + + cases := map[string]struct { + input map[string]types.AttributeValue + expectVal interface{} + expectType func() interface{} + }{ + "string key": { + input: map[string]types.AttributeValue{ + "a": &types.AttributeValueMemberN{Value: "123"}, + "b": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[string]interface{}{} }, + expectVal: map[string]interface{}{ + "a": 123., + "b": "efg", + }, + }, + "string alias key": { + input: map[string]types.AttributeValue{ + "a": &types.AttributeValueMemberN{Value: "123"}, + "b": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[StrAlias]interface{}{} }, + expectVal: map[StrAlias]interface{}{ + "a": 123., + "b": "efg", + }, + }, + "Number key": { + input: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[Number]interface{}{} }, + expectVal: map[Number]interface{}{ + Number("1"): 123., + Number("2"): "efg", + }, + }, + "int key": { + input: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[int]interface{}{} }, + expectVal: map[int]interface{}{ + 1: 123., + 2: "efg", + }, + }, + "int alias key": { + input: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[IntAlias]interface{}{} }, + expectVal: map[IntAlias]interface{}{ + 1: 123., + 2: "efg", + }, + }, + "bool key": { + input: map[string]types.AttributeValue{ + "true": &types.AttributeValueMemberN{Value: "123"}, + "false": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[bool]interface{}{} }, + expectVal: map[bool]interface{}{ + true: 123., + false: "efg", + }, + }, + "bool alias key": { + input: map[string]types.AttributeValue{ + "true": &types.AttributeValueMemberN{Value: "123"}, + "false": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[BoolAlias]interface{}{} }, + expectVal: map[BoolAlias]interface{}{ + true: 123., + false: "efg", + }, + }, + "textMarshaler key": { + input: map[string]types.AttributeValue{ + "Foo:1": &types.AttributeValueMemberN{Value: "123"}, + "Foo:2": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[testTextMarshaler]interface{}{} }, + expectVal: map[testTextMarshaler]interface{}{ + {Foo: "1"}: 123., + {Foo: "2"}: "efg", + }, + }, + "textMarshaler DDBAvMarshaler key": { + input: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[testUnmarshalMapKeyComplex]interface{}{} }, + expectVal: map[testUnmarshalMapKeyComplex]interface{}{ + {Foo: "1"}: 123., + {Foo: "2"}: "efg", + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + actualVal := c.expectType() + err := UnmarshalMap(c.input, &actualVal) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + t.Logf("expectType, %T", actualVal) + + if diff := cmp.Diff(c.expectVal, actualVal); diff != "" { + t.Errorf("expect value match\n%s", diff) + } + }) + } +} + +func TestUnmarshalMap_keyPtrTypes(t *testing.T) { + input := map[string]types.AttributeValue{ + "Foo:1": &types.AttributeValueMemberN{Value: "123"}, + "Foo:2": &types.AttributeValueMemberS{Value: "efg"}, + } + + expectVal := map[*testTextMarshaler]interface{}{ + {Foo: "1"}: 123., + {Foo: "2"}: "efg", + } + + actualVal := map[*testTextMarshaler]interface{}{} + err := UnmarshalMap(input, &actualVal) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + t.Logf("expectType, %T", actualVal) + + if e, a := len(expectVal), len(actualVal); e != a { + t.Errorf("expect %v values, got %v", e, a) + } + + for k, v := range expectVal { + var found bool + for ak, av := range actualVal { + if *k == *ak { + found = true + if diff := cmp.Diff(v, av); diff != "" { + t.Errorf("expect value match\n%s", diff) + } + } + } + if !found { + t.Errorf("expect %v key not found", *k) + } + } + +} diff --git a/feature/dynamodbstreams/attributevalue/encode.go b/feature/dynamodbstreams/attributevalue/encode.go index 66f649f05d4..89f15c1ee5c 100644 --- a/feature/dynamodbstreams/attributevalue/encode.go +++ b/feature/dynamodbstreams/attributevalue/encode.go @@ -1,6 +1,7 @@ package attributevalue import ( + "encoding" "fmt" "reflect" "strconv" @@ -380,6 +381,7 @@ type Encoder struct { // the `opts` functional options to override the default configuration. func NewEncoder(optFns ...func(*EncoderOptions)) *Encoder { options := EncoderOptions{ + TagKey: defaultTagKey, NullEmptySets: true, } for _, fn := range optFns { @@ -497,9 +499,9 @@ func (e *Encoder) encodeStruct(v reflect.Value, fieldTag tag) (types.AttributeVa func (e *Encoder) encodeMap(v reflect.Value, fieldTag tag) (types.AttributeValue, error) { m := &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{}} for _, key := range v.MapKeys() { - keyName := fmt.Sprint(key.Interface()) - if keyName == "" { - return nil, &InvalidMarshalError{msg: "map key cannot be empty"} + keyName, err := mapKeyAsString(key, fieldTag) + if err != nil { + return nil, err } elemVal := v.MapIndex(key) @@ -519,6 +521,40 @@ func (e *Encoder) encodeMap(v reflect.Value, fieldTag tag) (types.AttributeValue return m, nil } +func mapKeyAsString(keyVal reflect.Value, fieldTag tag) (keyStr string, err error) { + defer func() { + if err != nil { + return + } + if keyStr == "" { + err = &InvalidMarshalError{msg: "map key cannot be empty"} + } + }() + + if k, ok := keyVal.Interface().(encoding.TextMarshaler); ok { + b, err := k.MarshalText() + if err != nil { + return "", fmt.Errorf("failed to marshal text, %w", err) + } + return string(b), err + } + + switch keyVal.Kind() { + case reflect.Bool, + reflect.String, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + + return fmt.Sprint(keyVal.Interface()), nil + + default: + return "", &InvalidMarshalError{ + msg: "map key type not supported, must be string, number, bool, or TextMarshaler", + } + } +} + func (e *Encoder) encodeSlice(v reflect.Value, fieldTag tag) (types.AttributeValue, error) { if v.Type().Elem().Kind() == reflect.Uint8 { slice := reflect.MakeSlice(byteSliceType, v.Len(), v.Len()) diff --git a/feature/dynamodbstreams/attributevalue/encode_test.go b/feature/dynamodbstreams/attributevalue/encode_test.go index 0079eea8f76..af64000ddaa 100644 --- a/feature/dynamodbstreams/attributevalue/encode_test.go +++ b/feature/dynamodbstreams/attributevalue/encode_test.go @@ -1,13 +1,14 @@ package attributevalue import ( - smithydocument "github.com/aws/smithy-go/document" - "github.com/google/go-cmp/cmp/cmpopts" "reflect" "strconv" "testing" "time" + smithydocument "github.com/aws/smithy-go/document" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/dynamodbstreams/types" "github.com/google/go-cmp/cmp" @@ -366,3 +367,130 @@ func TestEncoderFieldByIndex(t *testing.T) { t.Error("expected f to be of kind Int with value equal to outer.Inner") } } + +func TestMarshalMap_keyTypes(t *testing.T) { + type StrAlias string + type IntAlias int + type BoolAlias bool + + cases := map[string]struct { + input interface{} + expectAV map[string]types.AttributeValue + }{ + "string key": { + input: map[string]interface{}{ + "a": 123, + "b": "efg", + }, + expectAV: map[string]types.AttributeValue{ + "a": &types.AttributeValueMemberN{Value: "123"}, + "b": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "string alias key": { + input: map[StrAlias]interface{}{ + "a": 123, + "b": "efg", + }, + expectAV: map[string]types.AttributeValue{ + "a": &types.AttributeValueMemberN{Value: "123"}, + "b": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "Number key": { + input: map[Number]interface{}{ + Number("1"): 123, + Number("2"): "efg", + }, + expectAV: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "int key": { + input: map[int]interface{}{ + 1: 123, + 2: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "int alias key": { + input: map[IntAlias]interface{}{ + 1: 123, + 2: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "bool key": { + input: map[bool]interface{}{ + true: 123, + false: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "true": &types.AttributeValueMemberN{Value: "123"}, + "false": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "bool alias key": { + input: map[BoolAlias]interface{}{ + true: 123, + false: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "true": &types.AttributeValueMemberN{Value: "123"}, + "false": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "textMarshaler key": { + input: map[testTextMarshaler]interface{}{ + {Foo: "1"}: 123, + {Foo: "2"}: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "Foo:1": &types.AttributeValueMemberN{Value: "123"}, + "Foo:2": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "textMarshaler ptr key": { + input: map[*testTextMarshaler]interface{}{ + {Foo: "1"}: 123, + {Foo: "2"}: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "Foo:1": &types.AttributeValueMemberN{Value: "123"}, + "Foo:2": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + av, err := MarshalMap(c.input) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + cmpOptions := cmp.Options{ + cmpopts.IgnoreUnexported(types.AttributeValueMemberM{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberN{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberNS{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberBOOL{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberB{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberBS{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberL{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberS{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberSS{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberNULL{}), + } + if diff := cmp.Diff(c.expectAV, av, cmpOptions...); diff != "" { + t.Errorf("expect attribute value match\n%s", diff) + } + }) + } +} diff --git a/feature/dynamodbstreams/attributevalue/field.go b/feature/dynamodbstreams/attributevalue/field.go index 7abd3479a96..4f63bc7df99 100644 --- a/feature/dynamodbstreams/attributevalue/field.go +++ b/feature/dynamodbstreams/attributevalue/field.go @@ -5,6 +5,8 @@ import ( "sort" ) +const defaultTagKey = "dynamodbav" + type field struct { tag @@ -46,7 +48,12 @@ type structFieldOptions struct { // unionStructFields returns a list of fields for the given type. Type info is cached // to avoid repeated calls into the reflect package func unionStructFields(t reflect.Type, opts structFieldOptions) *cachedFields { - if cached, ok := fieldCache.Load(t); ok { + key := fieldCacheKey{ + typ: t, + opts: opts, + } + + if cached, ok := fieldCache.Load(key); ok { return cached } @@ -62,7 +69,7 @@ func unionStructFields(t reflect.Type, opts structFieldOptions) *cachedFields { fs.fieldsByName[f.Name] = i } - cached, _ := fieldCache.LoadOrStore(t, fs) + cached, _ := fieldCache.LoadOrStore(key, fs) return cached } @@ -105,7 +112,7 @@ func enumFields(t reflect.Type, opts structFieldOptions) []field { fieldTag := tag{} fieldTag.parseAVTag(sf.Tag) // Because MarshalOptions.TagKey must be explicitly set. - if opts.TagKey != "" && fieldTag == (tag{}) { + if opts.TagKey != "" && opts.TagKey != defaultTagKey { fieldTag.parseStructTag(opts.TagKey, sf.Tag) } diff --git a/feature/dynamodbstreams/attributevalue/field_cache.go b/feature/dynamodbstreams/attributevalue/field_cache.go index 60a9d9c7499..c0fc4679a86 100644 --- a/feature/dynamodbstreams/attributevalue/field_cache.go +++ b/feature/dynamodbstreams/attributevalue/field_cache.go @@ -1,25 +1,31 @@ package attributevalue import ( + "reflect" "strings" "sync" ) -var fieldCache fieldCacher +var fieldCache = &fieldCacher{} + +type fieldCacheKey struct { + typ reflect.Type + opts structFieldOptions +} type fieldCacher struct { cache sync.Map } -func (c *fieldCacher) Load(t interface{}) (*cachedFields, bool) { - if v, ok := c.cache.Load(t); ok { +func (c *fieldCacher) Load(key fieldCacheKey) (*cachedFields, bool) { + if v, ok := c.cache.Load(key); ok { return v.(*cachedFields), true } return nil, false } -func (c *fieldCacher) LoadOrStore(t interface{}, fs *cachedFields) (*cachedFields, bool) { - v, ok := c.cache.LoadOrStore(t, fs) +func (c *fieldCacher) LoadOrStore(key fieldCacheKey, fs *cachedFields) (*cachedFields, bool) { + v, ok := c.cache.LoadOrStore(key, fs) return v.(*cachedFields), ok } diff --git a/feature/dynamodbstreams/attributevalue/field_test.go b/feature/dynamodbstreams/attributevalue/field_test.go index 82a09d6cf99..9c2b291703d 100644 --- a/feature/dynamodbstreams/attributevalue/field_test.go +++ b/feature/dynamodbstreams/attributevalue/field_test.go @@ -1,6 +1,7 @@ package attributevalue import ( + "fmt" "reflect" "testing" ) @@ -22,7 +23,7 @@ type unionComplex struct { } type unionTagged struct { - A int `json:"A"` + A int `dynamodbav:"ddbav" json:"A" taga:"TagA" tagb:"TagB"` } type unionTaggedComplex struct { @@ -32,97 +33,211 @@ type unionTaggedComplex struct { } func TestUnionStructFields(t *testing.T) { - var cases = []struct { + origFieldCache := fieldCache + defer func() { fieldCache = origFieldCache }() + + fieldCache = &fieldCacher{} + + var cases = map[string]struct { in interface{} + opts structFieldOptions expect []testUnionValues }{ - { - in: unionSimple{1, "2", []string{"abc"}}, + "simple input": { + in: unionSimple{1, "2", []string{"abc"}}, + opts: structFieldOptions{TagKey: "json"}, expect: []testUnionValues{ {"A", 1}, {"B", "2"}, {"C", []string{"abc"}}, }, }, - { + "nested struct": { in: unionComplex{ unionSimple: unionSimple{1, "2", []string{"abc"}}, A: 2, }, + opts: structFieldOptions{TagKey: "json"}, expect: []testUnionValues{ {"B", "2"}, {"C", []string{"abc"}}, {"A", 2}, }, }, - { + "with TagKey unset": { + in: unionTaggedComplex{ + unionSimple: unionSimple{1, "2", []string{"abc"}}, + unionTagged: unionTagged{3}, + B: "3", + }, + expect: []testUnionValues{ + {"A", 1}, + {"C", []string{"abc"}}, + {"ddbav", 3}, + {"B", "3"}, + }, + }, + "with TagKey json": { in: unionTaggedComplex{ unionSimple: unionSimple{1, "2", []string{"abc"}}, unionTagged: unionTagged{3}, B: "3", }, + opts: structFieldOptions{TagKey: "json"}, expect: []testUnionValues{ {"C", []string{"abc"}}, {"A", 3}, {"B", "3"}, }, }, + "with TagKey taga": { + in: unionTaggedComplex{ + unionSimple: unionSimple{1, "2", []string{"abc"}}, + unionTagged: unionTagged{3}, + B: "3", + }, + opts: structFieldOptions{TagKey: "taga"}, + expect: []testUnionValues{ + {"A", 1}, + {"C", []string{"abc"}}, + {"TagA", 3}, + {"B", "3"}, + }, + }, + "with TagKey tagb": { + in: unionTaggedComplex{ + unionSimple: unionSimple{1, "2", []string{"abc"}}, + unionTagged: unionTagged{3}, + B: "3", + }, + opts: structFieldOptions{TagKey: "tagb"}, + expect: []testUnionValues{ + {"A", 1}, + {"C", []string{"abc"}}, + {"TagB", 3}, + {"B", "3"}, + }, + }, } - for i, c := range cases { - v := reflect.ValueOf(c.in) + for name, c := range cases { + t.Run(name, func(t *testing.T) { + v := reflect.ValueOf(c.in) - fields := unionStructFields(v.Type(), structFieldOptions{TagKey: "json"}) - for j, f := range fields.All() { - expected := c.expect[j] - if e, a := expected.Name, f.Name; e != a { - t.Errorf("%d:%d expect %v, got %v", i, j, e, f) - } - actual := v.FieldByIndex(f.Index).Interface() - if e, a := expected.Value, actual; !reflect.DeepEqual(e, a) { - t.Errorf("%d:%d expect %v, got %v", i, j, e, f) + fields := unionStructFields(v.Type(), c.opts) + for i, f := range fields.All() { + expected := c.expect[i] + if e, a := expected.Name, f.Name; e != a { + t.Errorf("%d expect %v, got %v, %v", i, e, a, f) + } + actual := v.FieldByIndex(f.Index).Interface() + if e, a := expected.Value, actual; !reflect.DeepEqual(e, a) { + t.Errorf("%d expect %v, got %v, %v", i, e, a, f) + } } - } + }) } } func TestCachedFields(t *testing.T) { type myStruct struct { - Dog int + Dog int `tag1:"rabbit" tag2:"cow" tag3:"horse"` CAT string bird bool } - fields := unionStructFields(reflect.TypeOf(myStruct{}), structFieldOptions{}) - - const expectedNumFields = 2 - if numFields := len(fields.All()); numFields != expectedNumFields { - t.Errorf("expected number of fields to be %d but got %d", expectedNumFields, numFields) - } - - cases := []struct { + cases := map[string][]struct { Name string FieldName string Found bool }{ - {"Dog", "Dog", true}, - {"dog", "Dog", true}, - {"DOG", "Dog", true}, - {"Yorkie", "", false}, - {"Cat", "CAT", true}, - {"cat", "CAT", true}, - {"CAT", "CAT", true}, - {"tiger", "", false}, - {"bird", "", false}, + "": { + {"Dog", "Dog", true}, + {"dog", "Dog", true}, + {"DOG", "Dog", true}, + {"Yorkie", "", false}, + {"Cat", "CAT", true}, + {"cat", "CAT", true}, + {"CAT", "CAT", true}, + {"tiger", "", false}, + {"bird", "", false}, + }, + "tag1": { + {"rabbit", "rabbit", true}, + {"Rabbit", "rabbit", true}, + {"cow", "", false}, + {"Cow", "", false}, + {"horse", "", false}, + {"Horse", "", false}, + {"Dog", "", false}, + {"dog", "", false}, + {"DOG", "", false}, + {"Cat", "CAT", true}, + {"cat", "CAT", true}, + {"CAT", "CAT", true}, + {"tiger", "", false}, + {"bird", "", false}, + }, + "tag2": { + {"rabbit", "", false}, + {"Rabbit", "", false}, + {"cow", "cow", true}, + {"Cow", "cow", true}, + {"horse", "", false}, + {"Horse", "", false}, + {"Dog", "", false}, + {"dog", "", false}, + {"DOG", "", false}, + {"Cat", "CAT", true}, + {"cat", "CAT", true}, + {"CAT", "CAT", true}, + {"tiger", "", false}, + {"bird", "", false}, + }, + "tag3": { + {"rabbit", "", false}, + {"Rabbit", "", false}, + {"cow", "", false}, + {"Cow", "", false}, + {"horse", "horse", true}, + {"Horse", "horse", true}, + {"Dog", "", false}, + {"dog", "", false}, + {"DOG", "", false}, + {"Cat", "CAT", true}, + {"cat", "CAT", true}, + {"CAT", "CAT", true}, + {"tiger", "", false}, + {"bird", "", false}, + }, } - for _, c := range cases { - f, found := fields.FieldByName(c.Name) - if found != c.Found { - t.Errorf("expected found to be %v but got %v", c.Found, found) - } - if found && f.Name != c.FieldName { - t.Errorf("expected field name to be %s but got %s", c.FieldName, f.Name) + for tagKey, cs := range cases { + for _, c := range cs { + name := tagKey + if name == "" { + name = "none" + } + t.Run(fmt.Sprintf("%s/%s", name, c.Name), func(t *testing.T) { + t.Parallel() + + fields := unionStructFields(reflect.TypeOf(myStruct{}), structFieldOptions{ + TagKey: tagKey, + }) + + const expectedNumFields = 2 + if numFields := len(fields.All()); numFields != expectedNumFields { + t.Errorf("expect %v fields, got %d", expectedNumFields, numFields) + } + + f, found := fields.FieldByName(c.Name) + if found != c.Found { + t.Errorf("expect %v found, got %v", c.Found, found) + } + if found && f.Name != c.FieldName { + t.Errorf("expect %v field name, got %s", c.FieldName, f.Name) + } + }) } } } diff --git a/feature/dynamodbstreams/attributevalue/marshaler_test.go b/feature/dynamodbstreams/attributevalue/marshaler_test.go index 26d4d91a5c7..ac0969111aa 100644 --- a/feature/dynamodbstreams/attributevalue/marshaler_test.go +++ b/feature/dynamodbstreams/attributevalue/marshaler_test.go @@ -520,7 +520,7 @@ func compareObjects(t *testing.T, expected interface{}, actual interface{}) { } func BenchmarkMarshalOneMember(b *testing.B) { - fieldCache = fieldCacher{} + fieldCache = &fieldCacher{} simple := simpleMarshalStruct{ String: "abc", @@ -547,7 +547,7 @@ func BenchmarkMarshalOneMember(b *testing.B) { } func BenchmarkMarshalTwoMembers(b *testing.B) { - fieldCache = fieldCacher{} + fieldCache = &fieldCacher{} simple := simpleMarshalStruct{ String: "abc", @@ -576,7 +576,7 @@ func BenchmarkMarshalTwoMembers(b *testing.B) { } func BenchmarkUnmarshalOneMember(b *testing.B) { - fieldCache = fieldCacher{} + fieldCache = &fieldCacher{} myStructAVMap, _ := Marshal(simpleMarshalStruct{ String: "abc", @@ -605,7 +605,7 @@ func BenchmarkUnmarshalOneMember(b *testing.B) { } func BenchmarkUnmarshalTwoMembers(b *testing.B) { - fieldCache = fieldCacher{} + fieldCache = &fieldCacher{} myStructAVMap, _ := Marshal(simpleMarshalStruct{ String: "abc", diff --git a/feature/dynamodbstreams/attributevalue/shared_test.go b/feature/dynamodbstreams/attributevalue/shared_test.go index 62071554878..41a747147a4 100644 --- a/feature/dynamodbstreams/attributevalue/shared_test.go +++ b/feature/dynamodbstreams/attributevalue/shared_test.go @@ -1,18 +1,37 @@ package attributevalue import ( - smithydocument "github.com/aws/smithy-go/document" - "github.com/google/go-cmp/cmp/cmpopts" + "fmt" "reflect" "strings" "testing" "time" + smithydocument "github.com/aws/smithy-go/document" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/dynamodbstreams/types" "github.com/google/go-cmp/cmp" ) +type testTextMarshaler struct { + Foo string +} + +func (t *testTextMarshaler) UnmarshalText(b []byte) error { + if !strings.HasPrefix(string(b), "Foo:") { + return fmt.Errorf(`missing "Foo:" prefix`) + } + + t.Foo = string(b)[len("Foo:"):] + return nil +} + +func (t testTextMarshaler) MarshalText() ([]byte, error) { + return []byte("Foo:" + t.Foo), nil +} + type testBinarySetStruct struct { Binarys [][]byte `dynamodbav:",binaryset"` } diff --git a/feature/dynamodbstreams/attributevalue/tag.go b/feature/dynamodbstreams/attributevalue/tag.go index 6eb901706fb..f01c432e6ea 100644 --- a/feature/dynamodbstreams/attributevalue/tag.go +++ b/feature/dynamodbstreams/attributevalue/tag.go @@ -18,7 +18,7 @@ type tag struct { } func (t *tag) parseAVTag(structTag reflect.StructTag) { - tagStr := structTag.Get("dynamodbav") + tagStr := structTag.Get(defaultTagKey) if len(tagStr) == 0 { return }