diff --git a/mapstructure.go b/mapstructure.go index f41bcc58..5f6c9991 100644 --- a/mapstructure.go +++ b/mapstructure.go @@ -862,6 +862,10 @@ func (d *Decoder) decodeMapFromStruct(name string, dataVal reflect.Value, val re // Next get the actual value of this field and verify it is assignable // to the map value. v := dataVal.Field(i) + if v.Kind() == reflect.Ptr && v.Elem().Kind() == reflect.Struct { + // Handle embedded struct pointers as embedded structs. + v = v.Elem() + } if !v.Type().AssignableTo(valMap.Type().Elem()) { return fmt.Errorf("cannot assign type '%s' to map value field of type '%s'", v.Type(), valMap.Type().Elem()) } @@ -1232,10 +1236,14 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e for i := 0; i < structType.NumField(); i++ { fieldType := structType.Field(i) - fieldKind := fieldType.Type.Kind() + fieldVal := structVal.Field(i) + if fieldVal.Kind() == reflect.Ptr && fieldVal.Elem().Kind() == reflect.Struct { + // Handle embedded struct pointers as embedded structs. + fieldVal = fieldVal.Elem() + } // If "squash" is specified in the tag, we squash the field down. - squash := d.config.Squash && fieldKind == reflect.Struct && fieldType.Anonymous + squash := d.config.Squash && fieldVal.Kind() == reflect.Struct && fieldType.Anonymous remain := false // We always parse the tags cause we're looking for other tags too @@ -1253,21 +1261,21 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e } if squash { - if fieldKind != reflect.Struct { + if fieldVal.Kind() != reflect.Struct { errors = appendErrors(errors, - fmt.Errorf("%s: unsupported type for squash: %s", fieldType.Name, fieldKind)) + fmt.Errorf("%s: unsupported type for squash: %s", fieldType.Name, fieldVal.Kind())) } else { - structs = append(structs, structVal.FieldByName(fieldType.Name)) + structs = append(structs, fieldVal) } continue } // Build our field if remain { - remainField = &field{fieldType, structVal.Field(i)} + remainField = &field{fieldType, fieldVal} } else { // Normal struct field, store it away - fields = append(fields, field{fieldType, structVal.Field(i)}) + fields = append(fields, field{fieldType, fieldVal}) } } } diff --git a/mapstructure_test.go b/mapstructure_test.go index 63614e83..dafcb94a 100644 --- a/mapstructure_test.go +++ b/mapstructure_test.go @@ -61,6 +61,11 @@ type EmbeddedSquash struct { Vunique string } +type EmbeddedPointerSquash struct { + *Basic `mapstructure:",squash"` + Vunique string +} + type EmbeddedAndNamed struct { Basic Named Basic @@ -655,6 +660,56 @@ func TestDecodeFrom_EmbeddedSquash(t *testing.T) { } } +func TestDecode_EmbeddedPointerSquash_FromStructToMap(t *testing.T) { + t.Parallel() + + input := EmbeddedPointerSquash{ + Basic: &Basic{ + Vstring: "foo", + }, + Vunique: "bar", + } + + var result map[string]interface{} + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an err: %s", err.Error()) + } + + if result["Vstring"] != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result["Vstring"]) + } + + if result["Vunique"] != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result["Vunique"]) + } +} + +func TestDecode_EmbeddedPointerSquash_FromMapToStruct(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "Vstring": "foo", + "Vunique": "bar", + } + + result := EmbeddedPointerSquash{ + Basic: &Basic{}, + } + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an err: %s", err.Error()) + } + + if result.Vstring != "foo" { + t.Errorf("vstring value should be 'foo': %#v", result.Vstring) + } + + if result.Vunique != "bar" { + t.Errorf("vunique value should be 'bar': %#v", result.Vunique) + } +} + func TestDecode_EmbeddedSquashConfig(t *testing.T) { t.Parallel() @@ -2334,7 +2389,26 @@ func TestDecode_StructTaggedWithOmitempty_KeepNonEmptyValues(t *testing.T) { "visible-map": emptyMap, "omittable-map": map[string]interface{}{"k": "v"}, "visible-nested": emptyNested, - "omittable-nested": &Nested{}, + "omittable-nested": map[string]interface{}{ + "Vbar": map[string]interface{}{ + "Vbool": false, + "Vdata": interface{}(nil), + "Vextra": "", + "Vfloat": float64(0), + "Vint": 0, + "Vint16": int16(0), + "Vint32": int32(0), + "Vint64": int64(0), + "Vint8": int8(0), + "VjsonFloat": float64(0), + "VjsonInt": 0, + "VjsonNumber": json.Number(""), + "VjsonUint": uint(0), + "Vstring": "", + "Vuint": uint(0), + }, + "Vfoo": "", + }, } actual := &map[string]interface{}{}