From 088ce44df9ae835761766f7b57b4dd905dc950d2 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 2 Aug 2022 23:05:56 +0200 Subject: [PATCH] use a generic streams map for outgoing streams --- streams_map.go | 10 +- ...outgoing_uni.go => streams_map_outgoing.go | 64 ++--- streams_map_outgoing_bidi.go | 226 ------------------ streams_map_outgoing_generic.go | 224 ----------------- ...ic_test.go => streams_map_outgoing_test.go | 56 +++-- 5 files changed, 70 insertions(+), 510 deletions(-) rename streams_map_outgoing_uni.go => streams_map_outgoing.go (74%) delete mode 100644 streams_map_outgoing_bidi.go delete mode 100644 streams_map_outgoing_generic.go rename streams_map_outgoing_generic_test.go => streams_map_outgoing_test.go (89%) diff --git a/streams_map.go b/streams_map.go index 79c1ee91a86..b7fbeaa1799 100644 --- a/streams_map.go +++ b/streams_map.go @@ -55,8 +55,8 @@ type streamsMap struct { newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController mutex sync.Mutex - outgoingBidiStreams *outgoingBidiStreamsMap - outgoingUniStreams *outgoingUniStreamsMap + outgoingBidiStreams *outgoingStreamsMap[streamI] + outgoingUniStreams *outgoingStreamsMap[sendStreamI] incomingBidiStreams *incomingBidiStreamsMap incomingUniStreams *incomingUniStreamsMap reset bool @@ -85,7 +85,8 @@ func newStreamsMap( } func (m *streamsMap) initMaps() { - m.outgoingBidiStreams = newOutgoingBidiStreamsMap( + m.outgoingBidiStreams = newOutgoingStreamsMap( + protocol.StreamTypeBidi, func(num protocol.StreamNum) streamI { id := num.StreamID(protocol.StreamTypeBidi, m.perspective) return newStream(id, m.sender, m.newFlowController(id), m.version) @@ -100,7 +101,8 @@ func (m *streamsMap) initMaps() { m.maxIncomingBidiStreams, m.sender.queueControlFrame, ) - m.outgoingUniStreams = newOutgoingUniStreamsMap( + m.outgoingUniStreams = newOutgoingStreamsMap( + protocol.StreamTypeUni, func(num protocol.StreamNum) sendStreamI { id := num.StreamID(protocol.StreamTypeUni, m.perspective) return newSendStream(id, m.sender, m.newFlowController(id), m.version) diff --git a/streams_map_outgoing_uni.go b/streams_map_outgoing.go similarity index 74% rename from streams_map_outgoing_uni.go rename to streams_map_outgoing.go index 8782364a54a..d4f249f023a 100644 --- a/streams_map_outgoing_uni.go +++ b/streams_map_outgoing.go @@ -1,7 +1,3 @@ -// This file was automatically generated by genny. -// Any changes will be lost if this file is regenerated. -// see https://github.com/cheekybits/genny - package quic import ( @@ -12,10 +8,16 @@ import ( "github.com/lucas-clemente/quic-go/internal/wire" ) -type outgoingUniStreamsMap struct { +type outgoingStream interface { + updateSendWindow(protocol.ByteCount) + closeForShutdown(error) +} + +type outgoingStreamsMap[T outgoingStream] struct { mutex sync.RWMutex - streams map[protocol.StreamNum]sendStreamI + streamType protocol.StreamType + streams map[protocol.StreamNum]T openQueue map[uint64]chan struct{} lowestInQueue uint64 @@ -25,18 +27,20 @@ type outgoingUniStreamsMap struct { maxStream protocol.StreamNum // the maximum stream ID we're allowed to open blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream - newStream func(protocol.StreamNum) sendStreamI + newStream func(protocol.StreamNum) T queueStreamIDBlocked func(*wire.StreamsBlockedFrame) closeErr error } -func newOutgoingUniStreamsMap( - newStream func(protocol.StreamNum) sendStreamI, +func newOutgoingStreamsMap[T outgoingStream]( + streamType protocol.StreamType, + newStream func(protocol.StreamNum) T, queueControlFrame func(wire.Frame), -) *outgoingUniStreamsMap { - return &outgoingUniStreamsMap{ - streams: make(map[protocol.StreamNum]sendStreamI), +) *outgoingStreamsMap[T] { + return &outgoingStreamsMap[T]{ + streamType: streamType, + streams: make(map[protocol.StreamNum]T), openQueue: make(map[uint64]chan struct{}), maxStream: protocol.InvalidStreamNum, nextStream: 1, @@ -45,32 +49,32 @@ func newOutgoingUniStreamsMap( } } -func (m *outgoingUniStreamsMap) OpenStream() (sendStreamI, error) { +func (m *outgoingStreamsMap[T]) OpenStream() (T, error) { m.mutex.Lock() defer m.mutex.Unlock() if m.closeErr != nil { - return nil, m.closeErr + return *new(T), m.closeErr } // if there are OpenStreamSync calls waiting, return an error here if len(m.openQueue) > 0 || m.nextStream > m.maxStream { m.maybeSendBlockedFrame() - return nil, streamOpenErr{errTooManyOpenStreams} + return *new(T), streamOpenErr{errTooManyOpenStreams} } return m.openStream(), nil } -func (m *outgoingUniStreamsMap) OpenStreamSync(ctx context.Context) (sendStreamI, error) { +func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) { m.mutex.Lock() defer m.mutex.Unlock() if m.closeErr != nil { - return nil, m.closeErr + return *new(T), m.closeErr } if err := ctx.Err(); err != nil { - return nil, err + return *new(T), err } if len(m.openQueue) == 0 && m.nextStream <= m.maxStream { @@ -92,13 +96,13 @@ func (m *outgoingUniStreamsMap) OpenStreamSync(ctx context.Context) (sendStreamI case <-ctx.Done(): m.mutex.Lock() delete(m.openQueue, queuePos) - return nil, ctx.Err() + return *new(T), ctx.Err() case <-waitChan: } m.mutex.Lock() if m.closeErr != nil { - return nil, m.closeErr + return *new(T), m.closeErr } if m.nextStream > m.maxStream { // no stream available. Continue waiting @@ -112,7 +116,7 @@ func (m *outgoingUniStreamsMap) OpenStreamSync(ctx context.Context) (sendStreamI } } -func (m *outgoingUniStreamsMap) openStream() sendStreamI { +func (m *outgoingStreamsMap[T]) openStream() T { s := m.newStream(m.nextStream) m.streams[m.nextStream] = s m.nextStream++ @@ -121,7 +125,7 @@ func (m *outgoingUniStreamsMap) openStream() sendStreamI { // maybeSendBlockedFrame queues a STREAMS_BLOCKED frame for the current stream offset, // if we haven't sent one for this offset yet -func (m *outgoingUniStreamsMap) maybeSendBlockedFrame() { +func (m *outgoingStreamsMap[T]) maybeSendBlockedFrame() { if m.blockedSent { return } @@ -131,17 +135,17 @@ func (m *outgoingUniStreamsMap) maybeSendBlockedFrame() { streamNum = m.maxStream } m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{ - Type: protocol.StreamTypeUni, + Type: m.streamType, StreamLimit: streamNum, }) m.blockedSent = true } -func (m *outgoingUniStreamsMap) GetStream(num protocol.StreamNum) (sendStreamI, error) { +func (m *outgoingStreamsMap[T]) GetStream(num protocol.StreamNum) (T, error) { m.mutex.RLock() if num >= m.nextStream { m.mutex.RUnlock() - return nil, streamError{ + return *new(T), streamError{ message: "peer attempted to open stream %d", nums: []protocol.StreamNum{num}, } @@ -151,7 +155,7 @@ func (m *outgoingUniStreamsMap) GetStream(num protocol.StreamNum) (sendStreamI, return s, nil } -func (m *outgoingUniStreamsMap) DeleteStream(num protocol.StreamNum) error { +func (m *outgoingStreamsMap[T]) DeleteStream(num protocol.StreamNum) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -165,7 +169,7 @@ func (m *outgoingUniStreamsMap) DeleteStream(num protocol.StreamNum) error { return nil } -func (m *outgoingUniStreamsMap) SetMaxStream(num protocol.StreamNum) { +func (m *outgoingStreamsMap[T]) SetMaxStream(num protocol.StreamNum) { m.mutex.Lock() defer m.mutex.Unlock() @@ -183,7 +187,7 @@ func (m *outgoingUniStreamsMap) SetMaxStream(num protocol.StreamNum) { // UpdateSendWindow is called when the peer's transport parameters are received. // Only in the case of a 0-RTT handshake will we have open streams at this point. // We might need to update the send window, in case the server increased it. -func (m *outgoingUniStreamsMap) UpdateSendWindow(limit protocol.ByteCount) { +func (m *outgoingStreamsMap[T]) UpdateSendWindow(limit protocol.ByteCount) { m.mutex.Lock() for _, str := range m.streams { str.updateSendWindow(limit) @@ -192,7 +196,7 @@ func (m *outgoingUniStreamsMap) UpdateSendWindow(limit protocol.ByteCount) { } // unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream -func (m *outgoingUniStreamsMap) unblockOpenSync() { +func (m *outgoingStreamsMap[T]) unblockOpenSync() { if len(m.openQueue) == 0 { return } @@ -211,7 +215,7 @@ func (m *outgoingUniStreamsMap) unblockOpenSync() { } } -func (m *outgoingUniStreamsMap) CloseWithError(err error) { +func (m *outgoingStreamsMap[T]) CloseWithError(err error) { m.mutex.Lock() m.closeErr = err for _, str := range m.streams { diff --git a/streams_map_outgoing_bidi.go b/streams_map_outgoing_bidi.go deleted file mode 100644 index 3f7ec166ad1..00000000000 --- a/streams_map_outgoing_bidi.go +++ /dev/null @@ -1,226 +0,0 @@ -// This file was automatically generated by genny. -// Any changes will be lost if this file is regenerated. -// see https://github.com/cheekybits/genny - -package quic - -import ( - "context" - "sync" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/wire" -) - -type outgoingBidiStreamsMap struct { - mutex sync.RWMutex - - streams map[protocol.StreamNum]streamI - - openQueue map[uint64]chan struct{} - lowestInQueue uint64 - highestInQueue uint64 - - nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync) - maxStream protocol.StreamNum // the maximum stream ID we're allowed to open - blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream - - newStream func(protocol.StreamNum) streamI - queueStreamIDBlocked func(*wire.StreamsBlockedFrame) - - closeErr error -} - -func newOutgoingBidiStreamsMap( - newStream func(protocol.StreamNum) streamI, - queueControlFrame func(wire.Frame), -) *outgoingBidiStreamsMap { - return &outgoingBidiStreamsMap{ - streams: make(map[protocol.StreamNum]streamI), - openQueue: make(map[uint64]chan struct{}), - maxStream: protocol.InvalidStreamNum, - nextStream: 1, - newStream: newStream, - queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) }, - } -} - -func (m *outgoingBidiStreamsMap) OpenStream() (streamI, error) { - m.mutex.Lock() - defer m.mutex.Unlock() - - if m.closeErr != nil { - return nil, m.closeErr - } - - // if there are OpenStreamSync calls waiting, return an error here - if len(m.openQueue) > 0 || m.nextStream > m.maxStream { - m.maybeSendBlockedFrame() - return nil, streamOpenErr{errTooManyOpenStreams} - } - return m.openStream(), nil -} - -func (m *outgoingBidiStreamsMap) OpenStreamSync(ctx context.Context) (streamI, error) { - m.mutex.Lock() - defer m.mutex.Unlock() - - if m.closeErr != nil { - return nil, m.closeErr - } - - if err := ctx.Err(); err != nil { - return nil, err - } - - if len(m.openQueue) == 0 && m.nextStream <= m.maxStream { - return m.openStream(), nil - } - - waitChan := make(chan struct{}, 1) - queuePos := m.highestInQueue - m.highestInQueue++ - if len(m.openQueue) == 0 { - m.lowestInQueue = queuePos - } - m.openQueue[queuePos] = waitChan - m.maybeSendBlockedFrame() - - for { - m.mutex.Unlock() - select { - case <-ctx.Done(): - m.mutex.Lock() - delete(m.openQueue, queuePos) - return nil, ctx.Err() - case <-waitChan: - } - m.mutex.Lock() - - if m.closeErr != nil { - return nil, m.closeErr - } - if m.nextStream > m.maxStream { - // no stream available. Continue waiting - continue - } - str := m.openStream() - delete(m.openQueue, queuePos) - m.lowestInQueue = queuePos + 1 - m.unblockOpenSync() - return str, nil - } -} - -func (m *outgoingBidiStreamsMap) openStream() streamI { - s := m.newStream(m.nextStream) - m.streams[m.nextStream] = s - m.nextStream++ - return s -} - -// maybeSendBlockedFrame queues a STREAMS_BLOCKED frame for the current stream offset, -// if we haven't sent one for this offset yet -func (m *outgoingBidiStreamsMap) maybeSendBlockedFrame() { - if m.blockedSent { - return - } - - var streamNum protocol.StreamNum - if m.maxStream != protocol.InvalidStreamNum { - streamNum = m.maxStream - } - m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{ - Type: protocol.StreamTypeBidi, - StreamLimit: streamNum, - }) - m.blockedSent = true -} - -func (m *outgoingBidiStreamsMap) GetStream(num protocol.StreamNum) (streamI, error) { - m.mutex.RLock() - if num >= m.nextStream { - m.mutex.RUnlock() - return nil, streamError{ - message: "peer attempted to open stream %d", - nums: []protocol.StreamNum{num}, - } - } - s := m.streams[num] - m.mutex.RUnlock() - return s, nil -} - -func (m *outgoingBidiStreamsMap) DeleteStream(num protocol.StreamNum) error { - m.mutex.Lock() - defer m.mutex.Unlock() - - if _, ok := m.streams[num]; !ok { - return streamError{ - message: "tried to delete unknown outgoing stream %d", - nums: []protocol.StreamNum{num}, - } - } - delete(m.streams, num) - return nil -} - -func (m *outgoingBidiStreamsMap) SetMaxStream(num protocol.StreamNum) { - m.mutex.Lock() - defer m.mutex.Unlock() - - if num <= m.maxStream { - return - } - m.maxStream = num - m.blockedSent = false - if m.maxStream < m.nextStream-1+protocol.StreamNum(len(m.openQueue)) { - m.maybeSendBlockedFrame() - } - m.unblockOpenSync() -} - -// UpdateSendWindow is called when the peer's transport parameters are received. -// Only in the case of a 0-RTT handshake will we have open streams at this point. -// We might need to update the send window, in case the server increased it. -func (m *outgoingBidiStreamsMap) UpdateSendWindow(limit protocol.ByteCount) { - m.mutex.Lock() - for _, str := range m.streams { - str.updateSendWindow(limit) - } - m.mutex.Unlock() -} - -// unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream -func (m *outgoingBidiStreamsMap) unblockOpenSync() { - if len(m.openQueue) == 0 { - return - } - for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ { - c, ok := m.openQueue[qp] - if !ok { // entry was deleted because the context was canceled - continue - } - // unblockOpenSync is called both from OpenStreamSync and from SetMaxStream. - // It's sufficient to only unblock OpenStreamSync once. - select { - case c <- struct{}{}: - default: - } - return - } -} - -func (m *outgoingBidiStreamsMap) CloseWithError(err error) { - m.mutex.Lock() - m.closeErr = err - for _, str := range m.streams { - str.closeForShutdown(err) - } - for _, c := range m.openQueue { - if c != nil { - close(c) - } - } - m.mutex.Unlock() -} diff --git a/streams_map_outgoing_generic.go b/streams_map_outgoing_generic.go deleted file mode 100644 index dde75043c2d..00000000000 --- a/streams_map_outgoing_generic.go +++ /dev/null @@ -1,224 +0,0 @@ -package quic - -import ( - "context" - "sync" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/wire" -) - -//go:generate genny -in $GOFILE -out streams_map_outgoing_bidi.go gen "item=streamI Item=BidiStream streamTypeGeneric=protocol.StreamTypeBidi" -//go:generate genny -in $GOFILE -out streams_map_outgoing_uni.go gen "item=sendStreamI Item=UniStream streamTypeGeneric=protocol.StreamTypeUni" -type outgoingItemsMap struct { - mutex sync.RWMutex - - streams map[protocol.StreamNum]item - - openQueue map[uint64]chan struct{} - lowestInQueue uint64 - highestInQueue uint64 - - nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync) - maxStream protocol.StreamNum // the maximum stream ID we're allowed to open - blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream - - newStream func(protocol.StreamNum) item - queueStreamIDBlocked func(*wire.StreamsBlockedFrame) - - closeErr error -} - -func newOutgoingItemsMap( - newStream func(protocol.StreamNum) item, - queueControlFrame func(wire.Frame), -) *outgoingItemsMap { - return &outgoingItemsMap{ - streams: make(map[protocol.StreamNum]item), - openQueue: make(map[uint64]chan struct{}), - maxStream: protocol.InvalidStreamNum, - nextStream: 1, - newStream: newStream, - queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) }, - } -} - -func (m *outgoingItemsMap) OpenStream() (item, error) { - m.mutex.Lock() - defer m.mutex.Unlock() - - if m.closeErr != nil { - return nil, m.closeErr - } - - // if there are OpenStreamSync calls waiting, return an error here - if len(m.openQueue) > 0 || m.nextStream > m.maxStream { - m.maybeSendBlockedFrame() - return nil, streamOpenErr{errTooManyOpenStreams} - } - return m.openStream(), nil -} - -func (m *outgoingItemsMap) OpenStreamSync(ctx context.Context) (item, error) { - m.mutex.Lock() - defer m.mutex.Unlock() - - if m.closeErr != nil { - return nil, m.closeErr - } - - if err := ctx.Err(); err != nil { - return nil, err - } - - if len(m.openQueue) == 0 && m.nextStream <= m.maxStream { - return m.openStream(), nil - } - - waitChan := make(chan struct{}, 1) - queuePos := m.highestInQueue - m.highestInQueue++ - if len(m.openQueue) == 0 { - m.lowestInQueue = queuePos - } - m.openQueue[queuePos] = waitChan - m.maybeSendBlockedFrame() - - for { - m.mutex.Unlock() - select { - case <-ctx.Done(): - m.mutex.Lock() - delete(m.openQueue, queuePos) - return nil, ctx.Err() - case <-waitChan: - } - m.mutex.Lock() - - if m.closeErr != nil { - return nil, m.closeErr - } - if m.nextStream > m.maxStream { - // no stream available. Continue waiting - continue - } - str := m.openStream() - delete(m.openQueue, queuePos) - m.lowestInQueue = queuePos + 1 - m.unblockOpenSync() - return str, nil - } -} - -func (m *outgoingItemsMap) openStream() item { - s := m.newStream(m.nextStream) - m.streams[m.nextStream] = s - m.nextStream++ - return s -} - -// maybeSendBlockedFrame queues a STREAMS_BLOCKED frame for the current stream offset, -// if we haven't sent one for this offset yet -func (m *outgoingItemsMap) maybeSendBlockedFrame() { - if m.blockedSent { - return - } - - var streamNum protocol.StreamNum - if m.maxStream != protocol.InvalidStreamNum { - streamNum = m.maxStream - } - m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{ - Type: streamTypeGeneric, - StreamLimit: streamNum, - }) - m.blockedSent = true -} - -func (m *outgoingItemsMap) GetStream(num protocol.StreamNum) (item, error) { - m.mutex.RLock() - if num >= m.nextStream { - m.mutex.RUnlock() - return nil, streamError{ - message: "peer attempted to open stream %d", - nums: []protocol.StreamNum{num}, - } - } - s := m.streams[num] - m.mutex.RUnlock() - return s, nil -} - -func (m *outgoingItemsMap) DeleteStream(num protocol.StreamNum) error { - m.mutex.Lock() - defer m.mutex.Unlock() - - if _, ok := m.streams[num]; !ok { - return streamError{ - message: "tried to delete unknown outgoing stream %d", - nums: []protocol.StreamNum{num}, - } - } - delete(m.streams, num) - return nil -} - -func (m *outgoingItemsMap) SetMaxStream(num protocol.StreamNum) { - m.mutex.Lock() - defer m.mutex.Unlock() - - if num <= m.maxStream { - return - } - m.maxStream = num - m.blockedSent = false - if m.maxStream < m.nextStream-1+protocol.StreamNum(len(m.openQueue)) { - m.maybeSendBlockedFrame() - } - m.unblockOpenSync() -} - -// UpdateSendWindow is called when the peer's transport parameters are received. -// Only in the case of a 0-RTT handshake will we have open streams at this point. -// We might need to update the send window, in case the server increased it. -func (m *outgoingItemsMap) UpdateSendWindow(limit protocol.ByteCount) { - m.mutex.Lock() - for _, str := range m.streams { - str.updateSendWindow(limit) - } - m.mutex.Unlock() -} - -// unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream -func (m *outgoingItemsMap) unblockOpenSync() { - if len(m.openQueue) == 0 { - return - } - for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ { - c, ok := m.openQueue[qp] - if !ok { // entry was deleted because the context was canceled - continue - } - // unblockOpenSync is called both from OpenStreamSync and from SetMaxStream. - // It's sufficient to only unblock OpenStreamSync once. - select { - case c <- struct{}{}: - default: - } - return - } -} - -func (m *outgoingItemsMap) CloseWithError(err error) { - m.mutex.Lock() - m.closeErr = err - for _, str := range m.streams { - str.closeForShutdown(err) - } - for _, c := range m.openQueue { - if c != nil { - close(c) - } - } - m.mutex.Unlock() -} diff --git a/streams_map_outgoing_generic_test.go b/streams_map_outgoing_test.go similarity index 89% rename from streams_map_outgoing_generic_test.go rename to streams_map_outgoing_test.go index 421fb4ae26c..ed2b821fb4d 100644 --- a/streams_map_outgoing_generic_test.go +++ b/streams_map_outgoing_test.go @@ -18,11 +18,13 @@ import ( var _ = Describe("Streams Map (outgoing)", func() { var ( - m *outgoingItemsMap - newItem func(num protocol.StreamNum) item + m *outgoingStreamsMap[*mockGenericStream] + newStr func(num protocol.StreamNum) *mockGenericStream mockSender *MockStreamSender ) + const streamType = 42 + // waitForEnqueued waits until there are n go routines waiting on OpenStreamSync() waitForEnqueued := func(n int) { Eventually(func() int { @@ -33,11 +35,11 @@ var _ = Describe("Streams Map (outgoing)", func() { } BeforeEach(func() { - newItem = func(num protocol.StreamNum) item { + newStr = func(num protocol.StreamNum) *mockGenericStream { return &mockGenericStream{num: num} } mockSender = NewMockStreamSender(mockCtrl) - m = newOutgoingItemsMap(newItem, mockSender.queueControlFrame) + m = newOutgoingStreamsMap[*mockGenericStream](streamType, newStr, mockSender.queueControlFrame) }) Context("no stream ID limit", func() { @@ -48,10 +50,10 @@ var _ = Describe("Streams Map (outgoing)", func() { It("opens streams", func() { str, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) + Expect(str.num).To(Equal(protocol.StreamNum(1))) str, err = m.OpenStream() Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) + Expect(str.num).To(Equal(protocol.StreamNum(2))) }) It("doesn't open streams after it has been closed", func() { @@ -66,7 +68,7 @@ var _ = Describe("Streams Map (outgoing)", func() { Expect(err).ToNot(HaveOccurred()) str, err := m.GetStream(1) Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) + Expect(str.num).To(Equal(protocol.StreamNum(1))) }) It("errors when trying to get a stream that has not yet been opened", func() { @@ -107,10 +109,10 @@ var _ = Describe("Streams Map (outgoing)", func() { Expect(err).ToNot(HaveOccurred()) testErr := errors.New("test err") m.CloseWithError(testErr) - Expect(str1.(*mockGenericStream).closed).To(BeTrue()) - Expect(str1.(*mockGenericStream).closeErr).To(MatchError(testErr)) - Expect(str2.(*mockGenericStream).closed).To(BeTrue()) - Expect(str2.(*mockGenericStream).closeErr).To(MatchError(testErr)) + Expect(str1.closed).To(BeTrue()) + Expect(str1.closeErr).To(MatchError(testErr)) + Expect(str2.closed).To(BeTrue()) + Expect(str2.closeErr).To(MatchError(testErr)) }) It("updates the send window", func() { @@ -119,8 +121,8 @@ var _ = Describe("Streams Map (outgoing)", func() { str2, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) m.UpdateSendWindow(1337) - Expect(str1.(*mockGenericStream).sendWindow).To(BeEquivalentTo(1337)) - Expect(str2.(*mockGenericStream).sendWindow).To(BeEquivalentTo(1337)) + Expect(str1.sendWindow).To(BeEquivalentTo(1337)) + Expect(str2.sendWindow).To(BeEquivalentTo(1337)) }) }) @@ -145,7 +147,7 @@ var _ = Describe("Streams Map (outgoing)", func() { defer GinkgoRecover() str, err := m.OpenStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) + Expect(str.num).To(Equal(protocol.StreamNum(1))) close(done) }() waitForEnqueued(1) @@ -173,7 +175,7 @@ var _ = Describe("Streams Map (outgoing)", func() { m.SetMaxStream(1000) str, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) + Expect(str.num).To(Equal(protocol.StreamNum(1))) }) It("opens streams in the right order", func() { @@ -183,7 +185,7 @@ var _ = Describe("Streams Map (outgoing)", func() { defer GinkgoRecover() str, err := m.OpenStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) + Expect(str.num).To(Equal(protocol.StreamNum(1))) close(done1) }() waitForEnqueued(1) @@ -193,7 +195,7 @@ var _ = Describe("Streams Map (outgoing)", func() { defer GinkgoRecover() str, err := m.OpenStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) + Expect(str.num).To(Equal(protocol.StreamNum(2))) close(done2) }() waitForEnqueued(2) @@ -212,7 +214,7 @@ var _ = Describe("Streams Map (outgoing)", func() { defer GinkgoRecover() str, err := m.OpenStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) + Expect(str.num).To(Equal(protocol.StreamNum(1))) close(done1) }() waitForEnqueued(1) @@ -232,7 +234,7 @@ var _ = Describe("Streams Map (outgoing)", func() { defer GinkgoRecover() str, err := m.OpenStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) + Expect(str.num).To(Equal(protocol.StreamNum(2))) close(done3) }() waitForEnqueued(3) @@ -284,7 +286,7 @@ var _ = Describe("Streams Map (outgoing)", func() { defer GinkgoRecover() str, err := m.OpenStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(1))) + Expect(str.num).To(Equal(protocol.StreamNum(1))) close(openedSync) }() waitForEnqueued(1) @@ -297,7 +299,7 @@ var _ = Describe("Streams Map (outgoing)", func() { for { str, err := m.OpenStream() if err == nil { - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) + Expect(str.num).To(Equal(protocol.StreamNum(2))) close(openend) return } @@ -340,7 +342,7 @@ var _ = Describe("Streams Map (outgoing)", func() { Expect(err).ToNot(HaveOccurred()) str, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(2))) + Expect(str.num).To(Equal(protocol.StreamNum(2))) }) It("queues a STREAMS_BLOCKED frame if no stream can be opened", func() { @@ -352,7 +354,9 @@ var _ = Describe("Streams Map (outgoing)", func() { } mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { - Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(6)) + bf := f.(*wire.StreamsBlockedFrame) + Expect(bf.Type).To(BeEquivalentTo(streamType)) + Expect(bf.StreamLimit).To(BeEquivalentTo(6)) }) _, err := m.OpenStream() Expect(err).To(HaveOccurred()) @@ -423,7 +427,7 @@ var _ = Describe("Streams Map (outgoing)", func() { defer close(doneChan) str, err := m.OpenStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) - Expect(str.(*mockGenericStream).num).To(Equal(id)) + Expect(str.num).To(Equal(id)) }(c, protocol.StreamNum(i)) waitForEnqueued(i) } @@ -449,7 +453,7 @@ var _ = Describe("Streams Map (outgoing)", func() { Expect(err).To(HaveOccurred()) Expect(err.Error()).To(Equal(errTooManyOpenStreams.Error())) } else { - Expect(str.(*mockGenericStream).num).To(Equal(protocol.StreamNum(n + 1))) + Expect(str.num).To(Equal(protocol.StreamNum(n + 1))) } } Expect(blockedAt).To(Equal(limits)) @@ -499,7 +503,7 @@ var _ = Describe("Streams Map (outgoing)", func() { } Expect(err).ToNot(HaveOccurred()) mutex.Lock() - streamIDs = append(streamIDs, int(str.(*mockGenericStream).num)) + streamIDs = append(streamIDs, int(str.num)) mutex.Unlock() }(c, protocol.StreamNum(i)) waitForEnqueued(i)