diff --git a/decode_hooks.go b/decode_hooks.go index 1f0abc65..77578d00 100644 --- a/decode_hooks.go +++ b/decode_hooks.go @@ -16,10 +16,11 @@ func typedDecodeHook(h DecodeHookFunc) DecodeHookFunc { // Create variables here so we can reference them with the reflect pkg var f1 DecodeHookFuncType var f2 DecodeHookFuncKind + var f3 DecodeHookFuncValue // Fill in the variables into this interface and the rest is done // automatically using the reflect package. - potential := []interface{}{f1, f2} + potential := []interface{}{f1, f2, f3} v := reflect.ValueOf(h) vt := v.Type() @@ -38,13 +39,15 @@ func typedDecodeHook(h DecodeHookFunc) DecodeHookFunc { // that took reflect.Kind instead of reflect.Type. func DecodeHookExec( raw DecodeHookFunc, - from reflect.Type, to reflect.Type, - data interface{}) (interface{}, error) { + from reflect.Value, to reflect.Value) (interface{}, error) { + switch f := typedDecodeHook(raw).(type) { case DecodeHookFuncType: - return f(from, to, data) + return f(from.Type(), to.Type(), from.Interface()) case DecodeHookFuncKind: - return f(from.Kind(), to.Kind(), data) + return f(from.Kind(), to.Kind(), from.Interface()) + case DecodeHookFuncValue: + return f(from, to) default: return nil, errors.New("invalid decode hook signature") } @@ -56,22 +59,16 @@ func DecodeHookExec( // The composed funcs are called in order, with the result of the // previous transformation. func ComposeDecodeHookFunc(fs ...DecodeHookFunc) DecodeHookFunc { - return func( - f reflect.Type, - t reflect.Type, - data interface{}) (interface{}, error) { + return func(f reflect.Value, t reflect.Value) (interface{}, error) { var err error + var data interface{} + newFrom := f for _, f1 := range fs { - data, err = DecodeHookExec(f1, f, t, data) + data, err = DecodeHookExec(f1, newFrom, t) if err != nil { return nil, err } - - // Modify the from kind to be correct with the new data - f = nil - if val := reflect.ValueOf(data); val.IsValid() { - f = val.Type() - } + newFrom = reflect.ValueOf(data) } return data, nil @@ -215,3 +212,21 @@ func WeaklyTypedHook( return data, nil } + +func RecursiveStructToMapHookFunc() DecodeHookFunc { + return func(f reflect.Value, t reflect.Value) (interface{}, error) { + if f.Kind() != reflect.Struct { + return f.Interface(), nil + } + + var i interface{} = struct{}{} + if t.Type() != reflect.TypeOf(&i).Elem() { + return f.Interface(), nil + } + + m := make(map[string]interface{}) + t.Set(reflect.ValueOf(m)) + + return f.Interface(), nil + } +} diff --git a/decode_hooks_test.go b/decode_hooks_test.go index 028afbdb..d6902534 100644 --- a/decode_hooks_test.go +++ b/decode_hooks_test.go @@ -26,7 +26,7 @@ func TestComposeDecodeHookFunc(t *testing.T) { f := ComposeDecodeHookFunc(f1, f2) result, err := DecodeHookExec( - f, reflect.TypeOf(""), reflect.TypeOf([]byte("")), "") + f, reflect.ValueOf(""), reflect.ValueOf([]byte(""))) if err != nil { t.Fatalf("bad: %s", err) } @@ -47,7 +47,7 @@ func TestComposeDecodeHookFunc_err(t *testing.T) { f := ComposeDecodeHookFunc(f1, f2) _, err := DecodeHookExec( - f, reflect.TypeOf(""), reflect.TypeOf([]byte("")), 42) + f, reflect.ValueOf(""), reflect.ValueOf([]byte(""))) if err.Error() != "foo" { t.Fatalf("bad: %s", err) } @@ -74,7 +74,7 @@ func TestComposeDecodeHookFunc_kinds(t *testing.T) { f := ComposeDecodeHookFunc(f1, f2) _, err := DecodeHookExec( - f, reflect.TypeOf(""), reflect.TypeOf([]byte("")), "") + f, reflect.ValueOf(""), reflect.ValueOf([]byte(""))) if err != nil { t.Fatalf("bad: %s", err) } @@ -86,34 +86,31 @@ func TestComposeDecodeHookFunc_kinds(t *testing.T) { func TestStringToSliceHookFunc(t *testing.T) { f := StringToSliceHookFunc(",") - strType := reflect.TypeOf("") - sliceType := reflect.TypeOf([]byte("")) + strValue := reflect.ValueOf("42") + sliceValue := reflect.ValueOf([]byte("42")) cases := []struct { - f, t reflect.Type - data interface{} + f, t reflect.Value result interface{} err bool }{ - {sliceType, sliceType, 42, 42, false}, - {strType, strType, 42, 42, false}, + {sliceValue, sliceValue, []byte("42"), false}, + {strValue, strValue, "42", false}, { - strType, - sliceType, - "foo,bar,baz", + reflect.ValueOf("foo,bar,baz"), + sliceValue, []string{"foo", "bar", "baz"}, false, }, { - strType, - sliceType, - "", + reflect.ValueOf(""), + sliceValue, []string{}, false, }, } for i, tc := range cases { - actual, err := DecodeHookExec(f, tc.f, tc.t, tc.data) + actual, err := DecodeHookExec(f, tc.f, tc.t) if tc.err != (err != nil) { t.Fatalf("case %d: expected err %#v", i, tc.err) } @@ -128,21 +125,20 @@ func TestStringToSliceHookFunc(t *testing.T) { func TestStringToTimeDurationHookFunc(t *testing.T) { f := StringToTimeDurationHookFunc() - strType := reflect.TypeOf("") - timeType := reflect.TypeOf(time.Duration(5)) + timeValue := reflect.ValueOf(time.Duration(5)) + strValue := reflect.ValueOf("") cases := []struct { - f, t reflect.Type - data interface{} + f, t reflect.Value result interface{} err bool }{ - {strType, timeType, "5s", 5 * time.Second, false}, - {strType, timeType, "5", time.Duration(0), true}, - {strType, strType, "5", "5", false}, + {reflect.ValueOf("5s"), timeValue, 5 * time.Second, false}, + {reflect.ValueOf("5"), timeValue, time.Duration(0), true}, + {reflect.ValueOf("5"), strValue, "5", false}, } for i, tc := range cases { - actual, err := DecodeHookExec(f, tc.f, tc.t, tc.data) + actual, err := DecodeHookExec(f, tc.f, tc.t) if tc.err != (err != nil) { t.Fatalf("case %d: expected err %#v", i, tc.err) } @@ -155,24 +151,23 @@ func TestStringToTimeDurationHookFunc(t *testing.T) { } func TestStringToTimeHookFunc(t *testing.T) { - strType := reflect.TypeOf("") - timeType := reflect.TypeOf(time.Time{}) + strValue := reflect.ValueOf("5") + timeValue := reflect.ValueOf(time.Time{}) cases := []struct { - f, t reflect.Type + f, t reflect.Value layout string - data interface{} result interface{} err bool }{ - {strType, timeType, time.RFC3339, "2006-01-02T15:04:05Z", + {reflect.ValueOf("2006-01-02T15:04:05Z"), timeValue, time.RFC3339, time.Date(2006, 1, 2, 15, 4, 5, 0, time.UTC), false}, - {strType, timeType, time.RFC3339, "5", time.Time{}, true}, - {strType, strType, time.RFC3339, "5", "5", false}, + {strValue, timeValue, time.RFC3339, time.Time{}, true}, + {strValue, strValue, time.RFC3339, "5", false}, } for i, tc := range cases { f := StringToTimeHookFunc(tc.layout) - actual, err := DecodeHookExec(f, tc.f, tc.t, tc.data) + actual, err := DecodeHookExec(f, tc.f, tc.t) if tc.err != (err != nil) { t.Fatalf("case %d: expected err %#v", i, tc.err) } @@ -185,23 +180,22 @@ func TestStringToTimeHookFunc(t *testing.T) { } func TestStringToIPHookFunc(t *testing.T) { - strType := reflect.TypeOf("") - ipType := reflect.TypeOf(net.IP{}) + strValue := reflect.ValueOf("5") + ipValue := reflect.ValueOf(net.IP{}) cases := []struct { - f, t reflect.Type - data interface{} + f, t reflect.Value result interface{} err bool }{ - {strType, ipType, "1.2.3.4", + {reflect.ValueOf("1.2.3.4"), ipValue, net.IPv4(0x01, 0x02, 0x03, 0x04), false}, - {strType, ipType, "5", net.IP{}, true}, - {strType, strType, "5", "5", false}, + {strValue, ipValue, net.IP{}, true}, + {strValue, strValue, "5", false}, } for i, tc := range cases { f := StringToIPHookFunc() - actual, err := DecodeHookExec(f, tc.f, tc.t, tc.data) + actual, err := DecodeHookExec(f, tc.f, tc.t) if tc.err != (err != nil) { t.Fatalf("case %d: expected err %#v", i, tc.err) } @@ -214,28 +208,27 @@ func TestStringToIPHookFunc(t *testing.T) { } func TestStringToIPNetHookFunc(t *testing.T) { - strType := reflect.TypeOf("") - ipNetType := reflect.TypeOf(net.IPNet{}) + strValue := reflect.ValueOf("5") + ipNetValue := reflect.ValueOf(net.IPNet{}) var nilNet *net.IPNet = nil cases := []struct { - f, t reflect.Type - data interface{} + f, t reflect.Value result interface{} err bool }{ - {strType, ipNetType, "1.2.3.4/24", + {reflect.ValueOf("1.2.3.4/24"), ipNetValue, &net.IPNet{ IP: net.IP{0x01, 0x02, 0x03, 0x00}, Mask: net.IPv4Mask(0xff, 0xff, 0xff, 0x00), }, false}, - {strType, ipNetType, "5", nilNet, true}, - {strType, strType, "5", "5", false}, + {strValue, ipNetValue, nilNet, true}, + {strValue, strValue, "5", false}, } for i, tc := range cases { f := StringToIPNetHookFunc() - actual, err := DecodeHookExec(f, tc.f, tc.t, tc.data) + actual, err := DecodeHookExec(f, tc.f, tc.t) if tc.err != (err != nil) { t.Fatalf("case %d: expected err %#v", i, tc.err) } @@ -250,67 +243,58 @@ func TestStringToIPNetHookFunc(t *testing.T) { func TestWeaklyTypedHook(t *testing.T) { var f DecodeHookFunc = WeaklyTypedHook - boolType := reflect.TypeOf(true) - strType := reflect.TypeOf("") - sliceType := reflect.TypeOf([]byte("")) + strValue := reflect.ValueOf("") cases := []struct { - f, t reflect.Type - data interface{} + f, t reflect.Value result interface{} err bool }{ // TO STRING { - boolType, - strType, - false, + reflect.ValueOf(false), + strValue, "0", false, }, { - boolType, - strType, - true, + reflect.ValueOf(true), + strValue, "1", false, }, { - reflect.TypeOf(float32(1)), - strType, - float32(7), + reflect.ValueOf(float32(7)), + strValue, "7", false, }, { - reflect.TypeOf(int(1)), - strType, - int(7), + reflect.ValueOf(int(7)), + strValue, "7", false, }, { - sliceType, - strType, - []uint8("foo"), + reflect.ValueOf([]uint8("foo")), + strValue, "foo", false, }, { - reflect.TypeOf(uint(1)), - strType, - uint(7), + reflect.ValueOf(uint(7)), + strValue, "7", false, }, } for i, tc := range cases { - actual, err := DecodeHookExec(f, tc.f, tc.t, tc.data) + actual, err := DecodeHookExec(f, tc.f, tc.t) if tc.err != (err != nil) { t.Fatalf("case %d: expected err %#v", i, tc.err) } @@ -321,3 +305,117 @@ func TestWeaklyTypedHook(t *testing.T) { } } } + +func TestStructToMapHookFuncTabled(t *testing.T) { + var f DecodeHookFunc = RecursiveStructToMapHookFunc() + + type b struct { + TestKey string + } + + type a struct { + Sub b + } + + testStruct := a{ + Sub: b{ + TestKey: "testval", + }, + } + + testMap := map[string]interface{}{ + "Sub": map[string]interface{}{ + "TestKey": "testval", + }, + } + + cases := []struct { + name string + receiver interface{} + input interface{} + expected interface{} + err bool + }{ + { + "map receiver", + func() interface{} { + var res map[string]interface{} + return &res + }(), + testStruct, + &testMap, + false, + }, + { + "interface receiver", + func() interface{} { + var res interface{} + return &res + }(), + testStruct, + func() interface{} { + var exp interface{} = testMap + return &exp + }(), + false, + }, + { + "slice receiver errors", + func() interface{} { + var res []string + return &res + }(), + testStruct, + new([]string), + true, + }, + { + "slice to slice - no change", + func() interface{} { + var res []string + return &res + }(), + []string{"a", "b"}, + &[]string{"a", "b"}, + false, + }, + { + "string to string - no change", + func() interface{} { + var res string + return &res + }(), + "test", + func() *string { + s := "test" + return &s + }(), + false, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + cfg := &DecoderConfig{ + DecodeHook: f, + Result: tc.receiver, + } + + d, err := NewDecoder(cfg) + if err != nil { + t.Fatalf("unexpected err %#v", err) + } + + err = d.Decode(tc.input) + if tc.err != (err != nil) { + t.Fatalf("expected err %#v", err) + } + + if !reflect.DeepEqual(tc.expected, tc.receiver) { + t.Fatalf("expected %#v, got %#v", + tc.expected, tc.receiver) + } + }) + + } +} diff --git a/mapstructure.go b/mapstructure.go index daea3318..9e3ee844 100644 --- a/mapstructure.go +++ b/mapstructure.go @@ -139,6 +139,10 @@ type DecodeHookFuncType func(reflect.Type, reflect.Type, interface{}) (interface // source and target types. type DecodeHookFuncKind func(reflect.Kind, reflect.Kind, interface{}) (interface{}, error) +// DecodeHookFuncRaw is a DecodeHookFunc which has complete access to both the source and target +// values. +type DecodeHookFuncValue func(from reflect.Value, to reflect.Value) (interface{}, error) + // DecoderConfig is the configuration that is used to create a new decoder // and allows customization of various aspects of decoding. type DecoderConfig struct { @@ -368,9 +372,7 @@ func (d *Decoder) decode(name string, input interface{}, outVal reflect.Value) e if d.config.DecodeHook != nil { // We have a DecodeHook, so let's pre-process the input. var err error - input, err = DecodeHookExec( - d.config.DecodeHook, - inputVal.Type(), outVal.Type(), input) + input, err = DecodeHookExec(d.config.DecodeHook, inputVal, outVal) if err != nil { return fmt.Errorf("error decoding '%s': %s", name, err) }