From 4ed933a2e770d457edba006bee30586c93d055df Mon Sep 17 00:00:00 2001 From: Meng Date: Fri, 18 Jun 2021 19:43:29 +0800 Subject: [PATCH] fix: set content-length properly when StreanRequestBody was enabled (#1049) * fix: set content-length properly when StreanRequestBody was enabled * fix: add test cases for validating content length of streaming request --- http.go | 2 +- server_test.go | 42 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/http.go b/http.go index 7b3c51c4cf..2511819b00 100644 --- a/http.go +++ b/http.go @@ -1185,7 +1185,7 @@ func (req *Request) ContinueReadBodyStream(r *bufio.Reader, maxBodySize int, pre req.body = bodyBuf req.bodyStream = acquireRequestStream(bodyBuf, r, contentLength) - req.Header.SetContentLength(len(bodyBuf.B)) + req.Header.SetContentLength(contentLength) return nil } diff --git a/server_test.go b/server_test.go index b8cf7eebd4..c7b8f485c1 100644 --- a/server_test.go +++ b/server_test.go @@ -1108,7 +1108,6 @@ Host: asbd Connection: close ` - ln := fasthttputil.NewInmemoryListener() s := &Server{ @@ -3578,6 +3577,47 @@ func TestStreamRequestBodyExceedMaxSize(t *testing.T) { } } +func TestStreamBodyReqestContentLength(t *testing.T) { + t.Parallel() + content := strings.Repeat("1", 1<<15) // 32K + contentLength := len(content) + + s := &Server{ + Handler: func(ctx *RequestCtx) { + realContentLength := ctx.Request.Header.ContentLength() + if realContentLength != contentLength { + t.Fatal("incorrect content length") + } + }, + MaxRequestBodySize: 1 * 1024 * 1024, // 1M + StreamRequestBody: true, + } + + pipe := fasthttputil.NewPipeConns() + cc, sc := pipe.Conn1(), pipe.Conn2() + if _, err := cc.Write([]byte(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n%s", contentLength, content))); err != nil { + t.Fatal(err) + } + + ch := make(chan error) + go func() { + ch <- s.ServeConn(sc) + }() + + if err := sc.Close(); err != nil { + t.Fatal(err) + } + + select { + case err := <-ch: + if err == nil || err.Error() != "connection closed" { // fasthttputil.errConnectionClosed is private so do a string match. + t.Fatalf("Unexpected error from serveConn: %s", err) + } + case <-time.After(time.Second): + t.Fatal("test timeout") + } +} + func checkReader(t *testing.T, r io.Reader, expected string) { b := make([]byte, len(expected)) if _, err := io.ReadFull(r, b); err != nil {