From 6b78359dc4a30f760a2db6cde69247c0e1e0bd6b Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 27 Mar 2022 17:52:39 +0100 Subject: [PATCH] make the responseWriter hijackable --- http3/client_test.go | 6 +++--- http3/response_writer.go | 9 ++++++++- http3/response_writer_test.go | 2 +- http3/server.go | 2 +- 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/http3/client_test.go b/http3/client_test.go index fa289baeade..9304414f79d 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -430,7 +430,7 @@ var _ = Describe("Client", func() { buf := &bytes.Buffer{} rstr := mockquic.NewMockStream(mockCtrl) rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() - rw := newResponseWriter(rstr, utils.DefaultLogger) + rw := newResponseWriter(rstr, nil, utils.DefaultLogger) rw.WriteHeader(status) rw.Flush() return buf.Bytes() @@ -718,7 +718,7 @@ var _ = Describe("Client", func() { buf := &bytes.Buffer{} rstr := mockquic.NewMockStream(mockCtrl) rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() - rw := newResponseWriter(rstr, utils.DefaultLogger) + rw := newResponseWriter(rstr, nil, utils.DefaultLogger) rw.Header().Set("Content-Encoding", "gzip") gz := gzip.NewWriter(rw) gz.Write([]byte("gzipped response")) @@ -744,7 +744,7 @@ var _ = Describe("Client", func() { buf := &bytes.Buffer{} rstr := mockquic.NewMockStream(mockCtrl) rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() - rw := newResponseWriter(rstr, utils.DefaultLogger) + rw := newResponseWriter(rstr, nil, utils.DefaultLogger) rw.Write([]byte("not gzipped")) rw.Flush() str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) diff --git a/http3/response_writer.go b/http3/response_writer.go index 3b81f0a1f1c..9f232e0f2bb 100644 --- a/http3/response_writer.go +++ b/http3/response_writer.go @@ -23,6 +23,7 @@ type DataStreamer interface { } type responseWriter struct { + conn quic.Connection stream quic.Stream // needed for DataStream() bufferedStream *bufio.Writer @@ -38,12 +39,14 @@ var ( _ http.ResponseWriter = &responseWriter{} _ http.Flusher = &responseWriter{} _ DataStreamer = &responseWriter{} + _ Hijacker = &responseWriter{} ) -func newResponseWriter(stream quic.Stream, logger utils.Logger) *responseWriter { +func newResponseWriter(stream quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter { return &responseWriter{ header: http.Header{}, stream: stream, + conn: conn, bufferedStream: bufio.NewWriter(stream), logger: logger, } @@ -123,6 +126,10 @@ func (w *responseWriter) StreamID() quic.StreamID { return w.stream.StreamID() } +func (w *responseWriter) StreamCreator() StreamCreator { + return w.conn +} + // copied from http2/http2.go // bodyAllowedForStatus reports whether a given response status code // permits a body. See RFC 2616, section 4.4. diff --git a/http3/response_writer_test.go b/http3/response_writer_test.go index f1f454ccda5..2da3ef014a0 100644 --- a/http3/response_writer_test.go +++ b/http3/response_writer_test.go @@ -25,7 +25,7 @@ var _ = Describe("Response Writer", func() { strBuf = &bytes.Buffer{} str := mockquic.NewMockStream(mockCtrl) str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() - rw = newResponseWriter(str, utils.DefaultLogger) + rw = newResponseWriter(str, nil, utils.DefaultLogger) }) decodeHeader := func(str io.Reader) map[string][]string { diff --git a/http3/server.go b/http3/server.go index 69e5ee4c370..e1d818acc67 100644 --- a/http3/server.go +++ b/http3/server.go @@ -503,7 +503,7 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q ctx = context.WithValue(ctx, ServerContextKey, s) ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr()) req = req.WithContext(ctx) - r := newResponseWriter(str, s.logger) + r := newResponseWriter(str, conn, s.logger) defer func() { if !r.usedDataStream() { r.Flush()