Skip to content

Commit

Permalink
module/apmhttp: implement io.ReaderFrom in wrapper
Browse files Browse the repository at this point in the history
If the http.ResponseWriter implements io.ReaderFrom,
make sure the wrapped ResponseWriter also does and
passes through.

Closes #826
  • Loading branch information
axw committed Oct 6, 2020
1 parent b4d025c commit 3e77ebf
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 10 deletions.
61 changes: 51 additions & 10 deletions module/apmhttp/handler.go
Expand Up @@ -19,6 +19,7 @@ package apmhttp

import (
"context"
"io"
"net/http"

"go.elastic.co/apm"
Expand Down Expand Up @@ -153,39 +154,59 @@ func SetContext(ctx *apm.Context, req *http.Request, resp *Response, body *apm.B
// ResponseWriter's Write or WriteHeader methods are called, then the
// response's StatusCode field will be zero.
//
// The returned http.ResponseWriter implements http.Pusher and http.Hijacker
// if and only if the provided http.ResponseWriter does.
// The returned http.ResponseWriter implements http.Pusher, http.Hijacker,
// and io.ReaderFrom if and only if the provided http.ResponseWriter does.
func WrapResponseWriter(w http.ResponseWriter) (http.ResponseWriter, *Response) {
rw := responseWriter{
ResponseWriter: w,
resp: Response{
Headers: w.Header(),
},
}

h, _ := w.(http.Hijacker)
p, _ := w.(http.Pusher)
rf, _ := w.(io.ReaderFrom)

switch {
case h != nil && p != nil:
rwhp := &responseWriterHijackerPusher{
rwhp := responseWriterHijackerPusher{
responseWriter: rw,
Hijacker: h,
Pusher: p,
}
return rwhp, &rwhp.resp
if rf != nil {
rwhprf := responseWriterHijackerPusherReaderFrom{rwhp, rf}
return &rwhprf, &rwhprf.resp
}
return &rwhp, &rwhp.resp
case h != nil:
rwh := &responseWriterHijacker{
rwh := responseWriterHijacker{
responseWriter: rw,
Hijacker: h,
}
return rwh, &rwh.resp
if rf != nil {
rwhrf := responseWriterHijackerReaderFrom{rwh, rf}
return &rwhrf, &rwhrf.resp
}
return &rwh, &rwh.resp
case p != nil:
rwp := &responseWriterPusher{
rwp := responseWriterPusher{
responseWriter: rw,
Pusher: p,
}
return rwp, &rwp.resp
if rf != nil {
rwprf := responseWriterPusherReaderFrom{rwp, rf}
return &rwprf, &rwprf.resp
}
return &rwp, &rwp.resp
default:
if rf != nil {
rwrf := responseWriterReaderFrom{rw, rf}
return &rwrf, &rwrf.resp
}
return &rw, &rw.resp
}
return &rw, &rw.resp
}

// Response records details of the HTTP response.
Expand Down Expand Up @@ -229,30 +250,50 @@ func (w *responseWriter) CloseNotify() <-chan bool {
return nil
}

// Flush calls w.flush() if w.flush is non-nil, otherwise
// Flush calls w.ResponseWriter's Flush method if implemented, otherwise
// it does nothing.
func (w *responseWriter) Flush() {
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}

type responseWriterReaderFrom struct {
responseWriter
io.ReaderFrom
}

type responseWriterHijacker struct {
responseWriter
http.Hijacker
}

type responseWriterHijackerReaderFrom struct {
responseWriterHijacker
io.ReaderFrom
}

type responseWriterPusher struct {
responseWriter
http.Pusher
}

type responseWriterPusherReaderFrom struct {
responseWriterPusher
io.ReaderFrom
}

type responseWriterHijackerPusher struct {
responseWriter
http.Hijacker
http.Pusher
}

type responseWriterHijackerPusherReaderFrom struct {
responseWriterHijackerPusher
io.ReaderFrom
}

// ServerOption sets options for tracing server requests.
type ServerOption func(*handler)

Expand Down
21 changes: 21 additions & 0 deletions module/apmhttp/handler_test.go
Expand Up @@ -526,6 +526,27 @@ func TestHandlerTracestateHeader(t *testing.T) {
assert.Equal(t, "", w.Body.String())
}

func TestHandlerReaderFrom(t *testing.T) {
recorder := apmtest.NewRecordingTracer()
defer recorder.Close()

mux := http.NewServeMux()
mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
assert.Implements(t, new(io.ReaderFrom), w)
rf := w.(io.ReaderFrom)
rf.ReadFrom(strings.NewReader("hello"))
}))

srv := httptest.NewServer(apmhttp.Wrap(mux, apmhttp.WithTracer(recorder.Tracer)))
defer srv.Close()

resp, err := http.Get(srv.URL)
require.NoError(t, err)
content, _ := ioutil.ReadAll(resp.Body)
assert.NoError(t, resp.Body.Close())
assert.Equal(t, "hello", string(content))
}

func panicHandler(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(http.StatusTeapot)
panic("foo")
Expand Down

0 comments on commit 3e77ebf

Please sign in to comment.