diff --git a/receive_stream.go b/receive_stream.go index f9a1e066ff5..ae6a449b575 100644 --- a/receive_stream.go +++ b/receive_stream.go @@ -47,6 +47,7 @@ type receiveStream struct { resetRemotely bool // set when HandleResetStreamFrame() is called readChan chan struct{} + readOnce chan struct{} // cap: 1, to protect against concurrent use of Read deadline time.Time flowController flowcontrol.StreamFlowController @@ -70,6 +71,7 @@ func newReceiveStream( flowController: flowController, frameQueue: newFrameSorter(), readChan: make(chan struct{}, 1), + readOnce: make(chan struct{}, 1), finalOffset: protocol.MaxByteCount, version: version, } @@ -81,6 +83,12 @@ func (s *receiveStream) StreamID() protocol.StreamID { // Read implements io.Reader. It is not thread safe! func (s *receiveStream) Read(p []byte) (int, error) { + // Concurrent use of Read is not permitted (and doesn't make any sense), + // but sometimes people do it anyway. + // Make sure that we only execute one call at any given time to avoid hard to debug failures. + s.readOnce <- struct{}{} + defer func() { <-s.readOnce }() + s.mutex.Lock() completed, n, err := s.readImpl(p) s.mutex.Unlock() @@ -105,7 +113,7 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err return false, 0, s.closeForShutdownErr } - bytesRead := 0 + var bytesRead int var deadlineTimer *utils.Timer for bytesRead < len(p) { if s.currentFrame == nil || s.readPosInFrame >= len(s.currentFrame) { diff --git a/receive_stream_test.go b/receive_stream_test.go index 51a8414f230..20157bb408c 100644 --- a/receive_stream_test.go +++ b/receive_stream_test.go @@ -4,6 +4,8 @@ import ( "errors" "io" "runtime" + "sync" + "sync/atomic" "time" "github.com/golang/mock/gomock" @@ -403,6 +405,43 @@ var _ = Describe("Receive Stream", func() { Expect(n).To(BeZero()) Expect(err).To(MatchError(io.EOF)) }) + + // Calling Read concurrently doesn't make any sense (and is forbidden), + // but we still want to make sure that we don't complete the stream more than once + // if the user misuses our API. + // This would lead to an INTERNAL_ERROR ("tried to delete unknown outgoing stream"), + // which can be hard to debug. + // Note that even without the protection built into the receiveStream, this test + // is very timing-dependent, and would need to run a few hundred times to trigger the failure. + It("handles concurrent reads", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), gomock.Any()).AnyTimes() + var bytesRead protocol.ByteCount + mockFC.EXPECT().AddBytesRead(gomock.Any()).Do(func(n protocol.ByteCount) { bytesRead += n }).AnyTimes() + + var numCompleted int32 + mockSender.EXPECT().onStreamCompleted(streamID).Do(func(protocol.StreamID) { + atomic.AddInt32(&numCompleted, 1) + }).AnyTimes() + const num = 3 + var wg sync.WaitGroup + wg.Add(num) + for i := 0; i < num; i++ { + go func() { + defer wg.Done() + defer GinkgoRecover() + _, err := str.Read(make([]byte, 8)) + Expect(err).To(MatchError(io.EOF)) + }() + } + str.handleStreamFrame(&wire.StreamFrame{ + Offset: 0, + Data: []byte("foobar"), + Fin: true, + }) + wg.Wait() + Expect(bytesRead).To(BeEquivalentTo(6)) + Expect(atomic.LoadInt32(&numCompleted)).To(BeEquivalentTo(1)) + }) }) It("closes when CloseRemote is called", func() {