From 3e77ebff1b89c858c7e566dcf7a6eff2fb75b3cb Mon Sep 17 00:00:00 2001 From: Andrew Wilkins Date: Tue, 6 Oct 2020 12:22:09 +0800 Subject: [PATCH] module/apmhttp: implement io.ReaderFrom in wrapper If the http.ResponseWriter implements io.ReaderFrom, make sure the wrapped ResponseWriter also does and passes through. Closes #826 --- module/apmhttp/handler.go | 61 ++++++++++++++++++++++++++++------ module/apmhttp/handler_test.go | 21 ++++++++++++ 2 files changed, 72 insertions(+), 10 deletions(-) diff --git a/module/apmhttp/handler.go b/module/apmhttp/handler.go index d786ee46f..57536f6f8 100644 --- a/module/apmhttp/handler.go +++ b/module/apmhttp/handler.go @@ -19,6 +19,7 @@ package apmhttp import ( "context" + "io" "net/http" "go.elastic.co/apm" @@ -153,8 +154,8 @@ func SetContext(ctx *apm.Context, req *http.Request, resp *Response, body *apm.B // ResponseWriter's Write or WriteHeader methods are called, then the // response's StatusCode field will be zero. // -// The returned http.ResponseWriter implements http.Pusher and http.Hijacker -// if and only if the provided http.ResponseWriter does. +// The returned http.ResponseWriter implements http.Pusher, http.Hijacker, +// and io.ReaderFrom if and only if the provided http.ResponseWriter does. func WrapResponseWriter(w http.ResponseWriter) (http.ResponseWriter, *Response) { rw := responseWriter{ ResponseWriter: w, @@ -162,30 +163,50 @@ func WrapResponseWriter(w http.ResponseWriter) (http.ResponseWriter, *Response) Headers: w.Header(), }, } + h, _ := w.(http.Hijacker) p, _ := w.(http.Pusher) + rf, _ := w.(io.ReaderFrom) + switch { case h != nil && p != nil: - rwhp := &responseWriterHijackerPusher{ + rwhp := responseWriterHijackerPusher{ responseWriter: rw, Hijacker: h, Pusher: p, } - return rwhp, &rwhp.resp + if rf != nil { + rwhprf := responseWriterHijackerPusherReaderFrom{rwhp, rf} + return &rwhprf, &rwhprf.resp + } + return &rwhp, &rwhp.resp case h != nil: - rwh := &responseWriterHijacker{ + rwh := responseWriterHijacker{ responseWriter: rw, Hijacker: h, } - return rwh, &rwh.resp + if rf != nil { + rwhrf := responseWriterHijackerReaderFrom{rwh, rf} + return &rwhrf, &rwhrf.resp + } + return &rwh, &rwh.resp case p != nil: - rwp := &responseWriterPusher{ + rwp := responseWriterPusher{ responseWriter: rw, Pusher: p, } - return rwp, &rwp.resp + if rf != nil { + rwprf := responseWriterPusherReaderFrom{rwp, rf} + return &rwprf, &rwprf.resp + } + return &rwp, &rwp.resp + default: + if rf != nil { + rwrf := responseWriterReaderFrom{rw, rf} + return &rwrf, &rwrf.resp + } + return &rw, &rw.resp } - return &rw, &rw.resp } // Response records details of the HTTP response. @@ -229,7 +250,7 @@ func (w *responseWriter) CloseNotify() <-chan bool { return nil } -// Flush calls w.flush() if w.flush is non-nil, otherwise +// Flush calls w.ResponseWriter's Flush method if implemented, otherwise // it does nothing. func (w *responseWriter) Flush() { if flusher, ok := w.ResponseWriter.(http.Flusher); ok { @@ -237,22 +258,42 @@ func (w *responseWriter) Flush() { } } +type responseWriterReaderFrom struct { + responseWriter + io.ReaderFrom +} + type responseWriterHijacker struct { responseWriter http.Hijacker } +type responseWriterHijackerReaderFrom struct { + responseWriterHijacker + io.ReaderFrom +} + type responseWriterPusher struct { responseWriter http.Pusher } +type responseWriterPusherReaderFrom struct { + responseWriterPusher + io.ReaderFrom +} + type responseWriterHijackerPusher struct { responseWriter http.Hijacker http.Pusher } +type responseWriterHijackerPusherReaderFrom struct { + responseWriterHijackerPusher + io.ReaderFrom +} + // ServerOption sets options for tracing server requests. type ServerOption func(*handler) diff --git a/module/apmhttp/handler_test.go b/module/apmhttp/handler_test.go index 5804b8e24..0d02f02cb 100644 --- a/module/apmhttp/handler_test.go +++ b/module/apmhttp/handler_test.go @@ -526,6 +526,27 @@ func TestHandlerTracestateHeader(t *testing.T) { assert.Equal(t, "", w.Body.String()) } +func TestHandlerReaderFrom(t *testing.T) { + recorder := apmtest.NewRecordingTracer() + defer recorder.Close() + + mux := http.NewServeMux() + mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + assert.Implements(t, new(io.ReaderFrom), w) + rf := w.(io.ReaderFrom) + rf.ReadFrom(strings.NewReader("hello")) + })) + + srv := httptest.NewServer(apmhttp.Wrap(mux, apmhttp.WithTracer(recorder.Tracer))) + defer srv.Close() + + resp, err := http.Get(srv.URL) + require.NoError(t, err) + content, _ := ioutil.ReadAll(resp.Body) + assert.NoError(t, resp.Body.Close()) + assert.Equal(t, "hello", string(content)) +} + func panicHandler(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusTeapot) panic("foo")