Skip to content

Commit

Permalink
Allow ResponseWriters to unwrap writers when flushing/hijacking
Browse files Browse the repository at this point in the history
  • Loading branch information
aldas committed Feb 20, 2024
1 parent fa70db8 commit a5999fc
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 8 deletions.
12 changes: 10 additions & 2 deletions middleware/body_dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package middleware
import (
"bufio"
"bytes"
"errors"
"io"
"net"
"net/http"
Expand Down Expand Up @@ -98,9 +99,16 @@ func (w *bodyDumpResponseWriter) Write(b []byte) (int, error) {
}

func (w *bodyDumpResponseWriter) Flush() {
w.ResponseWriter.(http.Flusher).Flush()
err := http.NewResponseController(w.ResponseWriter).Flush()
if err != nil && errors.Is(err, http.ErrNotSupported) {
panic(errors.New("response writer flushing is not supported"))
}
}

func (w *bodyDumpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.ResponseWriter.(http.Hijacker).Hijack()
return http.NewResponseController(w.ResponseWriter).Hijack()
}

func (w *bodyDumpResponseWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}
50 changes: 50 additions & 0 deletions middleware/body_dump_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,53 @@ func TestBodyDumpFails(t *testing.T) {
}
})
}

func TestBodyDumpResponseWriter_CanNotFlush(t *testing.T) {
bdrw := bodyDumpResponseWriter{
ResponseWriter: new(testResponseWriterNoFlushHijack), // this RW does not support flush
}

assert.PanicsWithError(t, "response writer flushing is not supported", func() {
bdrw.Flush()
})
}

func TestBodyDumpResponseWriter_CanFlush(t *testing.T) {
trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
bdrw := bodyDumpResponseWriter{
ResponseWriter: &trwu,
}

bdrw.Flush()
assert.Equal(t, 1, trwu.unwrapCalled)
}

func TestBodyDumpResponseWriter_CanUnwrap(t *testing.T) {
trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
bdrw := bodyDumpResponseWriter{
ResponseWriter: trwu,
}

result := bdrw.Unwrap()
assert.Equal(t, trwu, result)
}

func TestBodyDumpResponseWriter_CanHijack(t *testing.T) {
trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
bdrw := bodyDumpResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}

_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "can hijack")
}

func TestBodyDumpResponseWriter_CanNotHijack(t *testing.T) {
trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
bdrw := bodyDumpResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}

_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "feature not supported")
}
10 changes: 6 additions & 4 deletions middleware/compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,15 @@ func (w *gzipResponseWriter) Flush() {
}

w.Writer.(*gzip.Writer).Flush()
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
http.NewResponseController(w.ResponseWriter).Flush()
}

func (w *gzipResponseWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}

func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.ResponseWriter.(http.Hijacker).Hijack()
return http.NewResponseController(w.ResponseWriter).Hijack()
}

func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
Expand Down
30 changes: 30 additions & 0 deletions middleware/compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,36 @@ func TestGzipWithStatic(t *testing.T) {
}
}

func TestGzipResponseWriter_CanUnwrap(t *testing.T) {
trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
bdrw := gzipResponseWriter{
ResponseWriter: trwu,
}

result := bdrw.Unwrap()
assert.Equal(t, trwu, result)
}

func TestGzipResponseWriter_CanHijack(t *testing.T) {
trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
bdrw := gzipResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}

_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "can hijack")
}

func TestGzipResponseWriter_CanNotHijack(t *testing.T) {
trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
bdrw := gzipResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}

_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "feature not supported")
}

func BenchmarkGzip(b *testing.B) {
e := echo.New()

Expand Down
46 changes: 46 additions & 0 deletions middleware/middleware_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package middleware

import (
"bufio"
"errors"
"github.com/stretchr/testify/assert"
"net"
"net/http"
"net/http/httptest"
"regexp"
Expand Down Expand Up @@ -90,3 +93,46 @@ func TestRewriteURL(t *testing.T) {
})
}
}

type testResponseWriterNoFlushHijack struct {
}

func (w *testResponseWriterNoFlushHijack) WriteHeader(statusCode int) {
}

func (w *testResponseWriterNoFlushHijack) Write([]byte) (int, error) {
return 0, nil
}

func (w *testResponseWriterNoFlushHijack) Header() http.Header {
return nil
}

type testResponseWriterUnwrapper struct {
unwrapCalled int
rw http.ResponseWriter
}

func (w *testResponseWriterUnwrapper) WriteHeader(statusCode int) {
}

func (w *testResponseWriterUnwrapper) Write([]byte) (int, error) {
return 0, nil
}

func (w *testResponseWriterUnwrapper) Header() http.Header {
return nil
}

func (w *testResponseWriterUnwrapper) Unwrap() http.ResponseWriter {
w.unwrapCalled++
return w.rw
}

type testResponseWriterUnwrapperHijack struct {
testResponseWriterUnwrapper
}

func (w *testResponseWriterUnwrapperHijack) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, errors.New("can hijack")
}
8 changes: 6 additions & 2 deletions response.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package echo

import (
"bufio"
"errors"
"net"
"net/http"
)
Expand Down Expand Up @@ -84,14 +85,17 @@ func (r *Response) Write(b []byte) (n int, err error) {
// buffered data to the client.
// See [http.Flusher](https://golang.org/pkg/net/http/#Flusher)
func (r *Response) Flush() {
r.Writer.(http.Flusher).Flush()
err := http.NewResponseController(r.Writer).Flush()

Check failure on line 88 in response.go

View workflow job for this annotation

GitHub Actions / ubuntu-latest @ Go 1.19

undefined: http.NewResponseController
if err != nil && errors.Is(err, http.ErrNotSupported) {
panic(errors.New("response writer flushing is not supported"))
}
}

// Hijack implements the http.Hijacker interface to allow an HTTP handler to
// take over the connection.
// See [http.Hijacker](https://golang.org/pkg/net/http/#Hijacker)
func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return r.Writer.(http.Hijacker).Hijack()
return http.NewResponseController(r.Writer).Hijack()

Check failure on line 98 in response.go

View workflow job for this annotation

GitHub Actions / ubuntu-latest @ Go 1.19

undefined: http.NewResponseController
}

// Unwrap returns the original http.ResponseWriter.
Expand Down
25 changes: 25 additions & 0 deletions response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,31 @@ func TestResponse_Flush(t *testing.T) {
assert.True(t, rec.Flushed)
}

type testResponseWriter struct {
}

func (w *testResponseWriter) WriteHeader(statusCode int) {
}

func (w *testResponseWriter) Write([]byte) (int, error) {
return 0, nil
}

func (w *testResponseWriter) Header() http.Header {
return nil
}

func TestResponse_FlushPanics(t *testing.T) {
e := New()
rw := new(testResponseWriter)
res := &Response{echo: e, Writer: rw}

// we test that we behave as before unwrapping flushers - flushing writer that does not support it causes panic
assert.PanicsWithError(t, "response writer flushing is not supported", func() {
res.Flush()
})
}

func TestResponse_ChangeStatusCodeBeforeWrite(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
Expand Down

0 comments on commit a5999fc

Please sign in to comment.