Skip to content

Commit

Permalink
allow decoding nested struct ptrs to map
Browse files Browse the repository at this point in the history
  • Loading branch information
vlanse committed Jan 21, 2022
1 parent b9b99d7 commit 94c41ea
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 0 deletions.
27 changes: 27 additions & 0 deletions mapstructure.go
Expand Up @@ -909,6 +909,8 @@ func (d *Decoder) decodeMapFromStruct(name string, dataVal reflect.Value, val re
// If Squash is set in the config, we squash the field down.
squash := d.config.Squash && v.Kind() == reflect.Struct && f.Anonymous

v = dereferencePtrToStructIfNeeded(v, d.config.TagName)

// Determine the name of the key in the map
if index := strings.Index(tagValue, ","); index != -1 {
if tagValue[:index] == "-" {
Expand Down Expand Up @@ -1465,3 +1467,28 @@ func getKind(val reflect.Value) reflect.Kind {
return kind
}
}

func isStructTypeConvertibleToMap(typ reflect.Type, checkMapstructureTags bool, tagName string) bool {
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
if f.PkgPath == "" && !checkMapstructureTags { // check for unexported fields
return true
}
if checkMapstructureTags && f.Tag.Get(tagName) != "" { // check for mapstructure tags inside
return true
}
}
return false
}

func dereferencePtrToStructIfNeeded(v reflect.Value, tagName string) reflect.Value {
if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
return v
}
deref := v.Elem()
derefT := deref.Type()
if isStructTypeConvertibleToMap(derefT, true, tagName) {
return deref
}
return v
}
83 changes: 83 additions & 0 deletions mapstructure_test.go
Expand Up @@ -7,6 +7,7 @@ import (
"sort"
"strings"
"testing"
"time"
)

type Basic struct {
Expand Down Expand Up @@ -67,6 +68,20 @@ type EmbeddedPointerSquash struct {
Vunique string
}

type BasicMapStructure struct {
Vunique string `mapstructure:"vunique"`
Vtime *time.Time `mapstructure:"time"`
}

type NestedPointerWithMapstructure struct {
Vbar *BasicMapStructure `mapstructure:"vbar"`
}

type EmbeddedPointerSquashWithNestedMapstructure struct {
*NestedPointerWithMapstructure `mapstructure:",squash"`
Vunique string
}

type EmbeddedAndNamed struct {
Basic
Named Basic
Expand Down Expand Up @@ -716,6 +731,74 @@ func TestDecode_EmbeddedPointerSquash_FromMapToStruct(t *testing.T) {
}
}

func TestDecode_EmbeddedPointerSquashWithNestedMapstructure_FromStructToMap(t *testing.T) {
t.Parallel()

vTime := time.Now()

input := EmbeddedPointerSquashWithNestedMapstructure{
NestedPointerWithMapstructure: &NestedPointerWithMapstructure{
Vbar: &BasicMapStructure{
Vunique: "bar",
Vtime: &vTime,
},
},
Vunique: "foo",
}

var result map[string]interface{}
err := Decode(input, &result)
if err != nil {
t.Fatalf("got an err: %s", err.Error())
}
expected := map[string]interface{}{
"vbar": map[string]interface{}{
"vunique": "bar",
"time": &vTime,
},
"Vunique": "foo",
}

if !reflect.DeepEqual(result, expected) {
t.Errorf("result should be %#v: got %#v", expected, result)
}
}

func TestDecode_EmbeddedPointerSquashWithNestedMapstructure_FromMapToStruct(t *testing.T) {
t.Parallel()

vTime := time.Now()

input := map[string]interface{}{
"vbar": map[string]interface{}{
"vunique": "bar",
"time": &vTime,
},
"Vunique": "foo",
}

result := EmbeddedPointerSquashWithNestedMapstructure{
NestedPointerWithMapstructure: &NestedPointerWithMapstructure{},
}
err := Decode(input, &result)
if err != nil {
t.Fatalf("got an err: %s", err.Error())
}
expected := EmbeddedPointerSquashWithNestedMapstructure{
NestedPointerWithMapstructure: &NestedPointerWithMapstructure{
Vbar: &BasicMapStructure{
Vunique: "bar",
Vtime: &vTime,
},
},
Vunique: "foo",
}

if !reflect.DeepEqual(result, expected) {
t.Errorf("result should be %#v: got %#v", expected, result)
}
}

func TestDecode_EmbeddedSquashConfig(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit 94c41ea

Please sign in to comment.