From e8e52d2a0afd3a24f986c0b09760bbb6ef11a6d3 Mon Sep 17 00:00:00 2001 From: kinggo Date: Thu, 8 Dec 2022 20:18:02 +0800 Subject: [PATCH] feat: support custom formvalue function --- server.go | 68 +++++++++++++++++++++++++++++++++++++++----------- server_test.go | 36 ++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 14 deletions(-) diff --git a/server.go b/server.go index 0be703c77d..a513c5b7be 100644 --- a/server.go +++ b/server.go @@ -406,6 +406,12 @@ type Server struct { // instead. TLSConfig *tls.Config + // FormValueFunc, which is used by RequestCtx.FormValue and support for customising + // the behaviour of the RequestCtx.FormValue function. + // + // NetHttpFormValueFunc gives a FormValueFunc func implementation that is consistent with net/http. + FormValueFunc FormValueFunc + nextProtos map[string]ServeHandler concurrency uint32 @@ -604,6 +610,7 @@ type RequestCtx struct { hijackHandler HijackHandler hijackNoResponse bool + formValueFunc FormValueFunc } // HijackHandler must process the hijacked connection c. @@ -1108,23 +1115,54 @@ func SaveMultipartFile(fh *multipart.FileHeader, path string) (err error) { // // The returned value is valid until your request handler returns. func (ctx *RequestCtx) FormValue(key string) []byte { - v := ctx.QueryArgs().Peek(key) - if len(v) > 0 { - return v + if ctx.formValueFunc != nil { + return ctx.formValueFunc(ctx, key) } - v = ctx.PostArgs().Peek(key) - if len(v) > 0 { - return v + return defaultFormValue(ctx, key) +} + +type FormValueFunc func(*RequestCtx, string) []byte + +var ( + defaultFormValue = func(ctx *RequestCtx, key string) []byte { + v := ctx.QueryArgs().Peek(key) + if len(v) > 0 { + return v + } + v = ctx.PostArgs().Peek(key) + if len(v) > 0 { + return v + } + mf, err := ctx.MultipartForm() + if err == nil && mf.Value != nil { + vv := mf.Value[key] + if len(vv) > 0 { + return []byte(vv[0]) + } + } + return nil } - mf, err := ctx.MultipartForm() - if err == nil && mf.Value != nil { - vv := mf.Value[key] - if len(vv) > 0 { - return []byte(vv[0]) + + // NetHttpFormValueFunc gives consistent behavior with net/http. POST and PUT body parameters take precedence over URL query string values. + NetHttpFormValueFunc = func(ctx *RequestCtx, key string) []byte { + v := ctx.PostArgs().Peek(key) + if len(v) > 0 { + return v + } + mf, err := ctx.MultipartForm() + if err == nil && mf.Value != nil { + vv := mf.Value[key] + if len(vv) > 0 { + return []byte(vv[0]) + } + } + v = ctx.QueryArgs().Peek(key) + if len(v) > 0 { + return v } + return nil } - return nil -} +) // IsGet returns true if request method is GET. func (ctx *RequestCtx) IsGet() bool { @@ -2638,7 +2676,9 @@ func (s *Server) acquireCtx(c net.Conn) (ctx *RequestCtx) { } else { ctx = v.(*RequestCtx) } - + if s.FormValueFunc != nil { + ctx.formValueFunc = s.FormValueFunc + } ctx.c = c return ctx diff --git a/server_test.go b/server_test.go index 0c44a23f12..7d96e93cf3 100644 --- a/server_test.go +++ b/server_test.go @@ -1713,6 +1713,20 @@ func TestRequestCtxFormValue(t *testing.T) { } } +func TestSetStandardFormValueFunc(t *testing.T) { + t.Parallel() + var ctx RequestCtx + var req Request + req.SetRequestURI("/foo/bar?aaa=bbb") + req.SetBodyString("aaa=port") + req.Header.SetContentType("application/x-www-form-urlencoded") + ctx.Init(&req, nil, nil) + ctx.formValueFunc = NetHttpFormValueFunc + v := ctx.FormValue("aaa") + if string(v) != "port" { + t.Fatalf("unexpected value %q. Expecting %q", v, "port") + } +} func TestRequestCtxUserValue(t *testing.T) { t.Parallel() @@ -3287,6 +3301,28 @@ func TestServeConnSingleRequest(t *testing.T) { verifyResponse(t, br, 200, "aaa", "requestURI=/foo/bar?baz, host=google.com") } +func TestServerSetFormValueFunc(t *testing.T) { + t.Parallel() + s := &Server{ + Handler: func(ctx *RequestCtx) { + ctx.Success("aaa", ctx.FormValue("aaa")) + }, + FormValueFunc: func(ctx *RequestCtx, s string) []byte { + return []byte(s) + }, + } + + rw := &readWriter{} + rw.r.WriteString("GET /foo/bar?baz HTTP/1.1\r\nHost: google.com\r\n\r\n") + + if err := s.ServeConn(rw); err != nil { + t.Fatalf("Unexpected error from serveConn: %v", err) + } + + br := bufio.NewReader(&rw.w) + verifyResponse(t, br, 200, "aaa", "aaa") +} + func TestServeConnMultiRequests(t *testing.T) { t.Parallel()