From 30401e287e43b4c555a8fe4cfd55ed7ff18a8e31 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 10 Jul 2022 16:48:17 +0000 Subject: [PATCH] http3: ignore context after response when using DontCloseRequestStream --- http3/body.go | 4 +++- http3/client.go | 11 +++++++++-- http3/client_test.go | 20 ++++++++++++++++++++ http3/roundtrip.go | 1 + 4 files changed, 33 insertions(+), 3 deletions(-) diff --git a/http3/body.go b/http3/body.go index b3d1afd7ba6..d6e704ebcb1 100644 --- a/http3/body.go +++ b/http3/body.go @@ -110,7 +110,9 @@ func (r *hijackableBody) requestDone() { if r.reqDoneClosed || r.reqDone == nil { return } - close(r.reqDone) + if r.reqDone != nil { + close(r.reqDone) + } r.reqDoneClosed = true } diff --git a/http3/client.go b/http3/client.go index e4c51688a45..f7b9ba7c246 100644 --- a/http3/client.go +++ b/http3/client.go @@ -282,7 +282,11 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon } }() - rsp, rerr := c.doRequest(req, str, opt, reqDone) + doneChan := reqDone + if opt.DontCloseRequestStream { + doneChan = nil + } + rsp, rerr := c.doRequest(req, str, opt, doneChan) if rerr.err != nil { // if any error occurred close(reqDone) if rerr.streamErr != 0 { // if it was a stream error @@ -296,6 +300,9 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon c.conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason) } } + if opt.DontCloseRequestStream { + close(reqDone) + } return rsp, rerr.err } @@ -326,7 +333,7 @@ func (c *client) sendRequestBody(str Stream, body io.ReadCloser) error { return nil } -func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, reqDone chan struct{}) (*http.Response, requestError) { +func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) { var requestGzip bool if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" { requestGzip = true diff --git a/http3/client_test.go b/http3/client_test.go index 8aeeeff6d0d..c7075ef1847 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -913,6 +913,26 @@ var _ = Describe("Client", func() { cancel() Eventually(done).Should(BeClosed()) }) + + It("doesn't cancel a request if DontCloseRequestStream is set", func() { + rspBuf := bytes.NewBuffer(getResponse(404)) + + ctx, cancel := context.WithCancel(context.Background()) + req := req.WithContext(ctx) + conn.EXPECT().HandshakeComplete().Return(handshakeCtx) + conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) + conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) + buf := &bytes.Buffer{} + str.EXPECT().Close().MaxTimes(1) + + str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) + str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() + rsp, err := client.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true}) + Expect(err).ToNot(HaveOccurred()) + cancel() + _, err = io.ReadAll(rsp.Body) + Expect(err).ToNot(HaveOccurred()) + }) }) Context("gzip compression", func() { diff --git a/http3/roundtrip.go b/http3/roundtrip.go index a4d0a312cc7..5cde95a62fd 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -84,6 +84,7 @@ type RoundTripOpt struct { // If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn. OnlyCachedConn bool // DontCloseRequestStream controls whether the request stream is closed after sending the request. + // If set, context cancellations have no effect after the response headers are received. DontCloseRequestStream bool }