From 258a4c17b4f451f9a1f9bae3f34d3f5d6e9e0bfd Mon Sep 17 00:00:00 2001 From: Sergio VS Date: Thu, 16 Dec 2021 05:27:02 +0100 Subject: [PATCH] fix: reset response after reset user values on keep-alive connections (#1176) --- server.go | 17 ++++++++++++----- server_test.go | 14 ++++++++++++-- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/server.go b/server.go index 928702cb34..edbd6999fb 100644 --- a/server.go +++ b/server.go @@ -2079,7 +2079,7 @@ func (s *Server) serveConn(c net.Conn) (err error) { connectionClose bool isHTTP11 bool - reqReset bool + reqReset, respReset bool continueReadingRequest bool = true ) for { @@ -2118,7 +2118,7 @@ func (s *Server) serveConn(c net.Conn) (err error) { br, err = acquireByteReader(&ctx) } - reqReset = false + reqReset, respReset = false, false ctx.Request.isTLS = isTLS ctx.Response.Header.noDefaultContentType = s.NoDefaultContentType ctx.Response.Header.noDefaultDate = s.NoDefaultDate @@ -2403,8 +2403,9 @@ func (s *Server) serveConn(c net.Conn) (err error) { s.setState(c, StateIdle) ctx.userValues.Reset() - reqReset = true + reqReset, respReset = true, true ctx.Request.Reset() + ctx.Response.Reset() if atomic.LoadInt32(&s.stop) == 1 { err = nil @@ -2420,11 +2421,14 @@ func (s *Server) serveConn(c net.Conn) (err error) { } if ctx != nil { // in unexpected cases the for loop will break - // before request reset call. in such cases, call it before + // before request/response reset call. in such cases, call it before // release to fix #548 if !reqReset { ctx.Request.Reset() } + if !respReset { + ctx.Response.Reset() + } s.releaseCtx(ctx) } return @@ -2511,7 +2515,7 @@ func writeResponse(ctx *RequestCtx, w *bufio.Writer) error { panic("BUG: cannot write timed out response") } err := ctx.Response.Write(w) - ctx.Response.Reset() + return err } @@ -2796,8 +2800,11 @@ func (s *Server) writeErrorResponse(bw *bufio.Writer, ctx *RequestCtx, serverNam if bw == nil { bw = acquireWriter(ctx) } + writeResponse(ctx, bw) //nolint:errcheck + ctx.Response.Reset() bw.Flush() + return bw } diff --git a/server_test.go b/server_test.go index 2f7e6aef72..0a199b8e28 100644 --- a/server_test.go +++ b/server_test.go @@ -14,6 +14,7 @@ import ( "net" "os" "reflect" + "regexp" "strings" "sync" "testing" @@ -2041,23 +2042,28 @@ func TestRequestCtxWriteString(t *testing.T) { } } -func TestServeConnKeepRequestValuesUntilResetUserValues(t *testing.T) { +func TestServeConnKeepRequestAndResponseUntilResetUserValues(t *testing.T) { t.Parallel() reqStr := "POST /foo HTTP/1.0\r\nHost: google.com\r\nContent-Type: application/octet-stream\r\nContent-Length: 0\r\nConnection: keep-alive\r\n\r\n" + respRegex := regexp.MustCompile("HTTP/1.1 308 Permanent Redirect\r\nServer: fasthttp\r\nDate: (.*)\r\nContent-Length: 0\r\nConnection: keep-alive\r\n\r\n") rw := &readWriter{} rw.r.WriteString(reqStr) - var resultReqStr string + var resultReqStr, resultRespStr string ch := make(chan struct{}) go func() { err := ServeConn(rw, func(ctx *RequestCtx) { + ctx.Response.SetStatusCode(StatusPermanentRedirect) + ctx.SetUserValue("myKey", &closerWithRequestCtx{ ctx: ctx, closeFunc: func(closerCtx *RequestCtx) error { resultReqStr = closerCtx.Request.String() + resultRespStr = closerCtx.Response.String() + return nil }}) }) @@ -2076,6 +2082,10 @@ func TestServeConnKeepRequestValuesUntilResetUserValues(t *testing.T) { if resultReqStr != reqStr { t.Errorf("Request == %s, want %s", resultReqStr, reqStr) } + + if !respRegex.MatchString(resultRespStr) { + t.Errorf("Response == %s, want regex %s", resultRespStr, respRegex) + } } func TestServeConnNonHTTP11KeepAlive(t *testing.T) {