Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support custom formvalue function #1453

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
68 changes: 54 additions & 14 deletions server.go
Expand Up @@ -406,6 +406,12 @@ type Server struct {
// instead.
TLSConfig *tls.Config

// FormValueFunc, which is used by RequestCtx.FormValue and support for customising
// the behaviour of the RequestCtx.FormValue function.
//
// NetHttpFormValueFunc gives a FormValueFunc func implementation that is consistent with net/http.
FormValueFunc FormValueFunc

nextProtos map[string]ServeHandler

concurrency uint32
Expand Down Expand Up @@ -604,6 +610,7 @@ type RequestCtx struct {

hijackHandler HijackHandler
hijackNoResponse bool
formValueFunc FormValueFunc
}

// HijackHandler must process the hijacked connection c.
Expand Down Expand Up @@ -1108,23 +1115,54 @@ func SaveMultipartFile(fh *multipart.FileHeader, path string) (err error) {
//
// The returned value is valid until your request handler returns.
func (ctx *RequestCtx) FormValue(key string) []byte {
v := ctx.QueryArgs().Peek(key)
if len(v) > 0 {
return v
if ctx.formValueFunc != nil {
return ctx.formValueFunc(ctx, key)
}
v = ctx.PostArgs().Peek(key)
if len(v) > 0 {
return v
return defaultFormValue(ctx, key)
}

type FormValueFunc func(*RequestCtx, string) []byte

var (
defaultFormValue = func(ctx *RequestCtx, key string) []byte {
v := ctx.QueryArgs().Peek(key)
if len(v) > 0 {
return v
}
v = ctx.PostArgs().Peek(key)
if len(v) > 0 {
return v
}
mf, err := ctx.MultipartForm()
if err == nil && mf.Value != nil {
vv := mf.Value[key]
if len(vv) > 0 {
return []byte(vv[0])
}
}
return nil
}
mf, err := ctx.MultipartForm()
if err == nil && mf.Value != nil {
vv := mf.Value[key]
if len(vv) > 0 {
return []byte(vv[0])

// NetHttpFormValueFunc gives consistent behavior with net/http. POST and PUT body parameters take precedence over URL query string values.
NetHttpFormValueFunc = func(ctx *RequestCtx, key string) []byte {
v := ctx.PostArgs().Peek(key)
if len(v) > 0 {
return v
}
mf, err := ctx.MultipartForm()
if err == nil && mf.Value != nil {
vv := mf.Value[key]
if len(vv) > 0 {
return []byte(vv[0])
}
}
v = ctx.QueryArgs().Peek(key)
if len(v) > 0 {
return v
}
return nil
}
return nil
}
)

// IsGet returns true if request method is GET.
func (ctx *RequestCtx) IsGet() bool {
Expand Down Expand Up @@ -2638,7 +2676,9 @@ func (s *Server) acquireCtx(c net.Conn) (ctx *RequestCtx) {
} else {
ctx = v.(*RequestCtx)
}

if s.FormValueFunc != nil {
ctx.formValueFunc = s.FormValueFunc
}
ctx.c = c

return ctx
Expand Down
36 changes: 36 additions & 0 deletions server_test.go
Expand Up @@ -1713,6 +1713,20 @@ func TestRequestCtxFormValue(t *testing.T) {
}
}

func TestSetStandardFormValueFunc(t *testing.T) {
t.Parallel()
var ctx RequestCtx
var req Request
req.SetRequestURI("/foo/bar?aaa=bbb")
req.SetBodyString("aaa=port")
req.Header.SetContentType("application/x-www-form-urlencoded")
ctx.Init(&req, nil, nil)
ctx.formValueFunc = NetHttpFormValueFunc
v := ctx.FormValue("aaa")
if string(v) != "port" {
t.Fatalf("unexpected value %q. Expecting %q", v, "port")
}
}
func TestRequestCtxUserValue(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -3287,6 +3301,28 @@ func TestServeConnSingleRequest(t *testing.T) {
verifyResponse(t, br, 200, "aaa", "requestURI=/foo/bar?baz, host=google.com")
}

func TestServerSetFormValueFunc(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Success("aaa", ctx.FormValue("aaa"))
},
FormValueFunc: func(ctx *RequestCtx, s string) []byte {
return []byte(s)
},
}

rw := &readWriter{}
rw.r.WriteString("GET /foo/bar?baz HTTP/1.1\r\nHost: google.com\r\n\r\n")

if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}

br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 200, "aaa", "aaa")
}

func TestServeConnMultiRequests(t *testing.T) {
t.Parallel()

Expand Down