Skip to content

Commit

Permalink
middleware: add method to WrapResponseWriter for getting response wri…
Browse files Browse the repository at this point in the history
…ting time
  • Loading branch information
vasayxtx committed Dec 28, 2023
1 parent 58ca6d6 commit 1bfa1de
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
17 changes: 17 additions & 0 deletions middleware/wrap_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io"
"net"
"net/http"
"time"
)

// NewWrapResponseWriter wraps an http.ResponseWriter, returning a proxy that allows you to
Expand Down Expand Up @@ -47,18 +48,25 @@ func NewWrapResponseWriter(w http.ResponseWriter, protoMajor int) WrapResponseWr
// into various parts of the response process.
type WrapResponseWriter interface {
http.ResponseWriter

// Status returns the HTTP status of the request, or 0 if one has not
// yet been sent.
Status() int

// BytesWritten returns the total number of bytes sent to the client.
BytesWritten() int

// ElapsedWriteTime returns the total time spent writing the response.
ElapsedWriteTime() time.Duration

// Tee causes the response body to be written to the given io.Writer in
// addition to proxying the writes through. Only one io.Writer can be
// tee'd to at once: setting a second one will overwrite the first.
// Writes will be sent to the proxy before being written to this
// io.Writer. It is illegal for the tee'd writer to be modified
// concurrently with writes.
Tee(io.Writer)

// Unwrap returns the original proxied target.
Unwrap() http.ResponseWriter
}
Expand All @@ -70,18 +78,22 @@ type basicWriter struct {
wroteHeader bool
code int
bytes int
elapsedTime time.Duration
tee io.Writer
}

func (b *basicWriter) WriteHeader(code int) {
if !b.wroteHeader {
startTime := time.Now()
b.code = code
b.wroteHeader = true
b.ResponseWriter.WriteHeader(code)
b.elapsedTime += time.Since(startTime)
}
}

func (b *basicWriter) Write(buf []byte) (int, error) {
startTime := time.Now()
b.maybeWriteHeader()
n, err := b.ResponseWriter.Write(buf)
if b.tee != nil {
Expand All @@ -91,6 +103,7 @@ func (b *basicWriter) Write(buf []byte) (int, error) {
err = err2
}
}
b.elapsedTime += time.Since(startTime)
b.bytes += n
return n, err
}
Expand All @@ -109,6 +122,10 @@ func (b *basicWriter) BytesWritten() int {
return b.bytes
}

func (b *basicWriter) ElapsedWriteTime() time.Duration {
return b.elapsedTime
}

func (b *basicWriter) Tee(w io.Writer) {
b.tee = w
}
Expand Down
42 changes: 42 additions & 0 deletions middleware/wrap_writer_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package middleware

import (
"net/http"
"net/http/httptest"
"testing"
"time"
)

func TestHttpFancyWriterRemembersWroteHeaderWhenFlushed(t *testing.T) {
Expand All @@ -22,3 +24,43 @@ func TestHttp2FancyWriterRemembersWroteHeaderWhenFlushed(t *testing.T) {
t.Fatal("want Flush to have set wroteHeader=true")
}
}

func TestBasicWriterComputesElapsedWriteTime(t *testing.T) {
const delay = 50 * time.Millisecond
rw := &basicWriter{ResponseWriter: &DelayedResponseWriter{ResponseWriter: httptest.NewRecorder(), Delay: delay}}

if rw.ElapsedWriteTime() != 0 {
t.Fatal("write time should be 0 before any writes")
}

startTime := time.Now()

rw.WriteHeader(http.StatusOK)
totalElapsedTime := time.Since(startTime)
if writeTime := rw.ElapsedWriteTime(); writeTime < delay || writeTime > totalElapsedTime {
t.Fatalf("elapsed write time (%s) is not in the expected range (%s, %s)", writeTime, delay, totalElapsedTime)
}

if _, err := rw.Write([]byte("hello")); err != nil {
t.Fatal(err)
}
totalElapsedTime = time.Since(startTime)
if writeTime := rw.ElapsedWriteTime(); writeTime < delay*2 || writeTime > totalElapsedTime {
t.Fatalf("elapsed write time (%s) is not in the expected range (%s, %s)", writeTime, delay*2, totalElapsedTime)
}
}

type DelayedResponseWriter struct {
http.ResponseWriter
Delay time.Duration
}

func (w *DelayedResponseWriter) WriteHeader(statusCode int) {
time.Sleep(w.Delay)
w.ResponseWriter.WriteHeader(statusCode)
}

func (w *DelayedResponseWriter) Write(b []byte) (int, error) {
time.Sleep(w.Delay)
return w.ResponseWriter.Write(b)
}

0 comments on commit 1bfa1de

Please sign in to comment.