diff --git a/bytesconv.go b/bytesconv.go index b3cf29e3ce..4e2271a712 100644 --- a/bytesconv.go +++ b/bytesconv.go @@ -355,3 +355,23 @@ func appendQuotedPath(dst, src []byte) []byte { } return dst } + +// countHexDigits returns the number of hex digits required to represent n when using writeHexInt +func countHexDigits(n int) int { + if n < 0 { + // developer sanity-check + panic("BUG: int must be positive") + } + + if n == 0 { + return 1 + } + + count := 0 + for n > 0 { + n = n >> 4 + count++ + } + + return count +} diff --git a/http.go b/http.go index 2dd906102f..17bdc2a76d 100644 --- a/http.go +++ b/http.go @@ -120,6 +120,8 @@ type Response struct { raddr net.Addr // Local TCPAddr from concurrently net.Conn laddr net.Addr + + headersWritten bool } // SetHost sets host for the request. @@ -1122,6 +1124,7 @@ func (resp *Response) Reset() { resp.laddr = nil resp.ImmediateHeaderFlush = false resp.StreamBody = false + resp.headersWritten = false } func (resp *Response) resetSkipHeader() { diff --git a/server.go b/server.go index 9cfc17534c..7b2189da9b 100644 --- a/server.go +++ b/server.go @@ -608,6 +608,48 @@ type RequestCtx struct { hijackHandler HijackHandler hijackNoResponse bool formValueFunc FormValueFunc + + disableBuffering bool // disables buffered response body + getUnbufferedWriter func(*RequestCtx) UnbufferedWriter // defines how to get unbuffered writer + unbufferedWriter UnbufferedWriter // writes directly to underlying connection + bytesSent int // number of bytes sent to client using unbuffered operations +} + +// DisableBuffering modifies fasthttp to disable body buffering for this request. +// This is useful for requests that return large data or stream data. +// +// When buffering is disabled you must: +// 1. Set response status and header values before writing body +// 2. Set ContentLength is optional. If not set, the server will use chunked encoding. +// 3. Write body data using methods like ctx.Write or io.Copy(ctx,src), etc. +// 4. Optionally call CloseResponse to finalize the response. +// +// CLosing the response will finalize the response and send the last chunk. +// If the handler does not finish the response, it will be called automatically after handler returns. +// Closing the response will also set BytesSent with the correct number of total bytes sent. +func (ctx *RequestCtx) DisableBuffering() { + ctx.disableBuffering = true + + // We need to create a new unbufferedWriter for each unbuffered request. + // This way we can allow different implementations and be compatible with http2 protocol + if ctx.unbufferedWriter == nil { + if ctx.getUnbufferedWriter != nil { + ctx.unbufferedWriter = ctx.getUnbufferedWriter(ctx) + } else { + ctx.unbufferedWriter = NewUnbufferedWriter(ctx) + } + } +} + +// CloseResponse finalizes non-buffered response dispatch. +// This method must be called after performing non-buffered responses +// If the handler does not finish the response, it will be called automatically +// after the handler function returns. +func (ctx *RequestCtx) CloseResponse() error { + if !ctx.disableBuffering || ctx.unbufferedWriter == nil { + return ErrNotUnbuffered + } + return ctx.unbufferedWriter.Close() } // HijackHandler must process the hijacked connection c. @@ -822,6 +864,11 @@ func (ctx *RequestCtx) reset() { ctx.hijackHandler = nil ctx.hijackNoResponse = false + + ctx.disableBuffering = false + ctx.unbufferedWriter = nil + ctx.getUnbufferedWriter = nil + ctx.bytesSent = 0 } type firstByteReader struct { @@ -1443,10 +1490,28 @@ func (ctx *RequestCtx) NotFound() { // Write writes p into response body. func (ctx *RequestCtx) Write(p []byte) (int, error) { + if ctx.disableBuffering { + return ctx.writeDirect(p) + } + ctx.Response.AppendBody(p) return len(p), nil } +// writeDirect writes p to underlying connection bypassing any buffering. +func (ctx *RequestCtx) writeDirect(p []byte) (int, error) { + if ctx.unbufferedWriter == nil { + ctx.unbufferedWriter = NewUnbufferedWriter(ctx) + } + return ctx.unbufferedWriter.Write(p) +} + +// BytesSent returns the number of bytes sent to the client after non buffered operation. +// Includes headers and body length. +func (ctx *RequestCtx) BytesSent() int { + return ctx.bytesSent +} + // WriteString appends s to response body. func (ctx *RequestCtx) WriteString(s string) (int, error) { ctx.Response.AppendBodyString(s) @@ -2359,6 +2424,11 @@ func (s *Server) serveConn(c net.Conn) (err error) { s.Handler(ctx) } + if ctx.disableBuffering { + _ = ctx.CloseResponse() + break + } + timeoutResponse = ctx.timeoutResponse if timeoutResponse != nil { // Acquire a new ctx because the old one will still be in use by the timeout out handler. diff --git a/server_test.go b/server_test.go index ad188db190..e96233e1ed 100644 --- a/server_test.go +++ b/server_test.go @@ -4237,6 +4237,76 @@ func TestServerChunkedResponse(t *testing.T) { } } +func TestServerDisableBuffering(t *testing.T) { + t.Parallel() + + received := make(chan bool) + done := make(chan bool) + + expectedBody := bytes.Repeat([]byte("a"), 4096) + + s := &Server{ + Handler: func(ctx *RequestCtx) { + ctx.DisableBuffering() + ctx.SetStatusCode(StatusOK) + ctx.SetContentType("text/html; charset=utf-8") + reader := bytes.NewReader(expectedBody) + _, err := io.Copy(ctx, reader) + if err != nil { + t.Fatalf("Unexpected error when copying body: %v", err) + } + ctx.CloseResponse() + if len(ctx.Response.Body()) > 0 { + t.Fatalf("Body was populated when buffer was disabled") + } + + // wait until body is received by the consumer or stop after 2 seconds timeout + select { + case <-received: + case <-time.After(2 * time.Second): + t.Fatal("Body not received by consumer after 2 seconds") + } + + // The consumer received the body, so we can finish the test + done <- true + }, + } + + ln := fasthttputil.NewInmemoryListener() + + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %v", err) + } + }() + + conn, err := ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, err = conn.Write([]byte("GET /index.html HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil { + t.Fatalf("unexpected error: %v", err) + } + br := bufio.NewReader(conn) + + var resp Response + if err := resp.Read(br); err != nil { + t.Fatalf("Unexpected error when reading response: %v", err) + } + if resp.Header.ContentLength() != -1 { + t.Fatalf("Unexpected Content-Length %d. Expected %d", resp.Header.ContentLength(), -1) + } + if !bytes.Equal(resp.Body(), expectedBody) { + t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), "foobar") + } + + // Signal that the body was received correctly + received <- true + + // Wait until the server has finished + <-done +} + func verifyResponse(t *testing.T, r *bufio.Reader, expectedStatusCode int, expectedContentType, expectedBody string) *Response { var resp Response if err := resp.Read(r); err != nil { diff --git a/unbuffered.go b/unbuffered.go new file mode 100644 index 0000000000..2cf6ea8026 --- /dev/null +++ b/unbuffered.go @@ -0,0 +1,116 @@ +package fasthttp + +import ( + "bufio" + "errors" + "fmt" +) + +type UnbufferedWriter interface { + Write(p []byte) (int, error) + WriteHeaders() (int, error) + Close() error +} + +type UnbufferedWriterHttp1 struct { + writer *bufio.Writer + ctx *RequestCtx + bodyChunkStarted bool + bodyLastChunkSent bool +} + +var ErrNotUnbuffered = errors.New("not unbuffered") +var ErrClosedUnbufferedWriter = errors.New("closed unbuffered writer") + +// Ensure UnbufferedWriterHttp1 implements UnbufferedWriter. +var _ UnbufferedWriter = &UnbufferedWriterHttp1{} + +// NewUnbufferedWriter +// +// Object must be discarded when request is finished +func NewUnbufferedWriter(ctx *RequestCtx) *UnbufferedWriterHttp1 { + writer := acquireWriter(ctx) + return &UnbufferedWriterHttp1{ctx: ctx, writer: writer} +} + +func (uw *UnbufferedWriterHttp1) Write(p []byte) (int, error) { + if uw.writer == nil || uw.ctx == nil { + return 0, ErrClosedUnbufferedWriter + } + + // Write headers if not already sent + if !uw.ctx.Response.headersWritten { + _, err := uw.WriteHeaders() + if err != nil { + return 0, fmt.Errorf("error writing headers: %w", err) + } + } + + // Write body. In chunks if content length is not set. + if uw.ctx.Response.Header.contentLength == -1 && uw.ctx.Response.Header.IsHTTP11() { + uw.bodyChunkStarted = true + err := writeChunk(uw.writer, p) + if err != nil { + return 0, err + } + uw.ctx.bytesSent += len(p) + 4 + countHexDigits(len(p)) + return len(p), nil + } + + n, err := uw.writer.Write(p) + uw.ctx.bytesSent += n + + return n, err +} + +func (uw *UnbufferedWriterHttp1) WriteHeaders() (int, error) { + if uw.writer == nil || uw.ctx == nil { + return 0, ErrClosedUnbufferedWriter + } + + if !uw.ctx.Response.headersWritten { + if uw.ctx.Response.Header.contentLength == 0 && uw.ctx.Response.Header.IsHTTP11() { + if uw.ctx.Response.SkipBody { + uw.ctx.Response.Header.SetContentLength(0) + } else { + uw.ctx.Response.Header.SetContentLength(-1) // means Transfer-Encoding = chunked + } + } + h := uw.ctx.Response.Header.Header() + n, err := uw.writer.Write(h) + if err != nil { + return 0, err + } + uw.ctx.bytesSent += n + uw.ctx.Response.headersWritten = true + } + return 0, nil +} + +func (uw *UnbufferedWriterHttp1) Close() error { + if uw.writer == nil || uw.ctx == nil { + return ErrClosedUnbufferedWriter + } + + // write headers if not already sent (e.g. if there is no body written) + if !uw.ctx.Response.headersWritten { + // skip body, as we are closing without writing body + uw.ctx.Response.SkipBody = true + _, err := uw.WriteHeaders() + if err != nil { + return fmt.Errorf("error writing headers: %w", err) + } + } + + // finalize chunks + if uw.bodyChunkStarted && uw.ctx.Response.Header.IsHTTP11() && !uw.bodyLastChunkSent { + _, _ = uw.writer.Write([]byte("0\r\n\r\n")) + uw.ctx.bytesSent += 5 + } + _ = uw.writer.Flush() + uw.bodyLastChunkSent = true + releaseWriter(uw.ctx.s, uw.writer) + uw.writer = nil + uw.ctx = nil + return nil +}