From bfe0d3c2f4c47d70d56f279e6c8867b05a3117cb Mon Sep 17 00:00:00 2001 From: Matt Keeler Date: Wed, 22 Jul 2020 11:52:35 -0400 Subject: [PATCH] Ensure that intermediate maps during struct to struct decoding are settable --- mapstructure.go | 31 ++++++++++++++++--- mapstructure_bugs_test.go | 64 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 5 deletions(-) diff --git a/mapstructure.go b/mapstructure.go index b384d9d9..f41bcc58 100644 --- a/mapstructure.go +++ b/mapstructure.go @@ -906,11 +906,22 @@ func (d *Decoder) decodeMapFromStruct(name string, dataVal reflect.Value, val re mType := reflect.MapOf(vKeyType, vElemType) vMap := reflect.MakeMap(mType) - err := d.decode(keyName, x.Interface(), vMap) + // Creating a pointer to a map so that other methods can completely + // overwrite the map if need be (looking at you decodeMapFromMap). The + // indirection allows the underlying map to be settable (CanSet() == true) + // where as reflect.MakeMap returns an unsettable map. + addrVal := reflect.New(vMap.Type()) + reflect.Indirect(addrVal).Set(vMap) + + err := d.decode(keyName, x.Interface(), reflect.Indirect(addrVal)) if err != nil { return err } + // the underlying map may have been completely overwritten so pull + // it indirectly out of the enclosing value. + vMap = reflect.Indirect(addrVal) + if squash { for _, k := range vMap.MapKeys() { valMap.SetMapIndex(k, vMap.MapIndex(k)) @@ -1154,13 +1165,23 @@ func (d *Decoder) decodeStruct(name string, data interface{}, val reflect.Value) // Not the most efficient way to do this but we can optimize later if // we want to. To convert from struct to struct we go to map first // as an intermediary. - m := make(map[string]interface{}) - mval := reflect.Indirect(reflect.ValueOf(&m)) - if err := d.decodeMapFromStruct(name, dataVal, mval, mval); err != nil { + + // Make a new map to hold our result + mapType := reflect.TypeOf((map[string]interface{})(nil)) + mval := reflect.MakeMap(mapType) + + // Creating a pointer to a map so that other methods can completely + // overwrite the map if need be (looking at you decodeMapFromMap). The + // indirection allows the underlying map to be settable (CanSet() == true) + // where as reflect.MakeMap returns an unsettable map. + addrVal := reflect.New(mval.Type()) + + reflect.Indirect(addrVal).Set(mval) + if err := d.decodeMapFromStruct(name, dataVal, reflect.Indirect(addrVal), mval); err != nil { return err } - result := d.decodeStructFromMap(name, mval, val) + result := d.decodeStructFromMap(name, reflect.Indirect(addrVal), val) return result default: diff --git a/mapstructure_bugs_test.go b/mapstructure_bugs_test.go index f030c2d6..8c87ebd4 100644 --- a/mapstructure_bugs_test.go +++ b/mapstructure_bugs_test.go @@ -3,6 +3,7 @@ package mapstructure import ( "reflect" "testing" + "time" ) // GH-1, GH-10, GH-96 @@ -475,3 +476,66 @@ func TestDecodeBadDataTypeInSlice(t *testing.T) { t.Error("An error was expected, got nil") } } + +// #202 Ensure that intermediate maps in the struct -> struct decode process are settable +// and not just the elements within them. +func TestDecodeIntermeidateMapsSettable(t *testing.T) { + type Timestamp struct { + Seconds int64 + Nanos int32 + } + + type TsWrapper struct { + Timestamp *Timestamp + } + + type TimeWrapper struct { + Timestamp time.Time + } + + input := TimeWrapper{ + Timestamp: time.Unix(123456789, 987654), + } + + expected := TsWrapper{ + Timestamp: &Timestamp{ + Seconds: 123456789, + Nanos: 987654, + }, + } + + timePtrType := reflect.TypeOf((*time.Time)(nil)) + mapStrInfType := reflect.TypeOf((map[string]interface{})(nil)) + + var actual TsWrapper + decoder, err := NewDecoder(&DecoderConfig{ + Result: &actual, + DecodeHook: func(from, to reflect.Type, data interface{}) (interface{}, error) { + if from == timePtrType && to == mapStrInfType { + ts := data.(*time.Time) + nanos := ts.UnixNano() + + seconds := nanos / 1000000000 + nanos = nanos % 1000000000 + + return &map[string]interface{}{ + "Seconds": seconds, + "Nanos": int32(nanos), + }, nil + } + return data, nil + }, + }) + + if err != nil { + t.Fatalf("failed to create decoder: %v", err) + } + + if err := decoder.Decode(&input); err != nil { + t.Fatalf("failed to decode input: %v", err) + } + + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("expected: %#[1]v (%[1]T), got: %#[2]v (%[2]T)", expected, actual) + } +}