diff --git a/header.go b/header.go index 34e2de94e1..677219ae93 100644 --- a/header.go +++ b/header.go @@ -1494,13 +1494,20 @@ func headerErrorMsg(typ string, err error, b []byte, secureErrorLogMessage bool) // // io.EOF is returned if r is closed before reading the first header byte. func (h *RequestHeader) Read(r *bufio.Reader) error { + return h.readLoop(r, true) +} + +// readLoop reads request header from r optionally loops until it has enough data. +// +// io.EOF is returned if r is closed before reading the first header byte. +func (h *RequestHeader) readLoop(r *bufio.Reader, waitForMore bool) error { n := 1 for { err := h.tryRead(r, n) if err == nil { return nil } - if err != errNeedMore { + if !waitForMore || err != errNeedMore { h.resetSkipNormalize() return err } diff --git a/server.go b/server.go index d47a756618..2903510054 100644 --- a/server.go +++ b/server.go @@ -2091,8 +2091,28 @@ func (s *Server) serveConn(c net.Conn) (err error) { ctx.Request.Header.DisableNormalizing() ctx.Response.Header.DisableNormalizing() } - // reading Headers - if err = ctx.Request.Header.Read(br); err == nil { + + // Reading Headers. + // + // If we have pipline response in the outgoing buffer, + // we only want to try and read the next headers once. + // If we have to wait for the next request we flush the + // outgoing buffer first so it doesn't have to wait. + if bw != nil && bw.Buffered() > 0 { + err = ctx.Request.Header.readLoop(br, false) + if err == errNeedMore { + err = bw.Flush() + if err != nil { + break + } + + err = ctx.Request.Header.Read(br) + } + } else { + err = ctx.Request.Header.Read(br) + } + + if err == nil { if onHdrRecv := s.HeaderReceived; onHdrRecv != nil { reqConf := onHdrRecv(&ctx.Request.Header) if reqConf.ReadTimeout > 0 { diff --git a/server_test.go b/server_test.go index c7b8f485c1..b7ad3ffe41 100644 --- a/server_test.go +++ b/server_test.go @@ -23,6 +23,67 @@ import ( // Make sure RequestCtx implements context.Context var _ context.Context = &RequestCtx{} +func TestServerPipelineFlush(t *testing.T) { + t.Parallel() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + }, + } + ln := fasthttputil.NewInmemoryListener() + + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %s", err) + } + }() + + c, err := ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if _, err = c.Write([]byte("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil { + t.Fatal(err) + } + + // Write a partial request. + if _, err = c.Write([]byte("GET /foo1 HTTP/1.1\r\nHost: ")); err != nil { + t.Fatal(err) + } + go func() { + // Wait for 100ms to finish the request + time.Sleep(time.Millisecond * 100) + + if _, err = c.Write([]byte("google.com\r\n\r\n")); err != nil { + t.Error(err) + } + }() + + start := time.Now() + br := bufio.NewReader(c) + var resp Response + + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + } + + // Since the second request takes 100ms to finish we expect the first one to be flushed earlier. + d := time.Since(start) + if d > time.Millisecond*10 { + t.Fatalf("had to wait for %v", d) + } + + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + } +} + func TestServerInvalidHeader(t *testing.T) { t.Parallel()