diff --git a/mapstructure.go b/mapstructure.go index c8619fa..99bee4f 100644 --- a/mapstructure.go +++ b/mapstructure.go @@ -1056,7 +1056,7 @@ func (d *Decoder) decodePtr(name string, data interface{}, val reflect.Value) (b // pointer to be nil as well. isNil := data == nil if !isNil { - switch v := reflect.Indirect(reflect.ValueOf(data)); v.Kind() { + switch v := reflect.ValueOf(data); v.Kind() { case reflect.Chan, reflect.Func, reflect.Interface, diff --git a/mapstructure_test.go b/mapstructure_test.go index d604fcf..60cba9c 100644 --- a/mapstructure_test.go +++ b/mapstructure_test.go @@ -2,9 +2,11 @@ package mapstructure import ( "encoding/json" + "fmt" "io" "reflect" "sort" + "strconv" "strings" "testing" "time" @@ -2821,6 +2823,85 @@ func TestDecoder_IgnoreUntaggedFieldsWithStruct(t *testing.T) { } } +func TestTypedNilPostHooks(t *testing.T) { + type customType1 int + type customType2 float64 + type configType struct { + C *customType1 `mapstructure:"c"` + A *customType2 `mapstructure:"a"` + } + + for i, marshalledConfig := range []map[string]interface{}{ + { + "c": (*customType1)(nil), + "a": (*customType2)(floatPtr(42.42)), + }, + { + "c": "", + "a": "42.42", + }, + } { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + actual := &configType{} + decoder, err := NewDecoder(&DecoderConfig{ + Result: actual, + DecodeHook: ComposeDecodeHookFunc( + func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) { + var enum customType1 + if f.Kind() != reflect.String || t != reflect.TypeOf(&enum) { + return data, nil + } + s := data.(string) + if s == "" { + // Returning an untyped nil here would cause a panic, as `from.Type()` + // is invalid for nil. + return (*customType1)(nil), nil + } + n, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return nil, err + } + enum = customType1(n) + return &enum, nil + }, + func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) { + var enum customType2 + if f.Kind() != reflect.String || t != reflect.TypeOf(&enum) { + return data, nil + } + s := data.(string) + if s == "" { + return (*customType2)(nil), nil + } + n, err := strconv.ParseFloat(s, 64) + if err != nil { + return nil, err + } + enum = customType2(n) + return &enum, nil + }, + ), + }) + if err != nil { + t.Fatal(err) + } + + if err := decoder.Decode(marshalledConfig); err != nil { + t.Fatal(err) + } + + expected := &configType{ + C: nil, + A: (*customType2)(floatPtr(42.42)), + } + + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("Decode() expected: %#v, got: %#v", expected, actual) + } + }) + } +} + func testSliceInput(t *testing.T, input map[string]interface{}, expected *Slice) { var result Slice err := Decode(input, &result)