diff --git a/mapstructure.go b/mapstructure.go index 6b81b00..d7f47ad 100644 --- a/mapstructure.go +++ b/mapstructure.go @@ -909,6 +909,8 @@ func (d *Decoder) decodeMapFromStruct(name string, dataVal reflect.Value, val re // If Squash is set in the config, we squash the field down. squash := d.config.Squash && v.Kind() == reflect.Struct && f.Anonymous + // v = dereferencePtrToStructIfNeeded(v, d.config.TagName) + // Determine the name of the key in the map if index := strings.Index(tagValue, ","); index != -1 { if tagValue[:index] == "-" { @@ -1465,3 +1467,28 @@ func getKind(val reflect.Value) reflect.Kind { return kind } } + +func isStructTypeConvertibleToMap(typ reflect.Type, checkMapstructureTags bool, tagName string) bool { + for i := 0; i < typ.NumField(); i++ { + f := typ.Field(i) + if f.PkgPath == "" && !checkMapstructureTags { // check for unexported fields + return true + } + if checkMapstructureTags && f.Tag.Get(tagName) != "" { // check for mapstructure tags inside + return true + } + } + return false +} + +func dereferencePtrToStructIfNeeded(v reflect.Value, tagName string) reflect.Value { + if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct { + return v + } + deref := v.Elem() + derefT := deref.Type() + if isStructTypeConvertibleToMap(derefT, true, tagName) { + return deref + } + return v +} diff --git a/mapstructure_test.go b/mapstructure_test.go index a78e77a..7853248 100644 --- a/mapstructure_test.go +++ b/mapstructure_test.go @@ -7,6 +7,7 @@ import ( "sort" "strings" "testing" + "time" ) type Basic struct { @@ -67,6 +68,20 @@ type EmbeddedPointerSquash struct { Vunique string } +type BasicMapStructure struct { + Vunique string `mapstructure:"vunique"` + Vtime *time.Time `mapstructure:"time"` +} + +type NestedPointerWithMapstructure struct { + Vbar *BasicMapStructure `mapstructure:"vbar"` +} + +type EmbeddedPointerSquashWithNestedMapstructure struct { + *NestedPointerWithMapstructure `mapstructure:",squash"` + Vunique string +} + type EmbeddedAndNamed struct { Basic Named Basic @@ -716,6 +731,74 @@ func TestDecode_EmbeddedPointerSquash_FromMapToStruct(t *testing.T) { } } +func TestDecode_EmbeddedPointerSquashWithNestedMapstructure_FromStructToMap(t *testing.T) { + t.Parallel() + + vTime := time.Now() + + input := EmbeddedPointerSquashWithNestedMapstructure{ + NestedPointerWithMapstructure: &NestedPointerWithMapstructure{ + Vbar: &BasicMapStructure{ + Vunique: "bar", + Vtime: &vTime, + }, + }, + Vunique: "foo", + } + + var result map[string]interface{} + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an err: %s", err.Error()) + } + expected := map[string]interface{}{ + "vbar": map[string]interface{}{ + "vunique": "bar", + "time": &vTime, + }, + "Vunique": "foo", + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("result should be %#v: got %#v", expected, result) + } +} + +func TestDecode_EmbeddedPointerSquashWithNestedMapstructure_FromMapToStruct(t *testing.T) { + t.Parallel() + + vTime := time.Now() + + input := map[string]interface{}{ + "vbar": map[string]interface{}{ + "vunique": "bar", + "time": &vTime, + }, + "Vunique": "foo", + } + + result := EmbeddedPointerSquashWithNestedMapstructure{ + NestedPointerWithMapstructure: &NestedPointerWithMapstructure{}, + } + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an err: %s", err.Error()) + } + expected := EmbeddedPointerSquashWithNestedMapstructure{ + NestedPointerWithMapstructure: &NestedPointerWithMapstructure{ + Vbar: &BasicMapStructure{ + Vunique: "bar", + Vtime: &vTime, + }, + }, + Vunique: "foo", + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("result should be %#v: got %#v", expected, result) + } +} + func TestDecode_EmbeddedSquashConfig(t *testing.T) { t.Parallel()