Skip to content

Commit

Permalink
http3: use the stream context to detect when the send side is closed (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed May 5, 2024
1 parent c7b58b5 commit bb6f066
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 33 deletions.
1 change: 1 addition & 0 deletions http3/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ var _ = Describe("Client", func() {
return len(b), nil
}) // SETTINGS frame
str = mockquic.NewMockStream(mockCtrl)
str.EXPECT().Context().Return(context.Background()).AnyTimes()
str.EXPECT().StreamID().AnyTimes()
conn = mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
Expand Down
2 changes: 2 additions & 0 deletions http3/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ var _ = Describe("Connection", func() {
// ... then open the stream
qstr := mockquic.NewMockStream(mockCtrl)
qstr.EXPECT().StreamID().Return(strID).MinTimes(1)
qstr.EXPECT().Context().Return(context.Background()).AnyTimes()
qconn.EXPECT().OpenStreamSync(gomock.Any()).Return(qstr, nil)
str, err := conn.openRequestStream(context.Background(), nil, nil, true, 1000)
Expect(err).ToNot(HaveOccurred())
Expand All @@ -397,6 +398,7 @@ var _ = Describe("Connection", func() {
// first open the stream...
qstr := mockquic.NewMockStream(mockCtrl)
qstr.EXPECT().StreamID().Return(strID).MinTimes(1)
qstr.EXPECT().Context().Return(context.Background()).AnyTimes()
qconn.EXPECT().OpenStreamSync(gomock.Any()).Return(qstr, nil)
str, err := conn.openRequestStream(context.Background(), nil, nil, true, 1000)
Expect(err).ToNot(HaveOccurred())
Expand Down
4 changes: 2 additions & 2 deletions http3/datagram.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ start:
d.mx.Unlock()
return data, nil
}
if d.receiveErr != nil {
if receiveErr := d.receiveErr; receiveErr != nil {
d.mx.Unlock()
return nil, d.receiveErr
return nil, receiveErr
}
d.mx.Unlock()

Expand Down
25 changes: 8 additions & 17 deletions http3/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ var _ = Describe("Server", func() {
exampleGetRequest *http.Request
examplePostRequest *http.Request
)
reqContext := context.Background()
reqContext, reqContextCancel := context.WithCancel(context.Background())

decodeHeader := func(str io.Reader) map[string][]string {
fields := make(map[string][]string)
Expand Down Expand Up @@ -140,6 +140,7 @@ var _ = Describe("Server", func() {

qpackDecoder = qpack.NewDecoder(nil)
str = mockquic.NewMockStream(mockCtrl)
str.EXPECT().Context().Return(reqContext).AnyTimes()
str.EXPECT().StreamID().AnyTimes()
qconn := mockquic.NewMockEarlyConnection(mockCtrl)
addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
Expand All @@ -157,7 +158,6 @@ var _ = Describe("Server", func() {
})

setRequest(encodeRequest(exampleGetRequest))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return len(p), nil
}).AnyTimes()
Expand All @@ -178,7 +178,6 @@ var _ = Describe("Server", func() {

responseBuf := &bytes.Buffer{}
setRequest(encodeRequest(exampleGetRequest))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
str.EXPECT().Close()
Expand All @@ -195,7 +194,6 @@ var _ = Describe("Server", func() {

responseBuf := &bytes.Buffer{}
setRequest(encodeRequest(exampleGetRequest))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
str.EXPECT().Close()
Expand All @@ -216,7 +214,6 @@ var _ = Describe("Server", func() {

responseBuf := &bytes.Buffer{}
setRequest(encodeRequest(exampleGetRequest))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
str.EXPECT().Close()
Expand All @@ -237,7 +234,6 @@ var _ = Describe("Server", func() {

responseBuf := &bytes.Buffer{}
setRequest(encodeRequest(exampleGetRequest))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
str.EXPECT().Close()
Expand All @@ -258,7 +254,6 @@ var _ = Describe("Server", func() {
Expect(err).ToNot(HaveOccurred())
responseBuf := &bytes.Buffer{}
setRequest(encodeRequest(headRequest))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
str.EXPECT().Close()
Expand All @@ -278,7 +273,6 @@ var _ = Describe("Server", func() {
Expect(err).ToNot(HaveOccurred())
responseBuf := &bytes.Buffer{}
setRequest(encodeRequest(headRequest))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
str.EXPECT().Close()
Expand All @@ -297,7 +291,6 @@ var _ = Describe("Server", func() {

responseBuf := &bytes.Buffer{}
setRequest(encodeRequest(exampleGetRequest))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeInternalError))
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeInternalError))
Expand All @@ -315,7 +308,6 @@ var _ = Describe("Server", func() {

responseBuf := &bytes.Buffer{}
setRequest(encodeRequest(exampleGetRequest))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeInternalError))
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeInternalError))
Expand Down Expand Up @@ -355,6 +347,7 @@ var _ = Describe("Server", func() {

buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
unknownStr := mockquic.NewMockStream(mockCtrl)
unknownStr.EXPECT().Context().Return(context.Background()).AnyTimes()
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
unknownStr.EXPECT().StreamID().AnyTimes()
conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
Expand All @@ -380,6 +373,7 @@ var _ = Describe("Server", func() {

buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
unknownStr := mockquic.NewMockStream(mockCtrl)
unknownStr.EXPECT().Context().Return(context.Background()).AnyTimes()
unknownStr.EXPECT().StreamID().AnyTimes()
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
Expand Down Expand Up @@ -407,6 +401,7 @@ var _ = Describe("Server", func() {

buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
unknownStr := mockquic.NewMockStream(mockCtrl)
unknownStr.EXPECT().Context().Return(context.Background()).AnyTimes()
unknownStr.EXPECT().StreamID().AnyTimes()
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
Expand All @@ -428,14 +423,15 @@ var _ = Describe("Server", func() {
const strID = protocol.StreamID(1234 * 4)
testErr := errors.New("test error")
done := make(chan struct{})
unknownStr := mockquic.NewMockStream(mockCtrl)
s.StreamHijacker = func(ft FrameType, _ quic.ConnectionTracingID, str quic.Stream, err error) (bool, error) {
defer close(done)
Expect(ft).To(BeZero())
Expect(str.StreamID()).To(Equal(strID))
Expect(err).To(MatchError(testErr))
return true, nil
}
unknownStr := mockquic.NewMockStream(mockCtrl)
unknownStr.EXPECT().Context().Return(context.Background()).AnyTimes()
unknownStr.EXPECT().StreamID().Return(strID).AnyTimes()
unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes()
conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
Expand Down Expand Up @@ -586,7 +582,6 @@ var _ = Describe("Server", func() {
responseBuf := &bytes.Buffer{}
setRequest(append(requestData, b...))
done := make(chan struct{})
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError))
str.EXPECT().Close().Do(func() error { close(done); return nil })
Expand All @@ -610,7 +605,6 @@ var _ = Describe("Server", func() {
b := (&dataFrame{Length: 6}).Append(nil) // add a body
b = append(b, []byte("foobar")...)
setRequest(append(requestData, b...))
str.EXPECT().Context().Return(reqContext)
var buf bytes.Buffer
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()

Expand Down Expand Up @@ -726,7 +720,6 @@ var _ = Describe("Server", func() {
})

setRequest(encodeRequest(examplePostRequest))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return len(p), nil
}).AnyTimes()
Expand All @@ -747,9 +740,7 @@ var _ = Describe("Server", func() {
})
setRequest(encodeRequest(examplePostRequest))

reqContext, cancel := context.WithCancel(context.Background())
cancel()
str.EXPECT().Context().Return(reqContext)
reqContextCancel()
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return len(p), nil
}).AnyTimes()
Expand Down
4 changes: 4 additions & 0 deletions http3/state_tracking_stream.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package http3

import (
"context"
"errors"
"sync"

Expand All @@ -26,6 +27,9 @@ type stateTrackingStream struct {
}

func newStateTrackingStream(s quic.Stream, onStateChange func(streamState, error)) *stateTrackingStream {
context.AfterFunc(s.Context(), func() {
onStateChange(streamStateSendClosed, context.Cause(s.Context()))
})
return &stateTrackingStream{
Stream: s,
state: streamStateOpen,
Expand Down
68 changes: 54 additions & 14 deletions http3/state_tracking_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package http3

import (
"bytes"
"context"
"errors"
"io"

Expand All @@ -19,22 +20,15 @@ type stateTransition struct {
}

var _ = Describe("State Tracking Stream", func() {
var (
qstr *mockquic.MockStream
str *stateTrackingStream
states []stateTransition
)

BeforeEach(func() {
states = nil
qstr = mockquic.NewMockStream(mockCtrl)
It("recognizes when the receive side is closed", func() {
qstr := mockquic.NewMockStream(mockCtrl)
qstr.EXPECT().StreamID().AnyTimes()
str = newStateTrackingStream(qstr, func(state streamState, err error) {
qstr.EXPECT().Context().Return(context.Background()).AnyTimes()
var states []stateTransition
str := newStateTrackingStream(qstr, func(state streamState, err error) {
states = append(states, stateTransition{state, err})
})
})

It("recognizes when the receive side is closed", func() {
buf := bytes.NewBuffer([]byte("foobar"))
qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
for i := 0; i < 3; i++ {
Expand All @@ -50,6 +44,14 @@ var _ = Describe("State Tracking Stream", func() {
})

It("recognizes read cancellations", func() {
qstr := mockquic.NewMockStream(mockCtrl)
qstr.EXPECT().StreamID().AnyTimes()
qstr.EXPECT().Context().Return(context.Background()).AnyTimes()
var states []stateTransition
str := newStateTrackingStream(qstr, func(state streamState, err error) {
states = append(states, stateTransition{state, err})
})

buf := bytes.NewBuffer([]byte("foobar"))
qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
qstr.EXPECT().CancelRead(quic.StreamErrorCode(1337))
Expand All @@ -62,7 +64,15 @@ var _ = Describe("State Tracking Stream", func() {
Expect(states[0].err).To(Equal(&quic.StreamError{ErrorCode: 1337}))
})

It("recognizes when the send side is closed", func() {
It("recognizes when the send side is closed, when write errors", func() {
qstr := mockquic.NewMockStream(mockCtrl)
qstr.EXPECT().StreamID().AnyTimes()
qstr.EXPECT().Context().Return(context.Background()).AnyTimes()
var states []stateTransition
str := newStateTrackingStream(qstr, func(state streamState, err error) {
states = append(states, stateTransition{state, err})
})

testErr := errors.New("test error")
qstr.EXPECT().Write([]byte("foo")).Return(3, nil)
qstr.EXPECT().Write([]byte("bar")).Return(0, testErr)
Expand All @@ -76,7 +86,15 @@ var _ = Describe("State Tracking Stream", func() {
Expect(states[0].err).To(Equal(testErr))
})

It("recognizes write cancellations", func() {
It("recognizes when the send side is closed, when CancelWrite is called", func() {
qstr := mockquic.NewMockStream(mockCtrl)
qstr.EXPECT().StreamID().AnyTimes()
qstr.EXPECT().Context().Return(context.Background()).AnyTimes()
var states []stateTransition
str := newStateTrackingStream(qstr, func(state streamState, err error) {
states = append(states, stateTransition{state, err})
})

qstr.EXPECT().Write(gomock.Any())
qstr.EXPECT().CancelWrite(quic.StreamErrorCode(1337))
_, err := str.Write([]byte("foobar"))
Expand All @@ -87,4 +105,26 @@ var _ = Describe("State Tracking Stream", func() {
Expect(states[0].state).To(Equal(streamStateSendClosed))
Expect(states[0].err).To(Equal(&quic.StreamError{ErrorCode: 1337}))
})

It("recognizes when the send side is closed, when the stream context is canceled", func() {
qstr := mockquic.NewMockStream(mockCtrl)
qstr.EXPECT().StreamID().AnyTimes()
ctx, cancel := context.WithCancelCause(context.Background())
qstr.EXPECT().Context().Return(ctx).AnyTimes()
var states []stateTransition

done := make(chan struct{})
newStateTrackingStream(qstr, func(state streamState, err error) {
states = append(states, stateTransition{state, err})
close(done)
})

Expect(states).To(BeEmpty())
testErr := errors.New("test error")
cancel(testErr)
Eventually(done).Should(BeClosed())
Expect(states).To(HaveLen(1))
Expect(states[0].state).To(Equal(streamStateSendClosed))
Expect(states[0].err).To(Equal(testErr))
})
})
40 changes: 40 additions & 0 deletions integrationtests/self/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,46 @@ var _ = Describe("HTTP tests", func() {
// make sure we can't send anymore
Expect(str.SendDatagram([]byte("foo"))).ToNot(Succeed())
})

It("detecting a stream reset from the server", func() {
errChan := make(chan error, 1)
datagramChan := make(chan []byte, 1)
mux.HandleFunc("/datagrams", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
conn := w.(http3.Hijacker).Connection()
Eventually(conn.ReceivedSettings()).Should(BeClosed())
Expect(conn.Settings().EnableDatagrams).To(BeTrue())
w.WriteHeader(http.StatusOK)

str := w.(http3.HTTPStreamer).HTTPStream()
go str.Read([]byte{0}) // need to continue reading from stream to observe state transitions

for {
data, err := str.ReceiveDatagram(context.Background())
if err != nil {
errChan <- err
return
}
str.CancelRead(42)
datagramChan <- data
}
})

str, closeFn := openDatagramStream(fmt.Sprintf("https://localhost:%d/datagrams", port))
defer closeFn()
go str.Read([]byte{0})

Expect(str.SendDatagram([]byte("foo"))).To(Succeed())
Eventually(datagramChan).Should(Receive(Equal([]byte("foo"))))
// signal that we're done sending

var resetErr error
Eventually(errChan).Should(Receive(&resetErr))
Expect(resetErr).To(Equal(&quic.StreamError{ErrorCode: 42, Remote: false}))

// make sure we can't send anymore
Expect(str.SendDatagram([]byte("foo"))).To(Equal(&quic.StreamError{ErrorCode: 42, Remote: true}))
})
})

Context("0-RTT", func() {
Expand Down

0 comments on commit bb6f066

Please sign in to comment.