diff --git a/decode_hooks.go b/decode_hooks.go index 4d4bbc7..3a754ca 100644 --- a/decode_hooks.go +++ b/decode_hooks.go @@ -77,6 +77,28 @@ func ComposeDecodeHookFunc(fs ...DecodeHookFunc) DecodeHookFunc { } } +// OrComposeDecodeHookFunc executes all input hook functions until one of them returns no error. In that case its value is returned. +// If all hooks return an error, OrComposeDecodeHookFunc returns an error concatenating all error messages. +func OrComposeDecodeHookFunc(ff ...DecodeHookFunc) DecodeHookFunc { + return func(a, b reflect.Value) (interface{}, error) { + var allErrs string + var out interface{} + var err error + + for _, f := range ff { + out, err = DecodeHookExec(f, a, b) + if err != nil { + allErrs += err.Error() + "\n" + continue + } + + return out, nil + } + + return nil, errors.New(allErrs) + } +} + // StringToSliceHookFunc returns a DecodeHookFunc that converts // string to []string by splitting on the given sep. func StringToSliceHookFunc(sep string) DecodeHookFunc { diff --git a/decode_hooks_test.go b/decode_hooks_test.go index cfa3c18..bf02952 100644 --- a/decode_hooks_test.go +++ b/decode_hooks_test.go @@ -84,6 +84,94 @@ func TestComposeDecodeHookFunc_kinds(t *testing.T) { } } +func TestOrComposeDecodeHookFunc(t *testing.T) { + f1 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return data.(string) + "foo", nil + } + + f2 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return data.(string) + "bar", nil + } + + f := OrComposeDecodeHookFunc(f1, f2) + + result, err := DecodeHookExec( + f, reflect.ValueOf(""), reflect.ValueOf([]byte(""))) + if err != nil { + t.Fatalf("bad: %s", err) + } + if result.(string) != "foo" { + t.Fatalf("bad: %#v", result) + } +} + +func TestOrComposeDecodeHookFunc_correctValueIsLast(t *testing.T) { + f1 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return nil, errors.New("f1 error") + } + + f2 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return nil, errors.New("f2 error") + } + + f3 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return data.(string) + "bar", nil + } + + f := OrComposeDecodeHookFunc(f1, f2, f3) + + result, err := DecodeHookExec( + f, reflect.ValueOf(""), reflect.ValueOf([]byte(""))) + if err != nil { + t.Fatalf("bad: %s", err) + } + if result.(string) != "bar" { + t.Fatalf("bad: %#v", result) + } +} + +func TestOrComposeDecodeHookFunc_err(t *testing.T) { + f1 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return nil, errors.New("f1 error") + } + + f2 := func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + return nil, errors.New("f2 error") + } + + f := OrComposeDecodeHookFunc(f1, f2) + + _, err := DecodeHookExec( + f, reflect.ValueOf(""), reflect.ValueOf([]byte(""))) + if err == nil { + t.Fatalf("bad: should return an error") + } + if err.Error() != "f1 error\nf2 error\n" { + t.Fatalf("bad: %s", err) + } +} + func TestComposeDecodeHookFunc_safe_nofuncs(t *testing.T) { f := ComposeDecodeHookFunc() type myStruct2 struct {