diff --git a/http3/body.go b/http3/body.go index 23d4cf556ef..b3d1afd7ba6 100644 --- a/http3/body.go +++ b/http3/body.go @@ -2,13 +2,21 @@ package http3 import ( "context" - "fmt" "io" "net" "github.com/lucas-clemente/quic-go" ) +// The HTTPStreamer allows taking over a HTTP/3 stream. The interface is implemented by: +// * for the server: the http.Request.Body +// * for the client: the http.Response.Body +// On the client side, the stream will be closed for writing, unless the DontCloseRequestStream RoundTripOpt was set. +// When a stream is taken over, it's the caller's responsibility to close the stream. +type HTTPStreamer interface { + HTTPStream() Stream +} + type StreamCreator interface { OpenStream() (quic.Stream, error) OpenStreamSync(context.Context) (quic.Stream, error) @@ -30,41 +38,59 @@ type Hijacker interface { type body struct { str quic.Stream - // only set for the http.Response - // The channel is closed when the user is done with this response: - // either when Read() errors, or when Close() is called. - reqDone chan<- struct{} - reqDoneClosed bool + wasHijacked bool // set when HTTPStream is called +} + +var ( + _ io.ReadCloser = &body{} + _ HTTPStreamer = &body{} +) - onFrameError func() +func newRequestBody(str Stream) *body { + return &body{str: str} +} + +func (r *body) HTTPStream() Stream { + r.wasHijacked = true + return r.str +} + +func (r *body) wasStreamHijacked() bool { + return r.wasHijacked +} - bytesRemainingInFrame uint64 +func (r *body) Read(b []byte) (int, error) { + return r.str.Read(b) } -var _ io.ReadCloser = &body{} +func (r *body) Close() error { + r.str.CancelRead(quic.StreamErrorCode(errorRequestCanceled)) + return nil +} type hijackableBody struct { body conn quic.Connection // only needed to implement Hijacker -} - -var _ Hijacker = &hijackableBody{} -func newRequestBody(str quic.Stream, onFrameError func()) *body { - return &body{ - str: str, - onFrameError: onFrameError, - } + // only set for the http.Response + // The channel is closed when the user is done with this response: + // either when Read() errors, or when Close() is called. + reqDone chan<- struct{} + reqDoneClosed bool } -func newResponseBody(str quic.Stream, conn quic.Connection, done chan<- struct{}, onFrameError func()) *hijackableBody { +var ( + _ Hijacker = &hijackableBody{} + _ HTTPStreamer = &hijackableBody{} +) + +func newResponseBody(str Stream, conn quic.Connection, done chan<- struct{}) *hijackableBody { return &hijackableBody{ body: body{ - str: str, - onFrameError: onFrameError, - reqDone: done, + str: str, }, - conn: conn, + reqDone: done, + conn: conn, } } @@ -72,50 +98,15 @@ func (r *hijackableBody) StreamCreator() StreamCreator { return r.conn } -func (r *body) Read(b []byte) (int, error) { - n, err := r.readImpl(b) +func (r *hijackableBody) Read(b []byte) (int, error) { + n, err := r.str.Read(b) if err != nil { r.requestDone() } return n, err } -func (r *body) readImpl(b []byte) (int, error) { - if r.bytesRemainingInFrame == 0 { - parseLoop: - for { - frame, err := parseNextFrame(r.str, nil) - if err != nil { - return 0, err - } - switch f := frame.(type) { - case *headersFrame: - // skip HEADERS frames - continue - case *dataFrame: - r.bytesRemainingInFrame = f.Length - break parseLoop - default: - r.onFrameError() - // parseNextFrame skips over unknown frame types - // Therefore, this condition is only entered when we parsed another known frame type. - return 0, fmt.Errorf("peer sent an unexpected frame: %T", f) - } - } - } - - var n int - var err error - if r.bytesRemainingInFrame < uint64(len(b)) { - n, err = r.str.Read(b[:r.bytesRemainingInFrame]) - } else { - n, err = r.str.Read(b) - } - r.bytesRemainingInFrame -= uint64(n) - return n, err -} - -func (r *body) requestDone() { +func (r *hijackableBody) requestDone() { if r.reqDoneClosed || r.reqDone == nil { return } @@ -127,9 +118,13 @@ func (r *body) StreamID() quic.StreamID { return r.str.StreamID() } -func (r *body) Close() error { +func (r *hijackableBody) Close() error { r.requestDone() // If the EOF was read, CancelRead() is a no-op. r.str.CancelRead(quic.StreamErrorCode(errorRequestCanceled)) return nil } + +func (r *hijackableBody) HTTPStream() Stream { + return r.str +} diff --git a/http3/body_test.go b/http3/body_test.go index f50004dc325..4920357dda3 100644 --- a/http3/body_test.go +++ b/http3/body_test.go @@ -1,189 +1,54 @@ package http3 import ( - "bytes" - "fmt" - "io" + "errors" - "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go" mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic" + "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) -type bodyType uint8 - -const ( - bodyTypeRequest bodyType = iota - bodyTypeResponse -) - -func (t bodyType) String() string { - if t == bodyTypeRequest { - return "request" - } - return "response" -} - -var _ = Describe("Body", func() { - var ( - rb io.ReadCloser - str *mockquic.MockStream - buf *bytes.Buffer - reqDone chan struct{} - errorCbCalled bool - ) +var _ = Describe("Response Body", func() { + var reqDone chan struct{} - errorCb := func() { errorCbCalled = true } + BeforeEach(func() { reqDone = make(chan struct{}) }) - getDataFrame := func(data []byte) []byte { - b := &bytes.Buffer{} - (&dataFrame{Length: uint64(len(data))}).Write(b) - b.Write(data) - return b.Bytes() - } - - BeforeEach(func() { - buf = &bytes.Buffer{} - errorCbCalled = false + It("closes the reqDone channel when Read errors", func() { + str := mockquic.NewMockStream(mockCtrl) + str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test error")) + rb := newResponseBody(str, nil, reqDone) + _, err := rb.Read([]byte{0}) + Expect(err).To(MatchError("test error")) + Expect(reqDone).To(BeClosed()) }) - for _, bt := range []bodyType{bodyTypeRequest, bodyTypeResponse} { - bodyType := bt - - Context(fmt.Sprintf("using a %s body", bodyType), func() { - BeforeEach(func() { - str = mockquic.NewMockStream(mockCtrl) - str.EXPECT().Write(gomock.Any()).DoAndReturn(func(b []byte) (int, error) { - return buf.Write(b) - }).AnyTimes() - str.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) { - return buf.Read(b) - }).AnyTimes() - - switch bodyType { - case bodyTypeRequest: - rb = newRequestBody(str, errorCb) - case bodyTypeResponse: - reqDone = make(chan struct{}) - rb = newResponseBody(str, nil, reqDone, errorCb) - } - }) - - It("reads DATA frames in a single run", func() { - buf.Write(getDataFrame([]byte("foobar"))) - b := make([]byte, 6) - n, err := rb.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(6)) - Expect(b).To(Equal([]byte("foobar"))) - }) - - It("reads DATA frames in multiple runs", func() { - buf.Write(getDataFrame([]byte("foobar"))) - b := make([]byte, 3) - n, err := rb.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(3)) - Expect(b).To(Equal([]byte("foo"))) - n, err = rb.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(3)) - Expect(b).To(Equal([]byte("bar"))) - }) - - It("reads DATA frames into too large buffers", func() { - buf.Write(getDataFrame([]byte("foobar"))) - b := make([]byte, 10) - n, err := rb.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(6)) - Expect(b[:n]).To(Equal([]byte("foobar"))) - }) - - It("reads DATA frames into too large buffers, in multiple runs", func() { - buf.Write(getDataFrame([]byte("foobar"))) - b := make([]byte, 4) - n, err := rb.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(4)) - Expect(b).To(Equal([]byte("foob"))) - n, err = rb.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(2)) - Expect(b[:n]).To(Equal([]byte("ar"))) - }) - - It("reads multiple DATA frames", func() { - buf.Write(getDataFrame([]byte("foo"))) - buf.Write(getDataFrame([]byte("bar"))) - b := make([]byte, 6) - n, err := rb.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(3)) - Expect(b[:n]).To(Equal([]byte("foo"))) - n, err = rb.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(3)) - Expect(b[:n]).To(Equal([]byte("bar"))) - }) - - It("skips HEADERS frames", func() { - buf.Write(getDataFrame([]byte("foo"))) - (&headersFrame{Length: 10}).Write(buf) - buf.Write(make([]byte, 10)) - buf.Write(getDataFrame([]byte("bar"))) - b := make([]byte, 6) - n, err := io.ReadFull(rb, b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(6)) - Expect(b).To(Equal([]byte("foobar"))) - }) - - It("errors when it can't parse the frame", func() { - buf.Write([]byte("invalid")) - _, err := rb.Read([]byte{0}) - Expect(err).To(HaveOccurred()) - }) - - It("errors on unexpected frames, and calls the error callback", func() { - (&settingsFrame{}).Write(buf) - _, err := rb.Read([]byte{0}) - Expect(err).To(MatchError("peer sent an unexpected frame: *http3.settingsFrame")) - Expect(errorCbCalled).To(BeTrue()) - }) - - if bodyType == bodyTypeResponse { - It("closes the reqDone channel when Read errors", func() { - buf.Write([]byte("invalid")) - _, err := rb.Read([]byte{0}) - Expect(err).To(HaveOccurred()) - Expect(reqDone).To(BeClosed()) - }) - - It("allows multiple calls to Read, when Read errors", func() { - buf.Write([]byte("invalid")) - _, err := rb.Read([]byte{0}) - Expect(err).To(HaveOccurred()) - Expect(reqDone).To(BeClosed()) - _, err = rb.Read([]byte{0}) - Expect(err).To(HaveOccurred()) - }) + It("allows multiple calls to Read, when Read errors", func() { + str := mockquic.NewMockStream(mockCtrl) + str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test error")).Times(2) + rb := newResponseBody(str, nil, reqDone) + _, err := rb.Read([]byte{0}) + Expect(err).To(HaveOccurred()) + Expect(reqDone).To(BeClosed()) + _, err = rb.Read([]byte{0}) + Expect(err).To(HaveOccurred()) + }) - It("closes responses", func() { - str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)) - Expect(rb.Close()).To(Succeed()) - }) + It("closes responses", func() { + str := mockquic.NewMockStream(mockCtrl) + rb := newResponseBody(str, nil, reqDone) + str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)) + Expect(rb.Close()).To(Succeed()) + }) - It("allows multiple calls to Close", func() { - str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).MaxTimes(2) - Expect(rb.Close()).To(Succeed()) - Expect(reqDone).To(BeClosed()) - Expect(rb.Close()).To(Succeed()) - }) - } - }) - } + It("allows multiple calls to Close", func() { + str := mockquic.NewMockStream(mockCtrl) + rb := newResponseBody(str, nil, reqDone) + str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).MaxTimes(2) + Expect(rb.Close()).To(Succeed()) + Expect(reqDone).To(BeClosed()) + Expect(rb.Close()).To(Succeed()) + }) }) diff --git a/http3/client.go b/http3/client.go index 325fd4d4456..c56a8a35707 100644 --- a/http3/client.go +++ b/http3/client.go @@ -298,15 +298,59 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon return rsp, rerr.err } +func (c *client) sendRequestBody(str Stream, body io.ReadCloser) error { + defer body.Close() + b := make([]byte, bodyCopyBufferSize) + for { + n, rerr := body.Read(b) + if n == 0 { + if rerr == nil { + continue + } + if rerr == io.EOF { + break + } + } + if _, err := str.Write(b[:n]); err != nil { + return err + } + if rerr != nil { + if rerr == io.EOF { + break + } + str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled)) + return rerr + } + } + return nil +} + func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, reqDone chan struct{}) (*http.Response, requestError) { var requestGzip bool if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" { requestGzip = true } - if err := c.requestWriter.WriteRequest(str, req, opt.DontCloseRequestStream, requestGzip); err != nil { + if err := c.requestWriter.WriteRequestHeader(str, req, requestGzip); err != nil { return nil, newStreamError(errorInternalError, err) } + if req.Body == nil && !opt.DontCloseRequestStream { + str.Close() + } + + hstr := newStream(str, func() { c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") }) + if req.Body != nil { + // send the request body asynchronously + go func() { + if err := c.sendRequestBody(hstr, req.Body); err != nil { + c.logger.Errorf("Error writing request: %s", err) + } + if !opt.DontCloseRequestStream { + hstr.Close() + } + }() + } + frame, err := parseNextFrame(str, nil) if err != nil { return nil, newStreamError(errorFrameError, err) @@ -348,9 +392,7 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, res.Header.Add(hf.Name, hf.Value) } } - respBody := newResponseBody(str, c.conn, reqDone, func() { - c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") - }) + respBody := newResponseBody(hstr, c.conn, reqDone) // Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2. _, hasTransferEncoding := res.Header["Transfer-Encoding"] diff --git a/http3/client_test.go b/http3/client_test.go index 9be1c6849d0..f512fd41857 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -797,6 +797,7 @@ var _ = Describe("Client", func() { <-done return 0, errors.New("test done") }) + str.EXPECT().Close() _, err := client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("test done")) }) diff --git a/http3/http_stream.go b/http3/http_stream.go new file mode 100644 index 00000000000..4c69068cdbb --- /dev/null +++ b/http3/http_stream.go @@ -0,0 +1,71 @@ +package http3 + +import ( + "bytes" + "fmt" + + "github.com/lucas-clemente/quic-go" +) + +// A Stream is a HTTP/3 stream. +// When writing to and reading from the stream, data is framed in HTTP/3 DATA frames. +type Stream quic.Stream + +// The stream conforms to the quic.Stream interface, but instead of writing to and reading directly +// from the QUIC stream, it writes to and reads from the HTTP stream. +type stream struct { + quic.Stream + + onFrameError func() + bytesRemainingInFrame uint64 +} + +var _ Stream = &stream{} + +func newStream(str quic.Stream, onFrameError func()) *stream { + return &stream{Stream: str, onFrameError: onFrameError} +} + +func (s *stream) Read(b []byte) (int, error) { + if s.bytesRemainingInFrame == 0 { + parseLoop: + for { + frame, err := parseNextFrame(s.Stream, nil) + if err != nil { + return 0, err + } + switch f := frame.(type) { + case *headersFrame: + // skip HEADERS frames + continue + case *dataFrame: + s.bytesRemainingInFrame = f.Length + break parseLoop + default: + s.onFrameError() + // parseNextFrame skips over unknown frame types + // Therefore, this condition is only entered when we parsed another known frame type. + return 0, fmt.Errorf("peer sent an unexpected frame: %T", f) + } + } + } + + var n int + var err error + if s.bytesRemainingInFrame < uint64(len(b)) { + n, err = s.Stream.Read(b[:s.bytesRemainingInFrame]) + } else { + n, err = s.Stream.Read(b) + } + s.bytesRemainingInFrame -= uint64(n) + return n, err +} + +func (s *stream) Write(b []byte) (int, error) { + buf := &bytes.Buffer{} + (&dataFrame{Length: uint64(len(b))}).Write(buf) + if _, err := s.Stream.Write(buf.Bytes()); err != nil { + return 0, err + } + return s.Stream.Write(b) +} diff --git a/http3/http_stream_test.go b/http3/http_stream_test.go new file mode 100644 index 00000000000..ad9833b97c9 --- /dev/null +++ b/http3/http_stream_test.go @@ -0,0 +1,150 @@ +package http3 + +import ( + "bytes" + "io" + + mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic" + + "github.com/golang/mock/gomock" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Stream", func() { + Context("reading", func() { + var ( + str Stream + qstr *mockquic.MockStream + buf *bytes.Buffer + errorCbCalled bool + ) + + errorCb := func() { errorCbCalled = true } + getDataFrame := func(data []byte) []byte { + b := &bytes.Buffer{} + (&dataFrame{Length: uint64(len(data))}).Write(b) + b.Write(data) + return b.Bytes() + } + + BeforeEach(func() { + buf = &bytes.Buffer{} + errorCbCalled = false + qstr = mockquic.NewMockStream(mockCtrl) + qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() + qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + str = newStream(qstr, errorCb) + }) + + It("reads DATA frames in a single run", func() { + buf.Write(getDataFrame([]byte("foobar"))) + b := make([]byte, 6) + n, err := str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(6)) + Expect(b).To(Equal([]byte("foobar"))) + }) + + It("reads DATA frames in multiple runs", func() { + buf.Write(getDataFrame([]byte("foobar"))) + b := make([]byte, 3) + n, err := str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(3)) + Expect(b).To(Equal([]byte("foo"))) + n, err = str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(3)) + Expect(b).To(Equal([]byte("bar"))) + }) + + It("reads DATA frames into too large buffers", func() { + buf.Write(getDataFrame([]byte("foobar"))) + b := make([]byte, 10) + n, err := str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(6)) + Expect(b[:n]).To(Equal([]byte("foobar"))) + }) + + It("reads DATA frames into too large buffers, in multiple runs", func() { + buf.Write(getDataFrame([]byte("foobar"))) + b := make([]byte, 4) + n, err := str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(4)) + Expect(b).To(Equal([]byte("foob"))) + n, err = str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(2)) + Expect(b[:n]).To(Equal([]byte("ar"))) + }) + + It("reads multiple DATA frames", func() { + buf.Write(getDataFrame([]byte("foo"))) + buf.Write(getDataFrame([]byte("bar"))) + b := make([]byte, 6) + n, err := str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(3)) + Expect(b[:n]).To(Equal([]byte("foo"))) + n, err = str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(3)) + Expect(b[:n]).To(Equal([]byte("bar"))) + }) + + It("skips HEADERS frames", func() { + buf.Write(getDataFrame([]byte("foo"))) + (&headersFrame{Length: 10}).Write(buf) + buf.Write(make([]byte, 10)) + buf.Write(getDataFrame([]byte("bar"))) + b := make([]byte, 6) + n, err := io.ReadFull(str, b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(6)) + Expect(b).To(Equal([]byte("foobar"))) + }) + + It("errors when it can't parse the frame", func() { + buf.Write([]byte("invalid")) + _, err := str.Read([]byte{0}) + Expect(err).To(HaveOccurred()) + }) + + It("errors on unexpected frames, and calls the error callback", func() { + (&settingsFrame{}).Write(buf) + _, err := str.Read([]byte{0}) + Expect(err).To(MatchError("peer sent an unexpected frame: *http3.settingsFrame")) + Expect(errorCbCalled).To(BeTrue()) + }) + }) + + Context("writing", func() { + It("writes data frames", func() { + buf := &bytes.Buffer{} + qstr := mockquic.NewMockStream(mockCtrl) + qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() + str := newStream(qstr, nil) + str.Write([]byte("foo")) + str.Write([]byte("foobar")) + + f, err := parseNextFrame(buf, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(f).To(Equal(&dataFrame{Length: 3})) + b := make([]byte, 3) + _, err = io.ReadFull(buf, b) + Expect(err).ToNot(HaveOccurred()) + Expect(b).To(Equal([]byte("foo"))) + + f, err = parseNextFrame(buf, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(f).To(Equal(&dataFrame{Length: 6})) + b = make([]byte, 6) + _, err = io.ReadFull(buf, b) + Expect(err).ToNot(HaveOccurred()) + Expect(b).To(Equal([]byte("foobar"))) + }) + }) +}) diff --git a/http3/request_writer.go b/http3/request_writer.go index bd36c7ddc94..0a9c67ac5d7 100644 --- a/http3/request_writer.go +++ b/http3/request_writer.go @@ -38,60 +38,14 @@ func newRequestWriter(logger utils.Logger) *requestWriter { } } -func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, dontCloseStr, gzip bool) error { +func (w *requestWriter) WriteRequestHeader(str quic.Stream, req *http.Request, gzip bool) error { + // TODO: figure out how to add support for trailers buf := &bytes.Buffer{} if err := w.writeHeaders(buf, req, gzip); err != nil { return err } - if _, err := str.Write(buf.Bytes()); err != nil { - return err - } - // TODO: add support for trailers - if req.Body == nil { - if !dontCloseStr { - str.Close() - } - return nil - } - - // send the request body asynchronously - go func() { - defer req.Body.Close() - b := make([]byte, bodyCopyBufferSize) - for { - n, rerr := req.Body.Read(b) - if n == 0 { - if rerr == nil { - continue - } else if rerr == io.EOF { - break - } - } - buf := &bytes.Buffer{} - (&dataFrame{Length: uint64(n)}).Write(buf) - if _, err := str.Write(buf.Bytes()); err != nil { - w.logger.Errorf("Error writing request: %s", err) - return - } - if _, err := str.Write(b[:n]); err != nil { - w.logger.Errorf("Error writing request: %s", err) - return - } - if rerr != nil { - if rerr == io.EOF { - break - } - str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled)) - w.logger.Errorf("Error writing request: %s", rerr) - return - } - } - if !dontCloseStr { - str.Close() - } - }() - - return nil + _, err := str.Write(buf.Bytes()) + return err } func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool) error { diff --git a/http3/request_writer_test.go b/http3/request_writer_test.go index e2c80cdc1d9..1e5a161432a 100644 --- a/http3/request_writer_test.go +++ b/http3/request_writer_test.go @@ -4,7 +4,6 @@ import ( "bytes" "io" "net/http" - "strconv" mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic" "github.com/lucas-clemente/quic-go/internal/utils" @@ -16,12 +15,6 @@ import ( . "github.com/onsi/gomega" ) -type foobarReader struct{} - -func (r *foobarReader) Read(b []byte) (int, error) { - return copy(b, []byte("foobar")), io.EOF -} - var _ = Describe("Request Writer", func() { var ( rw *requestWriter @@ -51,16 +44,13 @@ var _ = Describe("Request Writer", func() { rw = newRequestWriter(utils.DefaultLogger) strBuf = &bytes.Buffer{} str = mockquic.NewMockStream(mockCtrl) - str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { - return strBuf.Write(p) - }).AnyTimes() + str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() }) It("writes a GET request", func() { - str.EXPECT().Close() req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/index.html?foo=bar", nil) Expect(err).ToNot(HaveOccurred()) - Expect(rw.WriteRequest(str, req, false, false)).To(Succeed()) + Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) headerFields := decode(strBuf) Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io")) Expect(headerFields).To(HaveKeyWithValue(":method", "GET")) @@ -69,55 +59,7 @@ var _ = Describe("Request Writer", func() { Expect(headerFields).ToNot(HaveKey("accept-encoding")) }) - It("writes a GET request without closing the stream", func() { - req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io", nil) - Expect(err).ToNot(HaveOccurred()) - Expect(rw.WriteRequest(str, req, true, false)).To(Succeed()) - headerFields := decode(strBuf) - Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io")) - }) - - It("writes a POST request", func() { - closed := make(chan struct{}) - str.EXPECT().Close().Do(func() { close(closed) }) - postData := bytes.NewReader([]byte("foobar")) - req, err := http.NewRequest(http.MethodPost, "https://quic.clemente.io/upload.html", postData) - Expect(err).ToNot(HaveOccurred()) - Expect(rw.WriteRequest(str, req, false, false)).To(Succeed()) - - Eventually(closed).Should(BeClosed()) - headerFields := decode(strBuf) - Expect(headerFields).To(HaveKeyWithValue(":method", "POST")) - Expect(headerFields).To(HaveKey("content-length")) - contentLength, err := strconv.Atoi(headerFields["content-length"]) - Expect(err).ToNot(HaveOccurred()) - Expect(contentLength).To(BeNumerically(">", 0)) - - frame, err := parseNextFrame(strBuf, nil) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) - Expect(frame.(*dataFrame).Length).To(BeEquivalentTo(6)) - }) - - It("writes a POST request, if the Body returns an EOF immediately", func() { - closed := make(chan struct{}) - str.EXPECT().Close().Do(func() { close(closed) }) - req, err := http.NewRequest(http.MethodPost, "https://quic.clemente.io/upload.html", &foobarReader{}) - Expect(err).ToNot(HaveOccurred()) - Expect(rw.WriteRequest(str, req, false, false)).To(Succeed()) - - Eventually(closed).Should(BeClosed()) - headerFields := decode(strBuf) - Expect(headerFields).To(HaveKeyWithValue(":method", "POST")) - - frame, err := parseNextFrame(strBuf, nil) - Expect(err).ToNot(HaveOccurred()) - Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) - Expect(frame.(*dataFrame).Length).To(BeEquivalentTo(6)) - }) - It("sends cookies", func() { - str.EXPECT().Close() req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/", nil) Expect(err).ToNot(HaveOccurred()) cookie1 := &http.Cookie{ @@ -130,25 +72,23 @@ var _ = Describe("Request Writer", func() { } req.AddCookie(cookie1) req.AddCookie(cookie2) - Expect(rw.WriteRequest(str, req, false, false)).To(Succeed()) + Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) headerFields := decode(strBuf) Expect(headerFields).To(HaveKeyWithValue("cookie", `Cookie #1="Value #1"; Cookie #2="Value #2"`)) }) It("adds the header for gzip support", func() { - str.EXPECT().Close() req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/", nil) Expect(err).ToNot(HaveOccurred()) - Expect(rw.WriteRequest(str, req, false, true)).To(Succeed()) + Expect(rw.WriteRequestHeader(str, req, true)).To(Succeed()) headerFields := decode(strBuf) Expect(headerFields).To(HaveKeyWithValue("accept-encoding", "gzip")) }) It("writes a CONNECT request", func() { - str.EXPECT().Close() req, err := http.NewRequest(http.MethodConnect, "https://quic.clemente.io/", nil) Expect(err).ToNot(HaveOccurred()) - Expect(rw.WriteRequest(str, req, false, false)).To(Succeed()) + Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) headerFields := decode(strBuf) Expect(headerFields).To(HaveKeyWithValue(":method", "CONNECT")) Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io")) @@ -158,11 +98,10 @@ var _ = Describe("Request Writer", func() { }) It("writes an Extended CONNECT request", func() { - str.EXPECT().Close() req, err := http.NewRequest(http.MethodConnect, "https://quic.clemente.io/foobar", nil) Expect(err).ToNot(HaveOccurred()) req.Proto = "webtransport" - Expect(rw.WriteRequest(str, req, false, false)).To(Succeed()) + Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) headerFields := decode(strBuf) Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io")) Expect(headerFields).To(HaveKeyWithValue(":method", "CONNECT")) diff --git a/http3/response_writer.go b/http3/response_writer.go index c0de2d4ec2a..70a7cd3f484 100644 --- a/http3/response_writer.go +++ b/http3/response_writer.go @@ -13,9 +13,8 @@ import ( ) type responseWriter struct { - conn quic.Connection - stream quic.Stream // needed for DataStream() - bufferedStream *bufio.Writer + conn quic.Connection + bufferedStr *bufio.Writer header http.Header status int // status code passed to WriteHeader @@ -30,13 +29,12 @@ var ( _ Hijacker = &responseWriter{} ) -func newResponseWriter(stream quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter { +func newResponseWriter(str quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter { return &responseWriter{ - header: http.Header{}, - stream: stream, - conn: conn, - bufferedStream: bufio.NewWriter(stream), - logger: logger, + header: http.Header{}, + conn: conn, + bufferedStr: bufio.NewWriter(str), + logger: logger, } } @@ -67,10 +65,10 @@ func (w *responseWriter) WriteHeader(status int) { buf := &bytes.Buffer{} (&headersFrame{Length: uint64(headers.Len())}).Write(buf) w.logger.Infof("Responding with %d", status) - if _, err := w.bufferedStream.Write(buf.Bytes()); err != nil { + if _, err := w.bufferedStr.Write(buf.Bytes()); err != nil { w.logger.Errorf("could not write headers frame: %s", err.Error()) } - if _, err := w.bufferedStream.Write(headers.Bytes()); err != nil { + if _, err := w.bufferedStr.Write(headers.Bytes()); err != nil { w.logger.Errorf("could not write header frame payload: %s", err.Error()) } if !w.headerWritten { @@ -88,22 +86,18 @@ func (w *responseWriter) Write(p []byte) (int, error) { df := &dataFrame{Length: uint64(len(p))} buf := &bytes.Buffer{} df.Write(buf) - if _, err := w.bufferedStream.Write(buf.Bytes()); err != nil { + if _, err := w.bufferedStr.Write(buf.Bytes()); err != nil { return 0, err } - return w.bufferedStream.Write(p) + return w.bufferedStr.Write(p) } func (w *responseWriter) Flush() { - if err := w.bufferedStream.Flush(); err != nil { + if err := w.bufferedStr.Flush(); err != nil { w.logger.Errorf("could not flush to stream: %s", err.Error()) } } -func (w *responseWriter) StreamID() quic.StreamID { - return w.stream.StreamID() -} - func (w *responseWriter) StreamCreator() StreamCreator { return w.conn } diff --git a/http3/server.go b/http3/server.go index 45ca3f4c81b..cc904433220 100644 --- a/http3/server.go +++ b/http3/server.go @@ -549,7 +549,8 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q } req.RemoteAddr = conn.RemoteAddr().String() - req.Body = newRequestBody(str, onFrameError) + body := newRequestBody(newStream(str, onFrameError)) + req.Body = body if s.logger.Debug() { s.logger.Infof("%s %s%s, on stream %d", req.Method, req.Host, req.RequestURI, str.StreamID()) @@ -583,6 +584,10 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q handler.ServeHTTP(r, req) }() + if body.wasStreamHijacked() { + return requestError{err: errHijacked} + } + if panicked { r.WriteHeader(500) } else { diff --git a/http3/server_test.go b/http3/server_test.go index 064380c3d88..6b4cae02e63 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -135,11 +135,8 @@ var _ = Describe("Server", func() { buf := &bytes.Buffer{} str := mockquic.NewMockStream(mockCtrl) str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() - closed := make(chan struct{}) - str.EXPECT().Close().Do(func() { close(closed) }) rw := newRequestWriter(utils.DefaultLogger) - Expect(rw.WriteRequest(str, req, false, false)).To(Succeed()) - Eventually(closed).Should(BeClosed()) + Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) return buf.Bytes() } @@ -162,7 +159,6 @@ var _ = Describe("Server", func() { qpackDecoder = qpack.NewDecoder(nil) str = mockquic.NewMockStream(mockCtrl) - conn = mockquic.NewMockEarlyConnection(mockCtrl) addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() @@ -483,7 +479,7 @@ var _ = Describe("Server", func() { }) } - It("reset streams Other than the control stream and the QPACK streams", func() { + It("reset streams other than the control stream and the QPACK streams", func() { buf := &bytes.Buffer{} quicvarint.Write(buf, 1337) str := mockquic.NewMockStream(mockCtrl) @@ -626,9 +622,9 @@ var _ = Describe("Server", func() { AfterEach(func() { testDone <- struct{}{} }) It("cancels reading when client sends a body in GET request", func() { - handlerCalled := make(chan struct{}) + var handlerCalled bool s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - close(handlerCalled) + handlerCalled = true }) requestData := encodeRequest(exampleGetRequest) @@ -647,6 +643,27 @@ var _ = Describe("Server", func() { Eventually(done).Should(BeClosed()) hfs := decodeHeader(responseBuf) Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) + Expect(handlerCalled).To(BeTrue()) + }) + + It("doesn't close the stream if the stream was hijacked (via HTTPStream)", func() { + handlerCalled := make(chan struct{}) + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer close(handlerCalled) + r.Body.(HTTPStreamer).HTTPStream() + str.Write([]byte("foobar")) + }) + + requestData := encodeRequest(exampleGetRequest) + buf := &bytes.Buffer{} + (&dataFrame{Length: 6}).Write(buf) // add a body + buf.Write([]byte("foobar")) + setRequest(append(requestData, buf.Bytes()...)) + str.EXPECT().Context().Return(reqContext) + str.EXPECT().Write([]byte("foobar")).Return(6, nil) + + s.handleConn(conn) + Eventually(handlerCalled).Should(BeClosed()) }) It("errors when the client sends a too large header frame", func() { diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index 0a4e440d3c7..728511e10b8 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -311,6 +311,46 @@ var _ = Describe("HTTP tests", func() { Expect(req.Body.Close()).To(Succeed()) Eventually(done).Should(BeClosed()) }) + + It("allows taking over the stream", func() { + mux.HandleFunc("/httpstreamer", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + w.WriteHeader(200) + w.(http.Flusher).Flush() + + str := r.Body.(http3.HTTPStreamer).HTTPStream() + str.Write([]byte("foobar")) + + // Do this in a Go routine, so that the handler returns early. + // This way, we can also check that the HTTP/3 doesn't close the stream. + go func() { + defer GinkgoRecover() + _, err := io.Copy(str, str) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + }() + }) + + req, err := http.NewRequest(http.MethodGet, "https://localhost:"+port+"/httpstreamer", nil) + Expect(err).ToNot(HaveOccurred()) + rsp, err := client.Transport.(*http3.RoundTripper).RoundTripOpt(req, http3.RoundTripOpt{DontCloseRequestStream: true}) + Expect(err).ToNot(HaveOccurred()) + Expect(rsp.StatusCode).To(Equal(200)) + + str := rsp.Body.(http3.HTTPStreamer).HTTPStream() + b := make([]byte, 6) + _, err = io.ReadFull(str, b) + Expect(err).ToNot(HaveOccurred()) + Expect(b).To(Equal([]byte("foobar"))) + + data := GeneratePRData(8 * 1024) + _, err = str.Write(data) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + repl, err := io.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(repl).To(Equal(data)) + }) }) } })