Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

protect against concurrent use of Stream.Read #3380

Merged
merged 1 commit into from Apr 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 9 additions & 1 deletion receive_stream.go
Expand Up @@ -47,6 +47,7 @@ type receiveStream struct {
resetRemotely bool // set when HandleResetStreamFrame() is called

readChan chan struct{}
readOnce chan struct{} // cap: 1, to protect against concurrent use of Read
deadline time.Time

flowController flowcontrol.StreamFlowController
Expand All @@ -70,6 +71,7 @@ func newReceiveStream(
flowController: flowController,
frameQueue: newFrameSorter(),
readChan: make(chan struct{}, 1),
readOnce: make(chan struct{}, 1),
finalOffset: protocol.MaxByteCount,
version: version,
}
Expand All @@ -81,6 +83,12 @@ func (s *receiveStream) StreamID() protocol.StreamID {

// Read implements io.Reader. It is not thread safe!
func (s *receiveStream) Read(p []byte) (int, error) {
// Concurrent use of Read is not permitted (and doesn't make any sense),
// but sometimes people do it anyway.
// Make sure that we only execute one call at any given time to avoid hard to debug failures.
s.readOnce <- struct{}{}
defer func() { <-s.readOnce }()

s.mutex.Lock()
completed, n, err := s.readImpl(p)
s.mutex.Unlock()
Expand All @@ -105,7 +113,7 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err
return false, 0, s.closeForShutdownErr
}

bytesRead := 0
var bytesRead int
var deadlineTimer *utils.Timer
for bytesRead < len(p) {
if s.currentFrame == nil || s.readPosInFrame >= len(s.currentFrame) {
Expand Down
39 changes: 39 additions & 0 deletions receive_stream_test.go
Expand Up @@ -4,6 +4,8 @@ import (
"errors"
"io"
"runtime"
"sync"
"sync/atomic"
"time"

"github.com/golang/mock/gomock"
Expand Down Expand Up @@ -403,6 +405,43 @@ var _ = Describe("Receive Stream", func() {
Expect(n).To(BeZero())
Expect(err).To(MatchError(io.EOF))
})

// Calling Read concurrently doesn't make any sense (and is forbidden),
// but we still want to make sure that we don't complete the stream more than once
// if the user misuses our API.
// This would lead to an INTERNAL_ERROR ("tried to delete unknown outgoing stream"),
// which can be hard to debug.
// Note that even without the protection built into the receiveStream, this test
// is very timing-dependent, and would need to run a few hundred times to trigger the failure.
It("handles concurrent reads", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), gomock.Any()).AnyTimes()
var bytesRead protocol.ByteCount
mockFC.EXPECT().AddBytesRead(gomock.Any()).Do(func(n protocol.ByteCount) { bytesRead += n }).AnyTimes()

var numCompleted int32
mockSender.EXPECT().onStreamCompleted(streamID).Do(func(protocol.StreamID) {
atomic.AddInt32(&numCompleted, 1)
}).AnyTimes()
const num = 3
var wg sync.WaitGroup
wg.Add(num)
for i := 0; i < num; i++ {
go func() {
defer wg.Done()
defer GinkgoRecover()
_, err := str.Read(make([]byte, 8))
Expect(err).To(MatchError(io.EOF))
}()
}
str.handleStreamFrame(&wire.StreamFrame{
Offset: 0,
Data: []byte("foobar"),
Fin: true,
})
wg.Wait()
Expect(bytesRead).To(BeEquivalentTo(6))
Expect(atomic.LoadInt32(&numCompleted)).To(BeEquivalentTo(1))
})
})

It("closes when CloseRemote is called", func() {
Expand Down