Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix AttributeValue marshaling and names in expressions #1590

Merged
merged 3 commits into from Feb 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changelog/2096a6beb82d44bea4a469c197f6de40.json
@@ -0,0 +1,8 @@
{
"id": "2096a6be-b82d-44be-a4a4-69c197f6de40",
"type": "feature",
"description": "Add support for expression names with dots via new NameBuilder function NameNoDotSplit, related to [aws/aws-sdk-go#2570](https://github.com/aws/aws-sdk-go/issues/2570)",
"modules": [
"feature/dynamodb/expression"
]
}
9 changes: 9 additions & 0 deletions .changelog/98a8c469e1d64c9aa06b0e467ac61dd9.json
@@ -0,0 +1,9 @@
{
"id": "98a8c469-e1d6-4c9a-a06b-0e467ac61dd9",
"type": "bugfix",
"description": "Fixes [#1569](https://github.com/aws/aws-sdk-go-v2/issues/1569) inconsistent serialization of Go struct field names",
"modules": [
"feature/dynamodb/attributevalue",
"feature/dynamodbstreams/attributevalue"
]
}
9 changes: 9 additions & 0 deletions .changelog/db81731fa3ab450e9ea3535a0d4aaedd.json
@@ -0,0 +1,9 @@
{
"id": "db81731f-a3ab-450e-9ea3-535a0d4aaedd",
"type": "feature",
"description": "Fixes [#645](https://github.com/aws/aws-sdk-go-v2/issues/645), [#411](https://github.com/aws/aws-sdk-go-v2/issues/411) by adding support for (un)marshaling AttributeValue maps to Go maps key types of string, number, bool, and types implementing encoding.Text(un)Marshaler interface",
"modules": [
"feature/dynamodb/attributevalue",
"feature/dynamodbstreams/attributevalue"
]
}
149 changes: 129 additions & 20 deletions feature/dynamodb/attributevalue/decode.go
@@ -1,6 +1,7 @@
package attributevalue

import (
"encoding"
"fmt"
"reflect"
"strconv"
Expand Down Expand Up @@ -197,7 +198,7 @@ func UnmarshalListOfMapsWithOptions(l []map[string]types.AttributeValue, out int
}

// DecoderOptions is a collection of options to configure how the decoder
// unmarshalls the value.
// unmarshals the value.
type DecoderOptions struct {
// Support other custom struct tag keys, such as `yaml`, `json`, or `toml`.
// Note that values provided with a custom TagKey must also be supported
Expand All @@ -221,7 +222,7 @@ type Decoder struct {
// NewDecoder creates a new Decoder with default configuration. Use
// the `opts` functional options to override the default configuration.
func NewDecoder(optFns ...func(*DecoderOptions)) *Decoder {
var options DecoderOptions
options := DecoderOptions{TagKey: defaultTagKey}
for _, fn := range optFns {
fn(&options)
}
Expand Down Expand Up @@ -254,14 +255,14 @@ func (d *Decoder) decode(av types.AttributeValue, v reflect.Value, fieldTag tag)
var u Unmarshaler
_, isNull := av.(*types.AttributeValueMemberNULL)
if av == nil || isNull {
u, v = indirect(v, true)
u, v = indirect(v, indirectOptions{decodeNull: true})
if u != nil {
return u.UnmarshalDynamoDBAttributeValue(av)
}
return d.decodeNull(v)
}

u, v = indirect(v, false)
u, v = indirect(v, indirectOptions{})
if u != nil {
return u.UnmarshalDynamoDBAttributeValue(av)
}
Expand Down Expand Up @@ -386,7 +387,7 @@ func (d *Decoder) decodeBinarySet(bs [][]byte, v reflect.Value) error {
if !isArray {
v.SetLen(i + 1)
}
u, elem := indirect(v.Index(i), false)
u, elem := indirect(v.Index(i), indirectOptions{})
if u != nil {
return u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberBS{Value: bs})
}
Expand Down Expand Up @@ -513,7 +514,7 @@ func (d *Decoder) decodeNumberSet(ns []string, v reflect.Value) error {
if !isArray {
v.SetLen(i + 1)
}
u, elem := indirect(v.Index(i), false)
u, elem := indirect(v.Index(i), indirectOptions{})
if u != nil {
return u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberNS{Value: ns})
}
Expand Down Expand Up @@ -564,32 +565,48 @@ func (d *Decoder) decodeList(avList []types.AttributeValue, v reflect.Value) err
return nil
}

func (d *Decoder) decodeMap(avMap map[string]types.AttributeValue, v reflect.Value) error {
func (d *Decoder) decodeMap(avMap map[string]types.AttributeValue, v reflect.Value) (err error) {
var decodeMapKey func(v string, key reflect.Value, fieldTag tag) error

switch v.Kind() {
case reflect.Map:
t := v.Type()
if t.Key().Kind() != reflect.String {
return &UnmarshalTypeError{Value: "map string key", Type: t.Key()}
decodeMapKey, err = d.getMapKeyDecoder(v.Type().Key())
if err != nil {
return err
}

if v.IsNil() {
v.Set(reflect.MakeMap(t))
v.Set(reflect.MakeMap(v.Type()))
}
case reflect.Struct:
case reflect.Interface:
v.Set(reflect.MakeMap(stringInterfaceMapType))
decodeMapKey = d.decodeString
v = v.Elem()
default:
return &UnmarshalTypeError{Value: "map", Type: v.Type()}
}

if v.Kind() == reflect.Map {
keyType := v.Type().Key()
valueType := v.Type().Elem()
for k, av := range avMap {
key := reflect.New(v.Type().Key()).Elem()
key.SetString(k)
elem := reflect.New(v.Type().Elem()).Elem()
key := reflect.New(keyType).Elem()
// handle pointer keys
_, indirectKey := indirect(key, indirectOptions{skipUnmarshaler: true})
if err := decodeMapKey(k, indirectKey, tag{}); err != nil {
return &UnmarshalTypeError{
Value: fmt.Sprintf("map key %q", k),
Type: keyType,
Err: err,
}
}

elem := reflect.New(valueType).Elem()
if err := d.decode(av, elem, tag{}); err != nil {
return err
}

v.SetMapIndex(key, elem)
}
} else if v.Kind() == reflect.Struct {
Expand All @@ -609,6 +626,50 @@ func (d *Decoder) decodeMap(avMap map[string]types.AttributeValue, v reflect.Val
return nil
}

var numberType = reflect.TypeOf(Number(""))
var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()

func (d *Decoder) getMapKeyDecoder(keyType reflect.Type) (func(string, reflect.Value, tag) error, error) {
// Test the key type to determine if it implements the TextUnmarshaler interface.
if reflect.PtrTo(keyType).Implements(textUnmarshalerType) || keyType.Implements(textUnmarshalerType) {
return func(v string, k reflect.Value, _ tag) error {
if !k.CanAddr() {
return fmt.Errorf("cannot take address of map key, %v", k.Type())
}
return k.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(v))
}, nil
}

var decodeMapKey func(v string, key reflect.Value, fieldTag tag) error

switch keyType.Kind() {
case reflect.Bool:
decodeMapKey = func(v string, key reflect.Value, fieldTag tag) error {
b, err := strconv.ParseBool(v)
if err != nil {
return err
}
return d.decodeBool(b, key)
}
case reflect.String:
// Number type handled as a string
decodeMapKey = d.decodeString

case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64:
decodeMapKey = d.decodeNumber

default:
return nil, &UnmarshalTypeError{
Value: "map key must be string, number, bool, or TextUnmarshaler",
Type: keyType,
}
}

return decodeMapKey, nil
}

func (d *Decoder) decodeNull(v reflect.Value) error {
if v.IsValid() && v.CanSet() {
v.Set(reflect.Zero(v.Type()))
Expand Down Expand Up @@ -675,7 +736,7 @@ func (d *Decoder) decodeStringSet(ss []string, v reflect.Value) error {
if !isArray {
v.SetLen(i + 1)
}
u, elem := indirect(v.Index(i), false)
u, elem := indirect(v.Index(i), indirectOptions{})
if u != nil {
return u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberSS{Value: ss})
}
Expand Down Expand Up @@ -713,38 +774,82 @@ func decoderFieldByIndex(v reflect.Value, index []int) reflect.Value {
return v
}

type indirectOptions struct {
decodeNull bool
skipUnmarshaler bool
}

// indirect will walk a value's interface or pointer value types. Returning
// the final value or the value a unmarshaler is defined on.
//
// Based on the enoding/json type reflect value type indirection in Go Stdlib
// https://golang.org/src/encoding/json/decode.go indirect func.
func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, reflect.Value) {
func indirect(v reflect.Value, opts indirectOptions) (Unmarshaler, reflect.Value) {
// Issue #24153 indicates that it is generally not a guaranteed property
// that you may round-trip a reflect.Value by calling Value.Addr().Elem()
// and expect the value to still be settable for values derived from
// unexported embedded struct fields.
//
// The logic below effectively does this when it first addresses the value
// (to satisfy possible pointer methods) and continues to dereference
// subsequent pointers as necessary.
//
// After the first round-trip, we set v back to the original value to
// preserve the original RW flags contained in reflect.Value.
v0 := v
haveAddr := false

// If v is a named type and is addressable,
// start with its address, so that if the type has pointer methods,
// we find them.
if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() {
haveAddr = true
v = v.Addr()
}

for {
// Load value from interface, but only if the result will be
// usefully addressable.
if v.Kind() == reflect.Interface && !v.IsNil() {
e := v.Elem()
if e.Kind() == reflect.Ptr && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Ptr) {
if e.Kind() == reflect.Ptr && !e.IsNil() && (!opts.decodeNull || e.Elem().Kind() == reflect.Ptr) {
haveAddr = false
v = e
continue
}
if e.Kind() != reflect.Ptr && e.IsValid() {
return nil, e
}
}
if v.Kind() != reflect.Ptr {
break
}
if v.Elem().Kind() != reflect.Ptr && decodingNull && v.CanSet() {
if opts.decodeNull && v.CanSet() {
break
}

// Prevent infinite loop if v is an interface pointing to its own address:
// var v interface{}
// v = &v
if v.Elem().Kind() == reflect.Interface && v.Elem().Elem() == v {
v = v.Elem()
break
}
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
if v.Type().NumMethod() > 0 {
if !opts.skipUnmarshaler && v.Type().NumMethod() > 0 && v.CanInterface() {
if u, ok := v.Interface().(Unmarshaler); ok {
return u, reflect.Value{}
}
}
v = v.Elem()

if haveAddr {
v = v0 // restore original value after round-trip Value.Addr().Elem()
haveAddr = false
} else {
v = v.Elem()
}
}

return nil, v
Expand Down Expand Up @@ -782,8 +887,12 @@ func (n Number) String() string {
type UnmarshalTypeError struct {
Value string
Type reflect.Type
Err error
}

// Unwrap returns the underlying error if any.
func (e *UnmarshalTypeError) Unwrap() error { return e.Err }

// Error returns the string representation of the error.
// satisfying the error interface
func (e *UnmarshalTypeError) Error() string {
Expand Down