diff --git a/server.go b/server.go index 2903510054..a34ff18c2c 100644 --- a/server.go +++ b/server.go @@ -2056,7 +2056,7 @@ func (s *Server) serveConn(c net.Conn) (err error) { // within the idle time. if connRequestNum > 1 { var b []byte - b, err = br.Peek(4) + b, err = br.Peek(1) if len(b) == 0 { // If reading from a keep-alive connection returns nothing it means // the connection was closed (either timeout or from the other side). diff --git a/server_test.go b/server_test.go index 029bbb6299..b2f1d7e741 100644 --- a/server_test.go +++ b/server_test.go @@ -23,6 +23,94 @@ import ( // Make sure RequestCtx implements context.Context var _ context.Context = &RequestCtx{} +func TestServerCRNLAfterPost_Pipeline(t *testing.T) { + t.Parallel() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + }, + Logger: &testLogger{}, + } + + ln := fasthttputil.NewInmemoryListener() + defer ln.Close() + + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %s", err) + } + }() + + c, err := ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + defer c.Close() + if _, err = c.Write([]byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" + + "\r\n\r\n" + // <-- this stuff is bogus, but we'll ignore it + "GET / HTTP/1.1\r\nHost: golang.org\r\n\r\n")); err != nil { + t.Fatal(err) + } + + br := bufio.NewReader(c) + var resp Response + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + } + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + } +} + +func TestServerCRNLAfterPost(t *testing.T) { + t.Parallel() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + }, + Logger: &testLogger{}, + ReadTimeout: time.Millisecond * 1, + } + + ln := fasthttputil.NewInmemoryListener() + defer ln.Close() + + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %s", err) + } + }() + + c, err := ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + defer c.Close() + if _, err = c.Write([]byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" + + "\r\n\r\n", // <-- this stuff is bogus, but we'll ignore it + )); err != nil { + t.Fatal(err) + } + + br := bufio.NewReader(c) + var resp Response + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + } + if err := resp.Read(br); err == nil { + t.Fatal("expected error") // We didn't send a request so we should get an error here. + } +} + func TestServerPipelineFlush(t *testing.T) { t.Parallel()