Skip to content

Commit

Permalink
chore: fix more cases
Browse files Browse the repository at this point in the history
  • Loading branch information
kevwan committed Dec 10, 2022
1 parent 44d7995 commit 2fbed2c
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 15 deletions.
27 changes: 14 additions & 13 deletions core/mapping/unmarshaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -735,9 +735,17 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue inter
default:
switch v := keythData.(type) {
case bool:
targetValue.SetMapIndex(key, reflect.ValueOf(v))
if dereffedElemKind == reflect.Bool {
targetValue.SetMapIndex(key, reflect.ValueOf(v))
} else {
return emptyValue, errTypeMismatch
}
case string:
targetValue.SetMapIndex(key, reflect.ValueOf(v))
if dereffedElemKind == reflect.String {
targetValue.SetMapIndex(key, reflect.ValueOf(v))
} else {
return emptyValue, errTypeMismatch
}
case json.Number:
target := reflect.New(dereffedElemType)
if err := setValue(dereffedElemKind, target.Elem(), v.String()); err != nil {
Expand All @@ -746,8 +754,10 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue inter

targetValue.SetMapIndex(key, target.Elem())
default:
if err := setMapIndex(targetValue, key, keythValue); err != nil {
return targetValue, err
if dereffedElemKind == keythValue.Kind() {
targetValue.SetMapIndex(key, keythValue)
} else {
return emptyValue, errTypeMismatch
}
}
}
Expand Down Expand Up @@ -941,15 +951,6 @@ func readKeys(key string) []string {
return keys
}

func setMapIndex(targetMap, key, value reflect.Value) error {
if targetMap.MapIndex(key).Kind() != value.Kind() {
return errTypeMismatch
}

targetMap.SetMapIndex(key, value)
return nil
}

func setSameKindValue(targetType reflect.Type, target reflect.Value, value interface{}) {
if reflect.ValueOf(value).Type().AssignableTo(targetType) {
target.Set(reflect.ValueOf(value))
Expand Down
59 changes: 58 additions & 1 deletion core/mapping/unmarshaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3563,14 +3563,71 @@ func TestGoogleUUID(t *testing.T) {
assert.Equal(t, "6ba7b810-9dad-11d1-80b4-00c04fd430c2", val.Uidp.String())
}

func TestUnmarshalJsonReaderWithTypeMismatch(t *testing.T) {
func TestUnmarshalJsonReaderWithTypeMismatchBool(t *testing.T) {
var req struct {
Params map[string]bool `json:"params"`
}
body := `{"params":{"a":"123"}}`
assert.Equal(t, errTypeMismatch, UnmarshalJsonReader(strings.NewReader(body), &req))
}

func TestUnmarshalJsonReaderWithTypeMismatchString(t *testing.T) {
var req struct {
Params map[string]string `json:"params"`
}
body := `{"params":{"a":{"a":123}}}`
assert.Equal(t, errTypeMismatch, UnmarshalJsonReader(strings.NewReader(body), &req))
}

func TestUnmarshalJsonReaderWithMismatchType(t *testing.T) {
type Req struct {
Params map[string]string `json:"params"`
}

var req Req
body := `{"params":{"a":{"a":123}}}`
assert.Equal(t, errTypeMismatch, UnmarshalJsonReader(strings.NewReader(body), &req))
}

func TestUnmarshalJsonReaderWithMismatchTypeBool(t *testing.T) {
type Req struct {
Params map[string]bool `json:"params"`
}

tests := []struct {
name string
input string
}{
{
name: "int",
input: `{"params":{"a":123}}`,
},
{
name: "int",
input: `{"params":{"a":"123"}}`,
},
}

for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
var req Req
assert.Equal(t, errTypeMismatch, UnmarshalJsonReader(strings.NewReader(test.input), &req))
})
}
}

func TestUnmarshalJsonReaderWithMismatchTypeBoolMap(t *testing.T) {
var req struct {
Params map[string]string `json:"params"`
}
assert.Equal(t, errTypeMismatch, UnmarshalJsonMap(map[string]interface{}{
"params": map[string]interface{}{
"a": true,
},
}, &req))
}

func BenchmarkDefaultValue(b *testing.B) {
for i := 0; i < b.N; i++ {
var a struct {
Expand Down
9 changes: 8 additions & 1 deletion core/mapping/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,14 @@ func ValidatePtr(v *reflect.Value) error {
func convertType(kind reflect.Kind, str string) (interface{}, error) {
switch kind {
case reflect.Bool:
return str == "1" || strings.ToLower(str) == "true", nil
switch strings.ToLower(str) {
case "1", "true":
return true, nil
case "0", "false":
return false, nil
default:
return false, errTypeMismatch
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
intValue, err := strconv.ParseInt(str, 10, 64)
if err != nil {
Expand Down

0 comments on commit 2fbed2c

Please sign in to comment.