From 2da917e49e599c2a14e2ad953a247783403251cf Mon Sep 17 00:00:00 2001 From: Arno Geurts Date: Mon, 7 Dec 2020 15:44:42 +0100 Subject: [PATCH] Ability to use encoding.TextUnmarshaler --- decode_hooks.go | 24 ++++++++++++++++++++++++ decode_hooks_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/decode_hooks.go b/decode_hooks.go index 77578d00..92e6f76f 100644 --- a/decode_hooks.go +++ b/decode_hooks.go @@ -1,6 +1,7 @@ package mapstructure import ( + "encoding" "errors" "fmt" "net" @@ -230,3 +231,26 @@ func RecursiveStructToMapHookFunc() DecodeHookFunc { return f.Interface(), nil } } + +// TextUnmarshallerHookFunc returns a DecodeHookFunc that applies +// strings to the UnmarshalText function, when the target type +// implements the encoding.TextUnmarshaler interface +func TextUnmarshallerHookFunc() DecodeHookFuncType { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + if f.Kind() != reflect.String { + return data, nil + } + result := reflect.New(t).Interface() + unmarshaller, ok := result.(encoding.TextUnmarshaler) + if !ok { + return data, nil + } + if err := unmarshaller.UnmarshalText([]byte(data.(string))); err != nil { + return nil, err + } + return result, nil + } +} diff --git a/decode_hooks_test.go b/decode_hooks_test.go index d6902534..b3165bc9 100644 --- a/decode_hooks_test.go +++ b/decode_hooks_test.go @@ -2,6 +2,7 @@ package mapstructure import ( "errors" + "math/big" "net" "reflect" "testing" @@ -419,3 +420,28 @@ func TestStructToMapHookFuncTabled(t *testing.T) { } } + +func TestTextUnmarshallerHookFunc(t *testing.T) { + cases := []struct { + f, t reflect.Value + result interface{} + err bool + }{ + {reflect.ValueOf("42"), reflect.ValueOf(big.Int{}), big.NewInt(42), false}, + {reflect.ValueOf("invalid"), reflect.ValueOf(big.Int{}), nil, true}, + {reflect.ValueOf("5"), reflect.ValueOf("5"), "5", false}, + } + + for i, tc := range cases { + f := TextUnmarshallerHookFunc() + actual, err := DecodeHookExec(f, tc.f, tc.t) + if tc.err != (err != nil) { + t.Fatalf("case %d: expected err %#v", i, tc.err) + } + if !reflect.DeepEqual(actual, tc.result) { + t.Fatalf( + "case %d: expected %#v, got %#v", + i, tc.result, actual) + } + } +}