Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flush buffered responses if we have to wait for the next request #1050

Merged
merged 2 commits into from Jun 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 8 additions & 1 deletion header.go
Expand Up @@ -1494,13 +1494,20 @@ func headerErrorMsg(typ string, err error, b []byte, secureErrorLogMessage bool)
//
// io.EOF is returned if r is closed before reading the first header byte.
func (h *RequestHeader) Read(r *bufio.Reader) error {
return h.readLoop(r, true)
}

// readLoop reads request header from r optionally loops until it has enough data.
//
// io.EOF is returned if r is closed before reading the first header byte.
func (h *RequestHeader) readLoop(r *bufio.Reader, waitForMore bool) error {
n := 1
for {
err := h.tryRead(r, n)
if err == nil {
return nil
}
if err != errNeedMore {
if !waitForMore || err != errNeedMore {
h.resetSkipNormalize()
return err
}
Expand Down
26 changes: 23 additions & 3 deletions server.go
Expand Up @@ -2056,7 +2056,7 @@ func (s *Server) serveConn(c net.Conn) (err error) {
// within the idle time.
if connRequestNum > 1 {
var b []byte
b, err = br.Peek(4)
b, err = br.Peek(1)
if len(b) == 0 {
// If reading from a keep-alive connection returns nothing it means
// the connection was closed (either timeout or from the other side).
Expand Down Expand Up @@ -2091,8 +2091,28 @@ func (s *Server) serveConn(c net.Conn) (err error) {
ctx.Request.Header.DisableNormalizing()
ctx.Response.Header.DisableNormalizing()
}
// reading Headers
if err = ctx.Request.Header.Read(br); err == nil {

// Reading Headers.
//
// If we have pipline response in the outgoing buffer,
// we only want to try and read the next headers once.
// If we have to wait for the next request we flush the
// outgoing buffer first so it doesn't have to wait.
if bw != nil && bw.Buffered() > 0 {
err = ctx.Request.Header.readLoop(br, false)
if err == errNeedMore {
err = bw.Flush()
if err != nil {
break
}

err = ctx.Request.Header.Read(br)
}
} else {
err = ctx.Request.Header.Read(br)
}

if err == nil {
if onHdrRecv := s.HeaderReceived; onHdrRecv != nil {
reqConf := onHdrRecv(&ctx.Request.Header)
if reqConf.ReadTimeout > 0 {
Expand Down
149 changes: 149 additions & 0 deletions server_test.go
Expand Up @@ -23,6 +23,155 @@ import (
// Make sure RequestCtx implements context.Context
var _ context.Context = &RequestCtx{}

func TestServerCRNLAfterPost_Pipeline(t *testing.T) {
t.Parallel()

s := &Server{
Handler: func(ctx *RequestCtx) {
},
Logger: &testLogger{},
}

ln := fasthttputil.NewInmemoryListener()
defer ln.Close()

go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
}()

c, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
defer c.Close()
if _, err = c.Write([]byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" +
"\r\n\r\n" + // <-- this stuff is bogus, but we'll ignore it
"GET / HTTP/1.1\r\nHost: golang.org\r\n\r\n")); err != nil {
t.Fatal(err)
}

br := bufio.NewReader(c)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
}

func TestServerCRNLAfterPost(t *testing.T) {
t.Parallel()

s := &Server{
Handler: func(ctx *RequestCtx) {
},
Logger: &testLogger{},
ReadTimeout: time.Millisecond * 1,
}

ln := fasthttputil.NewInmemoryListener()
defer ln.Close()

go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
}()

c, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
defer c.Close()
if _, err = c.Write([]byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" +
"\r\n\r\n", // <-- this stuff is bogus, but we'll ignore it
)); err != nil {
t.Fatal(err)
}

br := bufio.NewReader(c)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
if err := resp.Read(br); err == nil {
t.Fatal("expected error") // We didn't send a request so we should get an error here.
}
}

func TestServerPipelineFlush(t *testing.T) {
t.Parallel()

s := &Server{
Handler: func(ctx *RequestCtx) {
},
}
ln := fasthttputil.NewInmemoryListener()

go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %s", err)
}
}()

c, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if _, err = c.Write([]byte("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Fatal(err)
}

// Write a partial request.
if _, err = c.Write([]byte("GET /foo1 HTTP/1.1\r\nHost: ")); err != nil {
t.Fatal(err)
}
go func() {
// Wait for 100ms to finish the request
time.Sleep(time.Millisecond * 100)

if _, err = c.Write([]byte("google.com\r\n\r\n")); err != nil {
t.Error(err)
}
}()

start := time.Now()
br := bufio.NewReader(c)
var resp Response

if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}

// Since the second request takes 100ms to finish we expect the first one to be flushed earlier.
d := time.Since(start)
if d > time.Millisecond*10 {
t.Fatalf("had to wait for %v", d)
}

if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
}

func TestServerInvalidHeader(t *testing.T) {
t.Parallel()

Expand Down