From 1a5f2f40c6a75d42c5c5935efc73f79644cd5b73 Mon Sep 17 00:00:00 2001 From: ArminBTVS <98586621+ArminBTVS@users.noreply.github.com> Date: Mon, 14 Mar 2022 10:53:16 +0100 Subject: [PATCH] Read response when client closes connection #1232 (#1233) * Read response when client closes connection #1232 * Fix edge case were client responds with invalid header * Follow linter suggestions for tests * Changes after review * Reafactor error check after review * Handle connection reset on windows * Remove format string from test where not needed * Run connection reset tests not on Windows --- client.go | 16 +++--- client_test.go | 2 +- client_unix_test.go | 136 ++++++++++++++++++++++++++++++++++++++++++++ http.go | 6 ++ tcp.go | 13 +++++ tcp_windows.go | 13 +++++ 6 files changed, 177 insertions(+), 9 deletions(-) create mode 100644 client_unix_test.go create mode 100644 tcp.go create mode 100644 tcp_windows.go diff --git a/client.go b/client.go index b36ca40824..2eaab39d4e 100644 --- a/client.go +++ b/client.go @@ -1437,12 +1437,12 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error) if err == nil { err = bw.Flush() } - if err != nil { - c.releaseWriter(bw) + c.releaseWriter(bw) + isConnRST := isConnectionReset(err) + if err != nil && !isConnRST { c.closeConn(cc) return true, err } - c.releaseWriter(bw) if c.ReadTimeout > 0 { // Set Deadline every time, since golang has fixed the performance issue @@ -1462,22 +1462,22 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error) } br := c.acquireReader(conn) - if err = resp.ReadLimitBody(br, c.MaxResponseBodySize); err != nil { - c.releaseReader(br) + err = resp.ReadLimitBody(br, c.MaxResponseBodySize) + c.releaseReader(br) + if err != nil { c.closeConn(cc) // Don't retry in case of ErrBodyTooLarge since we will just get the same again. retry := err != ErrBodyTooLarge return retry, err } - c.releaseReader(br) - if resetConnection || req.ConnectionClose() || resp.ConnectionClose() { + if resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST { c.closeConn(cc) } else { c.releaseConn(cc) } - return false, err + return false, nil } var ( diff --git a/client_test.go b/client_test.go index 78eda2e58e..30c8be7be2 100644 --- a/client_test.go +++ b/client_test.go @@ -2837,6 +2837,6 @@ func TestHttpsRequestWithoutParsedURL(t *testing.T) { _, err := client.doNonNilReqResp(req, &Response{}) if err != nil { - t.Fatalf("https requests with IsTLS client must succeed") + t.Fatal("https requests with IsTLS client must succeed") } } diff --git a/client_unix_test.go b/client_unix_test.go new file mode 100644 index 0000000000..f369111d5c --- /dev/null +++ b/client_unix_test.go @@ -0,0 +1,136 @@ +//go:build !windows +// +build !windows + +package fasthttp + +import ( + "io" + "io/ioutil" + "net" + "net/http" + "strings" + "testing" +) + +// See issue #1232 +func TestRstConnResponseWhileSending(t *testing.T) { + const expectedStatus = http.StatusTeapot + const payload = "payload" + + srv, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer srv.Close() + + go func() { + for { + conn, err := srv.Accept() + if err != nil { + return + } + + // Read at least one byte of the header + // Otherwise we would have an unsolicited response + _, err = ioutil.ReadAll(io.LimitReader(conn, 1)) + if err != nil { + t.Error(err) + } + + // Respond + _, err = conn.Write([]byte("HTTP/1.1 418 Teapot\r\n\r\n")) + if err != nil { + t.Error(err) + } + + // Forcefully close connection + err = conn.(*net.TCPConn).SetLinger(0) + if err != nil { + t.Error(err) + } + conn.Close() + } + }() + + svrUrl := "http://" + srv.Addr().String() + client := HostClient{Addr: srv.Addr().String()} + + for i := 0; i < 100; i++ { + req := AcquireRequest() + defer ReleaseRequest(req) + resp := AcquireResponse() + defer ReleaseResponse(resp) + + req.Header.SetMethod("POST") + req.SetBodyStream(strings.NewReader(payload), len(payload)) + req.SetRequestURI(svrUrl) + + err = client.Do(req, resp) + if err != nil { + t.Fatal(err) + } + if expectedStatus != resp.StatusCode() { + t.Fatalf("Expected %d status code, but got %d", expectedStatus, resp.StatusCode()) + } + } +} + +// See issue #1232 +func TestRstConnClosedWithoutResponse(t *testing.T) { + const payload = "payload" + + srv, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer srv.Close() + + go func() { + for { + conn, err := srv.Accept() + if err != nil { + return + } + + // Read at least one byte of the header + // Otherwise we would have an unsolicited response + _, err = ioutil.ReadAll(io.LimitReader(conn, 1)) + if err != nil { + t.Error(err) + } + + // Respond with incomplete header + _, err = conn.Write([]byte("Http")) + if err != nil { + t.Error(err) + } + + // Forcefully close connection + err = conn.(*net.TCPConn).SetLinger(0) + if err != nil { + t.Error(err) + } + conn.Close() + } + }() + + svrUrl := "http://" + srv.Addr().String() + client := HostClient{Addr: srv.Addr().String()} + + for i := 0; i < 100; i++ { + req := AcquireRequest() + defer ReleaseRequest(req) + resp := AcquireResponse() + defer ReleaseResponse(resp) + + req.Header.SetMethod("POST") + req.SetBodyStream(strings.NewReader(payload), len(payload)) + req.SetRequestURI(svrUrl) + + err = client.Do(req, resp) + + if !isConnectionReset(err) { + t.Fatal("Expected connection reset error") + } + } +} diff --git a/http.go b/http.go index 47431cdcd8..75b40c9b74 100644 --- a/http.go +++ b/http.go @@ -1291,6 +1291,9 @@ func (resp *Response) ReadLimitBody(r *bufio.Reader, maxBodySize int) error { if !resp.mustSkipBody() { err = resp.ReadBody(r, maxBodySize) if err != nil { + if isConnectionReset(err) { + return nil + } return err } } @@ -1298,6 +1301,9 @@ func (resp *Response) ReadLimitBody(r *bufio.Reader, maxBodySize int) error { if resp.Header.ContentLength() == -1 { err = resp.Header.ReadTrailer(r) if err != nil && err != io.EOF { + if isConnectionReset(err) { + return nil + } return err } } diff --git a/tcp.go b/tcp.go new file mode 100644 index 0000000000..54d30334ea --- /dev/null +++ b/tcp.go @@ -0,0 +1,13 @@ +//go:build !windows +// +build !windows + +package fasthttp + +import ( + "errors" + "syscall" +) + +func isConnectionReset(err error) bool { + return errors.Is(err, syscall.ECONNRESET) +} diff --git a/tcp_windows.go b/tcp_windows.go new file mode 100644 index 0000000000..5c33025f40 --- /dev/null +++ b/tcp_windows.go @@ -0,0 +1,13 @@ +//go:build windows +// +build windows + +package fasthttp + +import ( + "errors" + "syscall" +) + +func isConnectionReset(err error) bool { + return errors.Is(err, syscall.WSAECONNRESET) +}