Skip to content

Commit

Permalink
Merge pull request #66 from libp2p/max-incoming-streams
Browse files Browse the repository at this point in the history
limit the number of concurrent incoming streams
  • Loading branch information
marten-seemann committed Nov 20, 2021
2 parents d6101de + f55df18 commit 0bd012d
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 5 deletions.
6 changes: 6 additions & 0 deletions mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ type Config struct {
// an expectation that things will move along quickly.
ConnectionWriteTimeout time.Duration

// MaxIncomingStreams is maximum number of concurrent incoming streams
// that we accept. If the peer tries to open more streams, those will be
// reset immediately.
MaxIncomingStreams uint32

// InitialStreamWindowSize is used to control the initial
// window size that we allow for a stream.
InitialStreamWindowSize uint32
Expand Down Expand Up @@ -65,6 +70,7 @@ func DefaultConfig() *Config {
EnableKeepAlive: true,
KeepAliveInterval: 30 * time.Second,
ConnectionWriteTimeout: 10 * time.Second,
MaxIncomingStreams: 1000,
InitialStreamWindowSize: initialStreamWindow,
MaxStreamWindowSize: maxStreamWindow,
LogOutput: os.Stderr,
Expand Down
23 changes: 18 additions & 5 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
"sync/atomic"
"time"

"github.com/libp2p/go-buffer-pool"
pool "github.com/libp2p/go-buffer-pool"
)

// Session is used to wrap a reliable ordered connection and to
Expand Down Expand Up @@ -55,9 +55,10 @@ type Session struct {
// streams maps a stream id to a stream, and inflight has an entry
// for any outgoing stream that has not yet been established. Both are
// protected by streamLock.
streams map[uint32]*Stream
inflight map[uint32]struct{}
streamLock sync.Mutex
numIncomingStreams uint32
streams map[uint32]*Stream
inflight map[uint32]struct{}
streamLock sync.Mutex

// synCh acts like a semaphore. It is sized to the AcceptBacklog which
// is assumed to be symmetric between the client and server. This allows
Expand Down Expand Up @@ -735,6 +736,15 @@ func (s *Session) incomingStream(id uint32) error {
return ErrDuplicateStream
}

if s.numIncomingStreams >= s.config.MaxIncomingStreams {
// too many active streams at the same time
s.logger.Printf("[WARN] yamux: MaxIncomingStreams exceeded, forcing stream reset")
delete(s.streams, id)
hdr := encode(typeWindowUpdate, flagRST, id, 0)
return s.sendMsg(hdr, nil, nil)
}

s.numIncomingStreams++
// Register the stream
s.streams[id] = stream

Expand All @@ -744,7 +754,7 @@ func (s *Session) incomingStream(id uint32) error {
return nil
default:
// Backlog exceeded! RST the stream
s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset")
s.logger.Printf("[WARN] yamux: backlog exceeded, forcing stream reset")
delete(s.streams, id)
hdr := encode(typeWindowUpdate, flagRST, id, 0)
return s.sendMsg(hdr, nil, nil)
Expand All @@ -764,6 +774,9 @@ func (s *Session) closeStream(id uint32) {
}
delete(s.inflight, id)
}
if s.client == (id%2 == 0) {
s.numIncomingStreams--
}
delete(s.streams, id)
s.streamLock.Unlock()
}
Expand Down
52 changes: 52 additions & 0 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1732,3 +1732,55 @@ func TestInitialStreamWindow(t *testing.T) {
}
}
}

func TestMaxIncomingStreams(t *testing.T) {
const maxIncomingStreams = 5
conn1, conn2 := testConn()
client, err := Client(conn1, DefaultConfig())
require.NoError(t, err)
defer client.Close()

conf := DefaultConfig()
conf.MaxIncomingStreams = maxIncomingStreams
server, err := Server(conn2, conf)
require.NoError(t, err)
defer server.Close()

strChan := make(chan *Stream, maxIncomingStreams)
go func() {
defer close(strChan)
for {
str, err := server.AcceptStream()
if err != nil {
return
}
_, err = str.Write([]byte("foobar"))
require.NoError(t, err)
strChan <- str
}
}()

for i := 0; i < maxIncomingStreams; i++ {
str, err := client.OpenStream(context.Background())
require.NoError(t, err)
_, err = str.Read(make([]byte, 6))
require.NoError(t, err)
require.NoError(t, str.CloseWrite())
}
// The server now has maxIncomingStreams incoming streams.
// It will now reset the next stream that is opened.
str, err := client.OpenStream(context.Background())
require.NoError(t, err)
str.SetDeadline(time.Now().Add(time.Second))
_, err = str.Read([]byte{0})
require.EqualError(t, err, "stream reset")

// Now close one of the streams.
// This should then allow the client to open a new stream.
require.NoError(t, (<-strChan).Close())
str, err = client.OpenStream(context.Background())
require.NoError(t, err)
str.SetDeadline(time.Now().Add(time.Second))
_, err = str.Read([]byte{0})
require.NoError(t, err)
}

0 comments on commit 0bd012d

Please sign in to comment.