diff --git a/server.go b/server.go index 08d4dc6adc..c6a292305f 100644 --- a/server.go +++ b/server.go @@ -777,12 +777,44 @@ func (ctx *RequestCtx) Conn() net.Conn { return ctx.c } +func (ctx *RequestCtx) reset() { + ctx.userValues.Reset() + ctx.Request.Reset() + ctx.Response.Reset() + ctx.fbr.reset() + + ctx.connID = 0 + ctx.connRequestNum = 0 + ctx.connTime = zeroTime + ctx.remoteAddr = nil + ctx.time = zeroTime + ctx.s = nil + ctx.c = nil + + if ctx.timeoutResponse != nil { + ctx.timeoutResponse.Reset() + } + + if ctx.timeoutTimer != nil { + stopTimer(ctx.timeoutTimer) + } + + ctx.hijackHandler = nil + ctx.hijackNoResponse = false +} + type firstByteReader struct { c net.Conn ch byte byteRead bool } +func (r *firstByteReader) reset() { + r.c = nil + r.ch = 0 + r.byteRead = false +} + func (r *firstByteReader) Read(b []byte) (int, error) { if len(b) == 0 { return 0, nil @@ -2084,7 +2116,6 @@ func (s *Server) serveConn(c net.Conn) (err error) { connectionClose bool isHTTP11 bool - reqReset, respReset bool continueReadingRequest bool = true ) for { @@ -2123,7 +2154,6 @@ func (s *Server) serveConn(c net.Conn) (err error) { br, err = acquireByteReader(&ctx) } - reqReset, respReset = false, false ctx.Request.isTLS = isTLS ctx.Response.Header.noDefaultContentType = s.NoDefaultContentType ctx.Response.Header.noDefaultDate = s.NoDefaultDate @@ -2375,13 +2405,9 @@ func (s *Server) serveConn(c net.Conn) (err error) { if hijackHandler != nil { var hjr io.Reader = c - hctx := ctx if br != nil { hjr = br br = nil - - // br may point to ctx.fbr, so do not return ctx into pool below. - ctx = nil } if bw != nil { err = bw.Flush() @@ -2395,7 +2421,7 @@ func (s *Server) serveConn(c net.Conn) (err error) { if err != nil { break } - go hijackConnHandler(hctx, hjr, c, s, hijackHandler) + go hijackConnHandler(ctx, hjr, c, s, hijackHandler) err = errHijacked break } @@ -2408,8 +2434,6 @@ func (s *Server) serveConn(c net.Conn) (err error) { s.setState(c, StateIdle) ctx.userValues.Reset() - - reqReset, respReset = true, true ctx.Request.Reset() ctx.Response.Reset() @@ -2425,18 +2449,10 @@ func (s *Server) serveConn(c net.Conn) (err error) { if bw != nil { releaseWriter(s, bw) } - if ctx != nil { - // in unexpected cases the for loop will break - // before request/response reset call. in such cases, call it before - // release to fix #548 - if !reqReset { - ctx.Request.Reset() - } - if !respReset { - ctx.Response.Reset() - } + if hijackHandler == nil { s.releaseCtx(ctx) } + return } @@ -2458,7 +2474,7 @@ func hijackConnHandler(ctx *RequestCtx, r io.Reader, c net.Conn, s *Server, h Hi c.Close() s.releaseHijackConn(hjc) } - ctx.ResetUserValues() + s.releaseCtx(ctx) } func (s *Server) acquireHijackConn(r io.Reader, c net.Conn) *hijackConn { @@ -2603,17 +2619,19 @@ func releaseWriter(s *Server, w *bufio.Writer) { func (s *Server) acquireCtx(c net.Conn) (ctx *RequestCtx) { v := s.ctxPool.Get() if v == nil { - ctx = &RequestCtx{ - s: s, - } keepBodyBuffer := !s.ReduceMemoryUsage + + ctx = new(RequestCtx) ctx.Request.keepBodyBuffer = keepBodyBuffer ctx.Response.keepBodyBuffer = keepBodyBuffer } else { ctx = v.(*RequestCtx) } + + ctx.s = s ctx.c = c - return + + return ctx } // Init2 prepares ctx for passing to RequestHandler. @@ -2736,10 +2754,8 @@ func (s *Server) releaseCtx(ctx *RequestCtx) { if ctx.timeoutResponse != nil { panic("BUG: cannot release timed out RequestCtx") } - ctx.c = nil - ctx.remoteAddr = nil - ctx.fbr.c = nil - ctx.userValues.Reset() + + ctx.reset() s.ctxPool.Put(ctx) } diff --git a/server_test.go b/server_test.go index 47eb84d37b..cad3cf1a18 100644 --- a/server_test.go +++ b/server_test.go @@ -7,6 +7,7 @@ import ( "bytes" "context" "crypto/tls" + "errors" "fmt" "io" "io/ioutil" @@ -2088,6 +2089,63 @@ func TestServeConnKeepRequestAndResponseUntilResetUserValues(t *testing.T) { } } +// TestServerErrorHandler tests unexpected cases the for loop will break +// before request/response reset call. in such cases, call it before +// release to fix #548. +func TestServerErrorHandler(t *testing.T) { + t.Parallel() + + var resultReqStr, resultRespStr string + + s := &Server{ + Handler: func(ctx *RequestCtx) {}, + ErrorHandler: func(ctx *RequestCtx, err error) { + resultReqStr = ctx.Request.String() + resultRespStr = ctx.Response.String() + }, + MaxRequestBodySize: 10, + } + + reqStrTpl := "POST %s HTTP/1.1\r\nHost: example.com\r\nContent-Type: application/octet-stream\r\nContent-Length: %d\r\nConnection: keep-alive\r\n\r\n" + respRegex := regexp.MustCompile("HTTP/1.1 200 OK\r\nDate: (.*)\r\nContent-Length: 0\r\n\r\n") + + rw := &readWriter{} + + for i := 0; i < 100; i++ { + body := strings.Repeat("@", s.MaxRequestBodySize+1) + path := fmt.Sprintf("/%d", i) + + reqStr := fmt.Sprintf(reqStrTpl, path, len(body)) + expectedReqStr := fmt.Sprintf(reqStrTpl, path, 0) + + rw.r.WriteString(reqStr) + rw.r.WriteString(body) + + ch := make(chan struct{}) + go func() { + err := s.ServeConn(rw) + if err != nil && !errors.Is(err, ErrBodyTooLarge) { + t.Errorf("unexpected error in ServeConn: %s", err) + } + close(ch) + }() + + select { + case <-ch: + case <-time.After(time.Second): + t.Fatal("timeout") + } + + if resultReqStr != expectedReqStr { + t.Errorf("[iter: %d] Request == %s, want %s", i, resultReqStr, reqStr) + } + + if !respRegex.MatchString(resultRespStr) { + t.Errorf("[iter: %d] Response == %s, want regex %s", i, resultRespStr, respRegex) + } + } +} + func TestServeConnHijackResetUserValues(t *testing.T) { t.Parallel() @@ -2369,71 +2427,131 @@ func TestRequestCtxSendFile(t *testing.T) { } } -func TestRequestCtxHijack(t *testing.T) { - t.Parallel() +func testRequestCtxHijack(t *testing.T, s *Server) { + t.Helper() - hijackStartCh := make(chan struct{}) - hijackStopCh := make(chan struct{}) - s := &Server{ - Handler: func(ctx *RequestCtx) { - if ctx.Hijacked() { - t.Error("connection mustn't be hijacked") - } - ctx.Hijack(func(c net.Conn) { - <-hijackStartCh + type hijackSignal struct { + id int + rw *readWriter + } - b := make([]byte, 1) - // ping-pong echo via hijacked conn - for { - n, err := c.Read(b) - if n != 1 { - if err == io.EOF { - close(hijackStopCh) - return - } - if err != nil { - t.Errorf("unexpected error: %s", err) - } - t.Errorf("unexpected number of bytes read: %d. Expecting 1", n) - } - if _, err = c.Write(b); err != nil { - t.Errorf("unexpected error when writing data: %s", err) + wg := sync.WaitGroup{} + totalConns := 100 + hijackStartCh := make(chan *hijackSignal, totalConns) + hijackStopCh := make(chan *hijackSignal, totalConns) + + s.Handler = func(ctx *RequestCtx) { + if ctx.Hijacked() { + t.Error("connection mustn't be hijacked") + } + + ctx.Hijack(func(c net.Conn) { + signal := <-hijackStartCh + defer func() { + hijackStopCh <- signal + wg.Done() + }() + + b := make([]byte, 1) + stop := false + + // ping-pong echo via hijacked conn + for !stop { + n, err := c.Read(b) + if err != nil { + if errors.Is(err, io.EOF) { + stop = true + + continue } + + t.Errorf("unexpected read error: %s", err) + } else if n != 1 { + t.Errorf("unexpected number of bytes read: %d. Expecting 1", n) + } + + if _, err = c.Write(b); err != nil { + t.Errorf("unexpected error when writing data: %s", err) } - }) - if !ctx.Hijacked() { - t.Error("connection must be hijacked") } - ctx.Success("foo/bar", []byte("hijack it!")) - }, + }) + + if !ctx.Hijacked() { + t.Error("connection must be hijacked") + } + + ctx.Success("foo/bar", []byte("hijack it!")) } hijackedString := "foobar baz hijacked!!!" - rw := &readWriter{} - rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") - rw.r.WriteString(hijackedString) - if err := s.ServeConn(rw); err != nil { - t.Fatalf("Unexpected error from serveConn: %s", err) - } + for i := 0; i < totalConns; i++ { + wg.Add(1) - br := bufio.NewReader(&rw.w) - verifyResponse(t, br, StatusOK, "foo/bar", "hijack it!") + go func(t *testing.T, id int) { + t.Helper() - close(hijackStartCh) - select { - case <-hijackStopCh: - case <-time.After(100 * time.Millisecond): - t.Fatal("timeout") - } + rw := new(readWriter) + rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") + rw.r.WriteString(hijackedString) - data, err := ioutil.ReadAll(br) - if err != nil { - t.Fatalf("Unexpected error when reading remaining data: %s", err) + if err := s.ServeConn(rw); err != nil { + t.Errorf("[iter: %d] Unexpected error from serveConn: %s", id, err) + } + + hijackStartCh <- &hijackSignal{id, rw} + }(t, i) } - if string(data) != hijackedString { - t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, hijackedString) + + wg.Wait() + + count := 0 + for count != totalConns { + select { + case signal := <-hijackStopCh: + count++ + + id := signal.id + rw := signal.rw + + br := bufio.NewReader(&rw.w) + verifyResponse(t, br, StatusOK, "foo/bar", "hijack it!") + + data, err := ioutil.ReadAll(br) + if err != nil { + t.Errorf("[iter: %d] Unexpected error when reading remaining data: %s", id, err) + + return + } + if string(data) != hijackedString { + t.Errorf( + "[iter: %d] Unexpected response %s. Expecting %s", + id, data, hijackedString, + ) + + return + } + case <-time.After(200 * time.Millisecond): + t.Errorf("timeout") + } } + + close(hijackStartCh) + close(hijackStopCh) +} + +func TestRequestCtxHijack(t *testing.T) { + t.Parallel() + + testRequestCtxHijack(t, &Server{}) +} + +func TestRequestCtxHijackReduceMemoryUsage(t *testing.T) { + t.Parallel() + + testRequestCtxHijack(t, &Server{ + ReduceMemoryUsage: true, + }) } func TestRequestCtxHijackNoResponse(t *testing.T) {