Skip to content

Commit

Permalink
make the responseWriter hijackable
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Apr 3, 2022
1 parent a983db0 commit ff6313f
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 6 deletions.
6 changes: 3 additions & 3 deletions http3/client_test.go
Expand Up @@ -429,7 +429,7 @@ var _ = Describe("Client", func() {
buf := &bytes.Buffer{}
rstr := mockquic.NewMockStream(mockCtrl)
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
rw := newResponseWriter(rstr, utils.DefaultLogger)
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
rw.WriteHeader(status)
rw.Flush()
return buf.Bytes()
Expand Down Expand Up @@ -717,7 +717,7 @@ var _ = Describe("Client", func() {
buf := &bytes.Buffer{}
rstr := mockquic.NewMockStream(mockCtrl)
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
rw := newResponseWriter(rstr, utils.DefaultLogger)
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
rw.Header().Set("Content-Encoding", "gzip")
gz := gzip.NewWriter(rw)
gz.Write([]byte("gzipped response"))
Expand All @@ -743,7 +743,7 @@ var _ = Describe("Client", func() {
buf := &bytes.Buffer{}
rstr := mockquic.NewMockStream(mockCtrl)
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
rw := newResponseWriter(rstr, utils.DefaultLogger)
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
rw.Write([]byte("not gzipped"))
rw.Flush()
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
Expand Down
9 changes: 8 additions & 1 deletion http3/response_writer.go
Expand Up @@ -23,6 +23,7 @@ type DataStreamer interface {
}

type responseWriter struct {
conn quic.Connection
stream quic.Stream // needed for DataStream()
bufferedStream *bufio.Writer

Expand All @@ -38,12 +39,14 @@ var (
_ http.ResponseWriter = &responseWriter{}
_ http.Flusher = &responseWriter{}
_ DataStreamer = &responseWriter{}
_ Hijacker = &responseWriter{}
)

func newResponseWriter(stream quic.Stream, logger utils.Logger) *responseWriter {
func newResponseWriter(stream quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter {
return &responseWriter{
header: http.Header{},
stream: stream,
conn: conn,
bufferedStream: bufio.NewWriter(stream),
logger: logger,
}
Expand Down Expand Up @@ -123,6 +126,10 @@ func (w *responseWriter) StreamID() quic.StreamID {
return w.stream.StreamID()
}

func (w *responseWriter) StreamCreator() StreamCreator {
return w.conn
}

// copied from http2/http2.go
// bodyAllowedForStatus reports whether a given response status code
// permits a body. See RFC 2616, section 4.4.
Expand Down
2 changes: 1 addition & 1 deletion http3/response_writer_test.go
Expand Up @@ -25,7 +25,7 @@ var _ = Describe("Response Writer", func() {
strBuf = &bytes.Buffer{}
str := mockquic.NewMockStream(mockCtrl)
str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
rw = newResponseWriter(str, utils.DefaultLogger)
rw = newResponseWriter(str, nil, utils.DefaultLogger)
})

decodeHeader := func(str io.Reader) map[string][]string {
Expand Down
2 changes: 1 addition & 1 deletion http3/server.go
Expand Up @@ -503,7 +503,7 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q
ctx = context.WithValue(ctx, ServerContextKey, s)
ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr())
req = req.WithContext(ctx)
r := newResponseWriter(str, s.logger)
r := newResponseWriter(str, conn, s.logger)
defer func() {
if !r.usedDataStream() {
r.Flush()
Expand Down

0 comments on commit ff6313f

Please sign in to comment.