Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use a generic streams map for outgoing streams #3488

Merged
merged 1 commit into from Aug 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 6 additions & 4 deletions streams_map.go
Expand Up @@ -55,8 +55,8 @@ type streamsMap struct {
newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController

mutex sync.Mutex
outgoingBidiStreams *outgoingBidiStreamsMap
outgoingUniStreams *outgoingUniStreamsMap
outgoingBidiStreams *outgoingStreamsMap[streamI]
outgoingUniStreams *outgoingStreamsMap[sendStreamI]
incomingBidiStreams *incomingBidiStreamsMap
incomingUniStreams *incomingUniStreamsMap
reset bool
Expand Down Expand Up @@ -85,7 +85,8 @@ func newStreamsMap(
}

func (m *streamsMap) initMaps() {
m.outgoingBidiStreams = newOutgoingBidiStreamsMap(
m.outgoingBidiStreams = newOutgoingStreamsMap(
protocol.StreamTypeBidi,
func(num protocol.StreamNum) streamI {
id := num.StreamID(protocol.StreamTypeBidi, m.perspective)
return newStream(id, m.sender, m.newFlowController(id), m.version)
Expand All @@ -100,7 +101,8 @@ func (m *streamsMap) initMaps() {
m.maxIncomingBidiStreams,
m.sender.queueControlFrame,
)
m.outgoingUniStreams = newOutgoingUniStreamsMap(
m.outgoingUniStreams = newOutgoingStreamsMap(
protocol.StreamTypeUni,
func(num protocol.StreamNum) sendStreamI {
id := num.StreamID(protocol.StreamTypeUni, m.perspective)
return newSendStream(id, m.sender, m.newFlowController(id), m.version)
Expand Down
64 changes: 34 additions & 30 deletions streams_map_outgoing_uni.go → streams_map_outgoing.go
@@ -1,7 +1,3 @@
// This file was automatically generated by genny.
// Any changes will be lost if this file is regenerated.
// see https://github.com/cheekybits/genny

package quic

import (
Expand All @@ -12,10 +8,16 @@ import (
"github.com/lucas-clemente/quic-go/internal/wire"
)

type outgoingUniStreamsMap struct {
type outgoingStream interface {
updateSendWindow(protocol.ByteCount)
closeForShutdown(error)
}

type outgoingStreamsMap[T outgoingStream] struct {
mutex sync.RWMutex

streams map[protocol.StreamNum]sendStreamI
streamType protocol.StreamType
streams map[protocol.StreamNum]T

openQueue map[uint64]chan struct{}
lowestInQueue uint64
Expand All @@ -25,18 +27,20 @@ type outgoingUniStreamsMap struct {
maxStream protocol.StreamNum // the maximum stream ID we're allowed to open
blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream

newStream func(protocol.StreamNum) sendStreamI
newStream func(protocol.StreamNum) T
queueStreamIDBlocked func(*wire.StreamsBlockedFrame)

closeErr error
}

func newOutgoingUniStreamsMap(
newStream func(protocol.StreamNum) sendStreamI,
func newOutgoingStreamsMap[T outgoingStream](
streamType protocol.StreamType,
newStream func(protocol.StreamNum) T,
queueControlFrame func(wire.Frame),
) *outgoingUniStreamsMap {
return &outgoingUniStreamsMap{
streams: make(map[protocol.StreamNum]sendStreamI),
) *outgoingStreamsMap[T] {
return &outgoingStreamsMap[T]{
streamType: streamType,
streams: make(map[protocol.StreamNum]T),
openQueue: make(map[uint64]chan struct{}),
maxStream: protocol.InvalidStreamNum,
nextStream: 1,
Expand All @@ -45,32 +49,32 @@ func newOutgoingUniStreamsMap(
}
}

func (m *outgoingUniStreamsMap) OpenStream() (sendStreamI, error) {
func (m *outgoingStreamsMap[T]) OpenStream() (T, error) {
m.mutex.Lock()
defer m.mutex.Unlock()

if m.closeErr != nil {
return nil, m.closeErr
return *new(T), m.closeErr
}

// if there are OpenStreamSync calls waiting, return an error here
if len(m.openQueue) > 0 || m.nextStream > m.maxStream {
m.maybeSendBlockedFrame()
return nil, streamOpenErr{errTooManyOpenStreams}
return *new(T), streamOpenErr{errTooManyOpenStreams}
}
return m.openStream(), nil
}

func (m *outgoingUniStreamsMap) OpenStreamSync(ctx context.Context) (sendStreamI, error) {
func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) {
m.mutex.Lock()
defer m.mutex.Unlock()

if m.closeErr != nil {
return nil, m.closeErr
return *new(T), m.closeErr
}

if err := ctx.Err(); err != nil {
return nil, err
return *new(T), err
}

if len(m.openQueue) == 0 && m.nextStream <= m.maxStream {
Expand All @@ -92,13 +96,13 @@ func (m *outgoingUniStreamsMap) OpenStreamSync(ctx context.Context) (sendStreamI
case <-ctx.Done():
m.mutex.Lock()
delete(m.openQueue, queuePos)
return nil, ctx.Err()
return *new(T), ctx.Err()
case <-waitChan:
}
m.mutex.Lock()

if m.closeErr != nil {
return nil, m.closeErr
return *new(T), m.closeErr
}
if m.nextStream > m.maxStream {
// no stream available. Continue waiting
Expand All @@ -112,7 +116,7 @@ func (m *outgoingUniStreamsMap) OpenStreamSync(ctx context.Context) (sendStreamI
}
}

func (m *outgoingUniStreamsMap) openStream() sendStreamI {
func (m *outgoingStreamsMap[T]) openStream() T {
s := m.newStream(m.nextStream)
m.streams[m.nextStream] = s
m.nextStream++
Expand All @@ -121,7 +125,7 @@ func (m *outgoingUniStreamsMap) openStream() sendStreamI {

// maybeSendBlockedFrame queues a STREAMS_BLOCKED frame for the current stream offset,
// if we haven't sent one for this offset yet
func (m *outgoingUniStreamsMap) maybeSendBlockedFrame() {
func (m *outgoingStreamsMap[T]) maybeSendBlockedFrame() {
if m.blockedSent {
return
}
Expand All @@ -131,17 +135,17 @@ func (m *outgoingUniStreamsMap) maybeSendBlockedFrame() {
streamNum = m.maxStream
}
m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{
Type: protocol.StreamTypeUni,
Type: m.streamType,
StreamLimit: streamNum,
})
m.blockedSent = true
}

func (m *outgoingUniStreamsMap) GetStream(num protocol.StreamNum) (sendStreamI, error) {
func (m *outgoingStreamsMap[T]) GetStream(num protocol.StreamNum) (T, error) {
m.mutex.RLock()
if num >= m.nextStream {
m.mutex.RUnlock()
return nil, streamError{
return *new(T), streamError{
message: "peer attempted to open stream %d",
nums: []protocol.StreamNum{num},
}
Expand All @@ -151,7 +155,7 @@ func (m *outgoingUniStreamsMap) GetStream(num protocol.StreamNum) (sendStreamI,
return s, nil
}

func (m *outgoingUniStreamsMap) DeleteStream(num protocol.StreamNum) error {
func (m *outgoingStreamsMap[T]) DeleteStream(num protocol.StreamNum) error {
m.mutex.Lock()
defer m.mutex.Unlock()

Expand All @@ -165,7 +169,7 @@ func (m *outgoingUniStreamsMap) DeleteStream(num protocol.StreamNum) error {
return nil
}

func (m *outgoingUniStreamsMap) SetMaxStream(num protocol.StreamNum) {
func (m *outgoingStreamsMap[T]) SetMaxStream(num protocol.StreamNum) {
m.mutex.Lock()
defer m.mutex.Unlock()

Expand All @@ -183,7 +187,7 @@ func (m *outgoingUniStreamsMap) SetMaxStream(num protocol.StreamNum) {
// UpdateSendWindow is called when the peer's transport parameters are received.
// Only in the case of a 0-RTT handshake will we have open streams at this point.
// We might need to update the send window, in case the server increased it.
func (m *outgoingUniStreamsMap) UpdateSendWindow(limit protocol.ByteCount) {
func (m *outgoingStreamsMap[T]) UpdateSendWindow(limit protocol.ByteCount) {
m.mutex.Lock()
for _, str := range m.streams {
str.updateSendWindow(limit)
Expand All @@ -192,7 +196,7 @@ func (m *outgoingUniStreamsMap) UpdateSendWindow(limit protocol.ByteCount) {
}

// unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream
func (m *outgoingUniStreamsMap) unblockOpenSync() {
func (m *outgoingStreamsMap[T]) unblockOpenSync() {
if len(m.openQueue) == 0 {
return
}
Expand All @@ -211,7 +215,7 @@ func (m *outgoingUniStreamsMap) unblockOpenSync() {
}
}

func (m *outgoingUniStreamsMap) CloseWithError(err error) {
func (m *outgoingStreamsMap[T]) CloseWithError(err error) {
m.mutex.Lock()
m.closeErr = err
for _, str := range m.streams {
Expand Down