Skip to content

Commit

Permalink
http3: fix double close of chan when using DontCloseRequestStream
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Sep 23, 2022
1 parent 07412be commit bf29e68
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 60 deletions.
1 change: 1 addition & 0 deletions http3/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon
}
c.conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
}
return nil, rerr.err
}
if opt.DontCloseRequestStream {
close(reqDone)
Expand Down
124 changes: 64 additions & 60 deletions http3/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,22 @@ var _ = Describe("Client", func() {
Expect(rsp.StatusCode).To(Equal(418))
})

It("doesn't close the request stream, with DontCloseRequestStream set", func() {
rspBuf := bytes.NewBuffer(getResponse(418))
gomock.InOrder(
conn.EXPECT().HandshakeComplete().Return(handshakeCtx),
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}),
)
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
rsp, err := client.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true})
Expect(err).ToNot(HaveOccurred())
Expect(rsp.Proto).To(Equal("HTTP/3.0"))
Expect(rsp.ProtoMajor).To(Equal(3))
Expect(rsp.StatusCode).To(Equal(418))
})

Context("requests containing a Body", func() {
var strBuf *bytes.Buffer

Expand Down Expand Up @@ -849,47 +865,55 @@ var _ = Describe("Client", func() {
})

Context("request cancellations", func() {
It("cancels a request while waiting for the handshake to complete", func() {
ctx, cancel := context.WithCancel(context.Background())
req := req.WithContext(ctx)
conn.EXPECT().HandshakeComplete().Return(context.Background())

errChan := make(chan error)
go func() {
_, err := client.RoundTripOpt(req, RoundTripOpt{})
errChan <- err
}()
Consistently(errChan).ShouldNot(Receive())
cancel()
Eventually(errChan).Should(Receive(MatchError("context canceled")))
})

It("cancels a request while the request is still in flight", func() {
ctx, cancel := context.WithCancel(context.Background())
req := req.WithContext(ctx)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Close().MaxTimes(1)

str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)

done := make(chan struct{})
canceled := make(chan struct{})
gomock.InOrder(
str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }),
str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }),
)
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1)
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
cancel()
<-canceled
return 0, errors.New("test done")
for _, dontClose := range []bool{false, true} {
dontClose := dontClose

Context(fmt.Sprintf("with DontCloseRequestStream: %t", dontClose), func() {
roundTripOpt := RoundTripOpt{DontCloseRequestStream: dontClose}

It("cancels a request while waiting for the handshake to complete", func() {
ctx, cancel := context.WithCancel(context.Background())
req := req.WithContext(ctx)
conn.EXPECT().HandshakeComplete().Return(context.Background())

errChan := make(chan error)
go func() {
_, err := client.RoundTripOpt(req, roundTripOpt)
errChan <- err
}()
Consistently(errChan).ShouldNot(Receive())
cancel()
Eventually(errChan).Should(Receive(MatchError("context canceled")))
})

It("cancels a request while the request is still in flight", func() {
ctx, cancel := context.WithCancel(context.Background())
req := req.WithContext(ctx)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Close().MaxTimes(1)

str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)

done := make(chan struct{})
canceled := make(chan struct{})
gomock.InOrder(
str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }),
str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }),
)
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1)
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
cancel()
<-canceled
return 0, errors.New("test done")
})
_, err := client.RoundTripOpt(req, roundTripOpt)
Expect(err).To(MatchError("test done"))
Eventually(done).Should(BeClosed())
})
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("test done"))
Eventually(done).Should(BeClosed())
})
}

It("cancels a request after the response arrived", func() {
rspBuf := bytes.NewBuffer(getResponse(404))
Expand All @@ -912,26 +936,6 @@ var _ = Describe("Client", func() {
cancel()
Eventually(done).Should(BeClosed())
})

It("doesn't cancel a request if DontCloseRequestStream is set", func() {
rspBuf := bytes.NewBuffer(getResponse(404))

ctx, cancel := context.WithCancel(context.Background())
req := req.WithContext(ctx)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
buf := &bytes.Buffer{}
str.EXPECT().Close().MaxTimes(1)

str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
rsp, err := client.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true})
Expect(err).ToNot(HaveOccurred())
cancel()
_, err = io.ReadAll(rsp.Body)
Expect(err).ToNot(HaveOccurred())
})
})

Context("gzip compression", func() {
Expand Down

0 comments on commit bf29e68

Please sign in to comment.