From 7c0d205b62776a58b80f5714687122bdfcb6df30 Mon Sep 17 00:00:00 2001 From: Erik Dubbelboer Date: Fri, 18 Jun 2021 14:49:17 +0200 Subject: [PATCH 1/2] Flush buffered responses if we have to wait for the next request Don't wait for the next request as this can take some time, instead flush the outstanding responses already. Fixes #1043 --- header.go | 9 +++++++- server.go | 24 ++++++++++++++++++-- server_test.go | 61 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 91 insertions(+), 3 deletions(-) diff --git a/header.go b/header.go index 34e2de94e1..677219ae93 100644 --- a/header.go +++ b/header.go @@ -1494,13 +1494,20 @@ func headerErrorMsg(typ string, err error, b []byte, secureErrorLogMessage bool) // // io.EOF is returned if r is closed before reading the first header byte. func (h *RequestHeader) Read(r *bufio.Reader) error { + return h.readLoop(r, true) +} + +// readLoop reads request header from r optionally loops until it has enough data. +// +// io.EOF is returned if r is closed before reading the first header byte. +func (h *RequestHeader) readLoop(r *bufio.Reader, waitForMore bool) error { n := 1 for { err := h.tryRead(r, n) if err == nil { return nil } - if err != errNeedMore { + if !waitForMore || err != errNeedMore { h.resetSkipNormalize() return err } diff --git a/server.go b/server.go index d47a756618..2903510054 100644 --- a/server.go +++ b/server.go @@ -2091,8 +2091,28 @@ func (s *Server) serveConn(c net.Conn) (err error) { ctx.Request.Header.DisableNormalizing() ctx.Response.Header.DisableNormalizing() } - // reading Headers - if err = ctx.Request.Header.Read(br); err == nil { + + // Reading Headers. + // + // If we have pipline response in the outgoing buffer, + // we only want to try and read the next headers once. + // If we have to wait for the next request we flush the + // outgoing buffer first so it doesn't have to wait. + if bw != nil && bw.Buffered() > 0 { + err = ctx.Request.Header.readLoop(br, false) + if err == errNeedMore { + err = bw.Flush() + if err != nil { + break + } + + err = ctx.Request.Header.Read(br) + } + } else { + err = ctx.Request.Header.Read(br) + } + + if err == nil { if onHdrRecv := s.HeaderReceived; onHdrRecv != nil { reqConf := onHdrRecv(&ctx.Request.Header) if reqConf.ReadTimeout > 0 { diff --git a/server_test.go b/server_test.go index b258de58a8..029bbb6299 100644 --- a/server_test.go +++ b/server_test.go @@ -23,6 +23,67 @@ import ( // Make sure RequestCtx implements context.Context var _ context.Context = &RequestCtx{} +func TestServerPipelineFlush(t *testing.T) { + t.Parallel() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + }, + } + ln := fasthttputil.NewInmemoryListener() + + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %s", err) + } + }() + + c, err := ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if _, err = c.Write([]byte("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil { + t.Fatal(err) + } + + // Write a partial request. + if _, err = c.Write([]byte("GET /foo1 HTTP/1.1\r\nHost: ")); err != nil { + t.Fatal(err) + } + go func() { + // Wait for 100ms to finish the request + time.Sleep(time.Millisecond * 100) + + if _, err = c.Write([]byte("google.com\r\n\r\n")); err != nil { + t.Error(err) + } + }() + + start := time.Now() + br := bufio.NewReader(c) + var resp Response + + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + } + + // Since the second request takes 100ms to finish we expect the first one to be flushed earlier. + d := time.Since(start) + if d > time.Millisecond*10 { + t.Fatalf("had to wait for %v", d) + } + + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + } +} + func TestServerInvalidHeader(t *testing.T) { t.Parallel() From 643b0aab51a656dd419086388be9e54701e1d69c Mon Sep 17 00:00:00 2001 From: Erik Dubbelboer Date: Fri, 25 Jun 2021 10:13:06 +0200 Subject: [PATCH 2/2] Only peek 1 byte Make sure old clients that send bogus \r\n still work. See: https://github.com/golang/go/commit/bf5e19fbaf02b1b25fbe50c27ec301fe830a28d0 --- server.go | 2 +- server_test.go | 88 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 1 deletion(-) diff --git a/server.go b/server.go index 2903510054..a34ff18c2c 100644 --- a/server.go +++ b/server.go @@ -2056,7 +2056,7 @@ func (s *Server) serveConn(c net.Conn) (err error) { // within the idle time. if connRequestNum > 1 { var b []byte - b, err = br.Peek(4) + b, err = br.Peek(1) if len(b) == 0 { // If reading from a keep-alive connection returns nothing it means // the connection was closed (either timeout or from the other side). diff --git a/server_test.go b/server_test.go index 029bbb6299..b2f1d7e741 100644 --- a/server_test.go +++ b/server_test.go @@ -23,6 +23,94 @@ import ( // Make sure RequestCtx implements context.Context var _ context.Context = &RequestCtx{} +func TestServerCRNLAfterPost_Pipeline(t *testing.T) { + t.Parallel() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + }, + Logger: &testLogger{}, + } + + ln := fasthttputil.NewInmemoryListener() + defer ln.Close() + + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %s", err) + } + }() + + c, err := ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + defer c.Close() + if _, err = c.Write([]byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" + + "\r\n\r\n" + // <-- this stuff is bogus, but we'll ignore it + "GET / HTTP/1.1\r\nHost: golang.org\r\n\r\n")); err != nil { + t.Fatal(err) + } + + br := bufio.NewReader(c) + var resp Response + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + } + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + } +} + +func TestServerCRNLAfterPost(t *testing.T) { + t.Parallel() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + }, + Logger: &testLogger{}, + ReadTimeout: time.Millisecond * 1, + } + + ln := fasthttputil.NewInmemoryListener() + defer ln.Close() + + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %s", err) + } + }() + + c, err := ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + defer c.Close() + if _, err = c.Write([]byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" + + "\r\n\r\n", // <-- this stuff is bogus, but we'll ignore it + )); err != nil { + t.Fatal(err) + } + + br := bufio.NewReader(c) + var resp Response + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + } + if err := resp.Read(br); err == nil { + t.Fatal("expected error") // We didn't send a request so we should get an error here. + } +} + func TestServerPipelineFlush(t *testing.T) { t.Parallel()