Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement DTLS restart #1846

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 16 additions & 4 deletions datachannel.go
Expand Up @@ -70,7 +70,7 @@ func (api *API) NewDataChannel(transport *SCTPTransport, params *DataChannelPara
return nil, err
}

err = d.open(transport)
err = d.open(transport, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -104,14 +104,14 @@ func (api *API) newDataChannel(params *DataChannelParameters, log logging.Levele
}

// open opens the datachannel over the sctp transport
func (d *DataChannel) open(sctpTransport *SCTPTransport) error {
func (d *DataChannel) open(sctpTransport *SCTPTransport, restart bool) error {
association := sctpTransport.association()
if association == nil {
return errSCTPNotEstablished
}

d.mu.Lock()
if d.sctpTransport != nil { // already open
if d.sctpTransport != nil && !restart { // already open & not restarting
d.mu.Unlock()
return nil
}
Expand Down Expand Up @@ -170,6 +170,11 @@ func (d *DataChannel) open(sctpTransport *SCTPTransport) error {
return err
}

// If restarting, the `Open` event should be triggered again, once.
if restart {
d.openHandlerOnce = sync.Once{}
}

// bufferedAmountLowThreshold and onBufferedAmountLow might be set earlier
dc.SetBufferedAmountLowThreshold(d.bufferedAmountLowThreshold)
dc.OnBufferedAmountLow(d.onBufferedAmountLow)
Expand Down Expand Up @@ -325,11 +330,18 @@ func (d *DataChannel) readLoop() {
n, isString, err := d.dataChannel.ReadDataChannel(buffer)
if err != nil {
rlBufPool.Put(buffer) // nolint:staticcheck

previousState := d.ReadyState()
d.setReadyState(DataChannelStateClosed)

if err != io.EOF {
d.onError(err)
}
d.onClose()

// https://www.w3.org/TR/webrtc/#announcing-a-data-channel-as-closed
if previousState != DataChannelStateClosed {
d.onClose()
}
return
}

Expand Down
25 changes: 25 additions & 0 deletions dtlstransport.go
Expand Up @@ -215,6 +215,31 @@ func (t *DTLSTransport) startSRTP() error {
return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, err)
}

isAlreadyRunning := func() bool {
select {
case <-t.srtpReady:
return true
default:
return false
}
}()

if isAlreadyRunning {
if sess, ok := t.srtpSession.Load().(*srtp.SessionSRTP); ok {
if updateErr := sess.UpdateContext(srtpConfig); updateErr != nil {
return updateErr
}
}

if sess, ok := t.srtcpSession.Load().(*srtp.SessionSRTCP); ok {
if updateErr := sess.UpdateContext(srtpConfig); updateErr != nil {
return updateErr
}
}

return nil
}

srtpSession, err := srtp.NewSessionSRTP(t.srtpEndpoint, srtpConfig)
if err != nil {
return fmt.Errorf("%w: %v", errFailedToStartSRTP, err)
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Expand Up @@ -15,7 +15,7 @@ require (
github.com/pion/rtp v1.7.4
github.com/pion/sctp v1.8.2
github.com/pion/sdp/v3 v3.0.4
github.com/pion/srtp/v2 v2.0.5
github.com/pion/srtp/v2 v2.0.6-0.20220304062923-d55e443f8e15
github.com/pion/transport v0.13.0
github.com/sclevine/agouti v3.0.0+incompatible
github.com/stretchr/testify v1.7.0
Expand Down
6 changes: 2 additions & 4 deletions go.sum
Expand Up @@ -54,19 +54,17 @@ github.com/pion/mdns v0.0.5 h1:Q2oj/JB3NqfzY9xGZ1fPzZzK7sDSD8rZPOvcIQ10BCw=
github.com/pion/mdns v0.0.5/go.mod h1:UgssrvdD3mxpi8tMxAXbsppL3vJ4Jipw1mTCW+al01g=
github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
github.com/pion/rtcp v1.2.6/go.mod h1:52rMNPWFsjr39z9B9MhnkqhPLoeHTv1aN63o/42bWE0=
github.com/pion/rtcp v1.2.9 h1:1ujStwg++IOLIEoOiIQ2s+qBuJ1VN81KW+9pMPsif+U=
github.com/pion/rtcp v1.2.9/go.mod h1:qVPhiCzAm4D/rxb6XzKeyZiQK69yJpbUDJSF7TgrqNo=
github.com/pion/rtp v1.7.0/go.mod h1:bDb5n+BFZxXx0Ea7E5qe+klMuqiBrP+w8XSjiWtCUko=
github.com/pion/rtp v1.7.4 h1:4dMbjb1SuynU5OpA3kz1zHK+u+eOCQjW3MAeVHf1ODA=
github.com/pion/rtp v1.7.4/go.mod h1:bDb5n+BFZxXx0Ea7E5qe+klMuqiBrP+w8XSjiWtCUko=
github.com/pion/sctp v1.8.0/go.mod h1:xFe9cLMZ5Vj6eOzpyiKjT9SwGM4KpK/8Jbw5//jc+0s=
github.com/pion/sctp v1.8.2 h1:yBBCIrUMJ4yFICL3RIvR4eh/H2BTTvlligmSTy+3kiA=
github.com/pion/sctp v1.8.2/go.mod h1:xFe9cLMZ5Vj6eOzpyiKjT9SwGM4KpK/8Jbw5//jc+0s=
github.com/pion/sdp/v3 v3.0.4 h1:2Kf+dgrzJflNCSw3TV5v2VLeI0s/qkzy2r5jlR0wzf8=
github.com/pion/sdp/v3 v3.0.4/go.mod h1:bNiSknmJE0HYBprTHXKPQ3+JjacTv5uap92ueJZKsRk=
github.com/pion/srtp/v2 v2.0.5 h1:ks3wcTvIUE/GHndO3FAvROQ9opy0uLELpwHJaQ1yqhQ=
github.com/pion/srtp/v2 v2.0.5/go.mod h1:8k6AJlal740mrZ6WYxc4Dg6qDqqhxoRG2GSjlUhDF0A=
github.com/pion/srtp/v2 v2.0.6-0.20220304062923-d55e443f8e15 h1:qFdF9b185eGmlBr1OyizpP4a8RnS8tCT8MztCKofuQ8=
github.com/pion/srtp/v2 v2.0.6-0.20220304062923-d55e443f8e15/go.mod h1:Kp632EOcOX2wtB6njSY+oRamReUfEYINuaGmKIMHVlA=
github.com/pion/stun v0.3.5 h1:uLUCBCkQby4S1cf6CGuR9QrVOKcvUwFeemaC865QHDg=
github.com/pion/stun v0.3.5/go.mod h1:gDMim+47EeEtfWogA37n6qXZS88L5V6LqFcf+DZA2UA=
github.com/pion/transport v0.12.2/go.mod h1:N3+vZQD9HlDP5GWkZ85LohxNsDcNgofQmyL6ojX5d8Q=
Expand Down
54 changes: 53 additions & 1 deletion peerconnection.go
Expand Up @@ -1144,7 +1144,59 @@ func (pc *PeerConnection) SetRemoteDescription(desc SessionDescription) error {
pc.ops.Enqueue(func() {
pc.startRTP(true, &desc, currentTransceivers)
})
} else if pc.dtlsTransport.State() != DTLSTransportStateNew {
fingerprint, fingerprintHash, fErr := extractFingerprint(desc.parsed)
if fErr != nil {
return fErr
}

fingerPrintDidChange := true

for _, fp := range pc.dtlsTransport.remoteParameters.Fingerprints {
if fingerprint == fp.Value && fingerprintHash == fp.Algorithm {
fingerPrintDidChange = false
break
}
}

if fingerPrintDidChange {
pc.ops.Enqueue(func() {
// SCTP uses DTLS, so prevent any use, by locking, while
// DTLS is restarting.
pc.sctpTransport.lock.Lock()
defer pc.sctpTransport.lock.Unlock()

if dErr := pc.dtlsTransport.Stop(); dErr != nil {
pc.log.Warnf("Failed to stop DTLS: %s", dErr)
}

// libwebrtc switches the connection back to `new`.
pc.dtlsTransport.lock.Lock()
pc.dtlsTransport.onStateChange(DTLSTransportStateNew)
pc.dtlsTransport.lock.Unlock()

// Restart the dtls transport with updated fingerprints
err = pc.dtlsTransport.Start(DTLSParameters{
Role: dtlsRoleFromRemoteSDP(desc.parsed),
Fingerprints: []DTLSFingerprint{{Algorithm: fingerprintHash, Value: fingerprint}},
})
pc.updateConnectionState(pc.ICEConnectionState(), pc.dtlsTransport.State())
if err != nil {
pc.log.Warnf("Failed to restart DTLS: %s", err)
return
}

// If SCTP was enabled, restart it with the new DTLS transport.
if pc.sctpTransport.isStarted {
if dErr := pc.sctpTransport.restart(pc.dtlsTransport.conn); dErr != nil {
pc.log.Warnf("Failed to restart SCTP: %s", dErr)
return
}
}
})
}
}

return nil
}

Expand Down Expand Up @@ -1904,7 +1956,7 @@ func (pc *PeerConnection) CreateDataChannel(label string, options *DataChannelIn

// If SCTP already connected open all the channels
if pc.sctpTransport.State() == SCTPTransportStateConnected {
if err = d.open(pc.sctpTransport); err != nil {
if err = d.open(pc.sctpTransport, false); err != nil {
return nil, err
}
}
Expand Down