From de9733b1603098f6d5029df6712f7794f3fde0da Mon Sep 17 00:00:00 2001 From: Timur Mazitov <65964304+mazitovt@users.noreply.github.com> Date: Wed, 29 Jun 2022 20:15:12 +0500 Subject: [PATCH] fix: support encoding.TextMarshaler in StyleParamWithLocation (#634) Co-authored-by: Marcin Romaszewicz <47459980+deepmap-marcinr@users.noreply.github.com> --- pkg/runtime/styleparam.go | 20 ++++++++++++++++++++ pkg/runtime/styleparam_test.go | 15 +++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/pkg/runtime/styleparam.go b/pkg/runtime/styleparam.go index 6c01def1b..6de67cf8a 100644 --- a/pkg/runtime/styleparam.go +++ b/pkg/runtime/styleparam.go @@ -14,6 +14,7 @@ package runtime import ( + "encoding" "errors" "fmt" "net/url" @@ -62,6 +63,25 @@ func StyleParamWithLocation(style string, explode bool, paramName string, paramL t = v.Type() } + // If the value implements encoding.TextMarshaler we use it for marshaling + // https://github.com/deepmap/oapi-codegen/issues/504 + if tu, ok := value.(encoding.TextMarshaler); ok { + t := reflect.Indirect(reflect.ValueOf(value)).Type() + convertableToTime := t.ConvertibleTo(reflect.TypeOf(time.Time{})) + convertableToDate := t.ConvertibleTo(reflect.TypeOf(types.Date{})) + + // Since both time.Time and types.Date implement encoding.TextMarshaler + // we should avoid calling theirs MarshalText() + if !convertableToTime && !convertableToDate { + b, err := tu.MarshalText() + if err != nil { + return "", fmt.Errorf("error marshaling '%s' as text: %s", value, err) + } + + return stylePrimitive(style, explode, paramName, paramLocation, string(b)) + } + } + switch t.Kind() { case reflect.Slice: n := v.Len() diff --git a/pkg/runtime/styleparam_test.go b/pkg/runtime/styleparam_test.go index ff7e52d7c..72051a78d 100644 --- a/pkg/runtime/styleparam_test.go +++ b/pkg/runtime/styleparam_test.go @@ -14,6 +14,7 @@ package runtime import ( + "github.com/google/uuid" "testing" "time" @@ -676,4 +677,18 @@ func TestStyleParam(t *testing.T) { result, err = StyleParamWithLocation("simple", false, "id", ParamLocationQuery, object3) assert.NoError(t, err) assert.EqualValues(t, "date_field,1996-03-19,time_field,1996-03-19T00%3A00%3A00Z,uuid_field,baa07328-452e-40bd-aa2e-fa823ec13605", result) + + // Test handling of struct that implement encoding.TextMarshaler + timeVal = time.Date(1996, time.March, 19, 0, 0, 0, 0, time.UTC) + + result, err = StyleParamWithLocation("simple", false, "id", ParamLocationQuery, timeVal) + assert.NoError(t, err) + assert.EqualValues(t, "1996-03-19T00%3A00%3A00Z", result) + + uuidD := uuid.MustParse("972beb41-e5ea-4b31-a79a-96f4999d8769") + + result, err = StyleParamWithLocation("simple", false, "id", ParamLocationQuery, uuidD) + assert.NoError(t, err) + assert.EqualValues(t, "972beb41-e5ea-4b31-a79a-96f4999d8769", result) + }