Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

webrtc: close data channels cleanly #2724

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 21 additions & 6 deletions p2p/transport/webrtc/connection.go
Expand Up @@ -21,7 +21,10 @@ import (

var _ tpt.CapableConn = &connection{}

const maxAcceptQueueLen = 256
const (
maxAcceptQueueLen = 256
maxInvalidDataChannelClosures = 10
)

type errConnectionTimeout struct{}

Expand Down Expand Up @@ -51,9 +54,10 @@ type connection struct {
remoteKey ic.PubKey
remoteMultiaddr ma.Multiaddr

m sync.Mutex
streams map[uint16]*stream
nextStreamID atomic.Int32
m sync.Mutex
streams map[uint16]*stream
nextStreamID atomic.Int32
invalidDataChannelClosures atomic.Int32

acceptQueue chan dataChannel

Expand Down Expand Up @@ -158,7 +162,7 @@ func (c *connection) OpenStream(ctx context.Context) (network.MuxedStream, error
dc.Close()
return nil, fmt.Errorf("detach channel failed for stream(%d): %w", streamID, err)
}
str := newStream(dc, rwc, func() { c.removeStream(streamID) })
str := newStream(dc, rwc, maxRTT, func() { c.removeStream(streamID) }, c.onDataChannelClose)
if err := c.addStream(str); err != nil {
str.Reset()
return nil, fmt.Errorf("failed to add stream(%d) to connection: %w", streamID, err)
Expand All @@ -171,7 +175,7 @@ func (c *connection) AcceptStream() (network.MuxedStream, error) {
case <-c.ctx.Done():
return nil, c.closeErr
case dc := <-c.acceptQueue:
str := newStream(dc.channel, dc.stream, func() { c.removeStream(*dc.channel.ID()) })
str := newStream(dc.channel, dc.stream, maxRTT, func() { c.removeStream(*dc.channel.ID()) }, c.onDataChannelClose)
if err := c.addStream(str); err != nil {
str.Reset()
return nil, err
Expand Down Expand Up @@ -207,6 +211,17 @@ func (c *connection) removeStream(id uint16) {
delete(c.streams, id)
}

func (c *connection) onDataChannelClose(remoteClosed bool) {
if !remoteClosed {
if c.invalidDataChannelClosures.Add(1) > maxInvalidDataChannelClosures {
c.closeOnce.Do(func() {
log.Error("closing connection as peer is not closing datachannels: ", c.RemotePeer(), c.RemoteMultiaddr())
c.closeWithError(errors.New("peer is not closing datachannels"))
})
}
}
}

func (c *connection) onConnectionStateChange(state webrtc.PeerConnectionState) {
if state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateClosed {
c.closeOnce.Do(func() {
Expand Down
2 changes: 1 addition & 1 deletion p2p/transport/webrtc/listener.go
Expand Up @@ -257,7 +257,7 @@ func (l *listener) setupConnection(
if err != nil {
return nil, err
}
handshakeChannel := newStream(w.HandshakeDataChannel, rwc, func() {})
handshakeChannel := newStream(w.HandshakeDataChannel, rwc, maxRTT, nil, nil)
// we do not yet know A's peer ID so accept any inbound
remotePubKey, err := l.transport.noiseHandshake(ctx, w.PeerConnection, handshakeChannel, "", crypto.SHA256, true)
if err != nil {
Expand Down
170 changes: 94 additions & 76 deletions p2p/transport/webrtc/stream.go
Expand Up @@ -2,6 +2,7 @@ package libp2pwebrtc

import (
"errors"
"io"
"os"
"sync"
"time"
Expand Down Expand Up @@ -35,7 +36,6 @@ const (
// add messages to the send buffer once there is space for 1 full
// sized message.
bufferedAmountLowThreshold = maxBufferedAmount / 2

// Proto overhead assumption is 5 bytes
protoOverhead = 5
// Varint overhead is assumed to be 2 bytes. This is safe since
Expand All @@ -45,9 +45,9 @@ const (
// is less than or equal to 2 ^ 14, the varint will not be more than
// 2 bytes in length.
varintOverhead = 2
// maxFINACKWait is the maximum amount of time a stream will wait to read
// FIN_ACK before closing the data channel
maxFINACKWait = 10 * time.Second
// maxRTT is an estimate of maximum RTT
// We use this to wait for FIN_ACK and Data Channel Close messages from the peer
maxRTT = 10 * time.Second
)

type receiveState uint8
Expand Down Expand Up @@ -89,37 +89,36 @@ type stream struct {
writeDeadline time.Time

controlMessageReaderOnce sync.Once
// controlMessageReaderEndTime is the end time for reading FIN_ACK from the control
// message reader. We cannot rely on SetReadDeadline to do this since that is prone to
// race condition where a previous deadline timer fires after the latest call to
// SetReadDeadline
// See: https://github.com/pion/sctp/pull/290
controlMessageReaderEndTime time.Time
controlMessageReaderDone sync.WaitGroup

onDone func()

onCloseOnce sync.Once
onClose func()
onDataChannelClose func(remoteClosed bool)
id uint16 // for logging purposes
dataChannel *datachannel.DataChannel
closeForShutdownErr error
isClosed bool
rtt time.Duration
}

var _ network.MuxedStream = &stream{}

func newStream(
channel *webrtc.DataChannel,
rwc datachannel.ReadWriteCloser,
onDone func(),
rtt time.Duration,
onClose func(),
onDataChannelClose func(remoteClosed bool),
) *stream {
s := &stream{
reader: pbio.NewDelimitedReader(rwc, maxMessageSize),
writer: pbio.NewDelimitedWriter(rwc),
writeStateChanged: make(chan struct{}, 1),
id: *channel.ID(),
dataChannel: rwc.(*datachannel.DataChannel),
onDone: onDone,
reader: pbio.NewDelimitedReader(rwc, maxMessageSize),
writer: pbio.NewDelimitedWriter(rwc),
writeStateChanged: make(chan struct{}, 1),
id: *channel.ID(),
dataChannel: rwc.(*datachannel.DataChannel),
onClose: onClose,
onDataChannelClose: onDataChannelClose,
rtt: rtt,
}
// released when the controlMessageReader goroutine exits
s.controlMessageReaderDone.Add(1)
s.dataChannel.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold)
s.dataChannel.OnBufferedAmountLow(func() {
s.notifyWriteStateChanged()
Expand All @@ -129,55 +128,46 @@ func newStream(
}

func (s *stream) Close() error {
defer s.signalClose()
s.mx.Lock()
isClosed := s.closeForShutdownErr != nil
s.mx.Unlock()
if isClosed {
if s.closeForShutdownErr != nil || s.isClosed {
s.mx.Unlock()
return nil
}
s.isClosed = true
closeWriteErr := s.closeWriteUnlocked()
closeReadErr := s.closeReadUnlocked()
s.setDataChannelReadDeadline(time.Now().Add(-1 * time.Hour))
s.mx.Unlock()

closeWriteErr := s.CloseWrite()
closeReadErr := s.CloseRead()
if closeWriteErr != nil || closeReadErr != nil {
s.Reset()
return errors.Join(closeWriteErr, closeReadErr)
}

s.mx.Lock()
if s.controlMessageReaderEndTime.IsZero() {
s.controlMessageReaderEndTime = time.Now().Add(maxFINACKWait)
s.setDataChannelReadDeadline(time.Now().Add(-1 * time.Hour))
go func() {
s.controlMessageReaderDone.Wait()
s.cleanup()
}()
}
s.mx.Unlock()
return nil
}

func (s *stream) Reset() error {
defer s.signalClose()
s.mx.Lock()
isClosed := s.closeForShutdownErr != nil
s.mx.Unlock()
if isClosed {
defer s.mx.Unlock()
if s.closeForShutdownErr != nil {
return nil
}

defer s.cleanup()
cancelWriteErr := s.cancelWrite()
closeReadErr := s.CloseRead()
// reset even if it's closed already
s.isClosed = true
cancelWriteErr := s.cancelWriteUnlocked()
closeReadErr := s.closeReadUnlocked()
s.setDataChannelReadDeadline(time.Now().Add(-1 * time.Hour))
return errors.Join(closeReadErr, cancelWriteErr)
return errors.Join(cancelWriteErr, closeReadErr)
}

func (s *stream) closeForShutdown(closeErr error) {
defer s.cleanup()

defer s.signalClose()
s.mx.Lock()
defer s.mx.Unlock()

s.closeForShutdownErr = closeErr
s.isClosed = true
s.notifyWriteStateChanged()
}

Expand Down Expand Up @@ -223,47 +213,54 @@ func (s *stream) processIncomingFlag(flag *pb.Message_Flag) {
}

// spawnControlMessageReader is used for processing control messages after the reader is closed.
// It is also responsible for closing the datachannel once the stream is closed
func (s *stream) spawnControlMessageReader() {
s.controlMessageReaderOnce.Do(func() {
// Spawn a goroutine to ensure that we're not holding any locks
go func() {
defer s.controlMessageReaderDone.Done()
// cleanup the sctp deadline timer goroutine
defer s.setDataChannelReadDeadline(time.Time{})

setDeadline := func() bool {
if s.controlMessageReaderEndTime.IsZero() || time.Now().Before(s.controlMessageReaderEndTime) {
s.setDataChannelReadDeadline(s.controlMessageReaderEndTime)
return true
}
return false
}

// Unblock any Read call waiting on reader.ReadMsg
s.setDataChannelReadDeadline(time.Now().Add(-1 * time.Hour))

s.readerMx.Lock()
// We have the lock any readers blocked on reader.ReadMsg have exited.
// From this point onwards only this goroutine will do reader.ReadMsg.

//lint:ignore SA2001 we just want to ensure any exising readers have exited.
// Read calls from this point onwards will exit immediately on checking
// released after write half is closed
s.mx.Lock()

// Read calls after lock release will exit immediately on checking
// s.readState
s.readerMx.Unlock()

s.mx.Lock()
defer s.mx.Unlock()

if s.nextMessage != nil {
s.processIncomingFlag(s.nextMessage.Flag)
s.nextMessage = nil
}
for s.closeForShutdownErr == nil &&
s.sendState != sendStateDataReceived && s.sendState != sendStateReset {
var msg pb.Message
if !setDeadline() {
return

var endTime time.Time
var msg pb.Message
for {
// connection closed
if s.closeForShutdownErr != nil {
break
}
// write half completed
if s.sendState == sendStateDataReceived || s.sendState == sendStateReset {
break
}
// deadline exceeded
if !endTime.IsZero() && time.Now().After(endTime) {
break
}

// The stream is closed. Wait for 1RTT before erroring
if s.sendState == sendStateDataSent && endTime.IsZero() {
endTime = time.Now().Add(s.rtt)
}
s.setDataChannelReadDeadline(endTime)
s.mx.Unlock()
err := s.reader.ReadMsg(&msg)
s.mx.Lock()
Expand All @@ -274,21 +271,42 @@ func (s *stream) spawnControlMessageReader() {
if errors.Is(err, os.ErrDeadlineExceeded) {
continue
}
return
break
}
s.processIncomingFlag(msg.Flag)
}

s.mx.Unlock()
remoteClosed := s.closeDataChannel()
if s.onDataChannelClose != nil {
s.onDataChannelClose(remoteClosed)
}
}()
})
}

func (s *stream) cleanup() {
// Even if we close the datachannel pion keeps a reference to the datachannel around.
// Remove the onBufferedAmountLow callback to ensure that we at least garbage collect
// memory we allocated for this stream.
s.dataChannel.OnBufferedAmountLow(nil)
// closeDataChannel closes the datachannel and waits for 1rtt for remote to close the datachannel
func (s *stream) closeDataChannel() bool {
s.dataChannel.Close()
if s.onDone != nil {
s.onDone()
endTime := time.Now().Add(s.rtt)
var msg pb.Message
for {
if time.Now().After(endTime) {
return false
}
s.setDataChannelReadDeadline(endTime)
err := s.reader.ReadMsg(&msg)
if err == nil || errors.Is(err, os.ErrDeadlineExceeded) {
continue
}
return err == io.EOF
}
}

func (s *stream) signalClose() {
s.onCloseOnce.Do(func() {
if s.onClose != nil {
s.onClose()
}
})
}
4 changes: 4 additions & 0 deletions p2p/transport/webrtc/stream_read.go
Expand Up @@ -103,6 +103,10 @@ func (s *stream) setDataChannelReadDeadline(t time.Time) error {
func (s *stream) CloseRead() error {
s.mx.Lock()
defer s.mx.Unlock()
return s.closeReadUnlocked()
}

func (s *stream) closeReadUnlocked() error {
var err error
if s.receiveState == receiveStateReceiving && s.closeForShutdownErr == nil {
err = s.writer.WriteMsg(&pb.Message{Flag: pb.Message_STOP_SENDING.Enum()})
Expand Down