diff --git a/quinn/src/broadcast.rs b/quinn/src/broadcast.rs deleted file mode 100644 index 51463fed8..000000000 --- a/quinn/src/broadcast.rs +++ /dev/null @@ -1,62 +0,0 @@ -use std::task::{Context, Waker}; - -/// Helper for waking unpredictable numbers of tasks simultaneously -/// -/// # Rationale -/// -/// Sometimes we want to let an arbitrary number of tasks wait for the same transient condition. If -/// a task is polled and finds that the condition of interest is not in effect, it must register a -/// `Waker` to arrange to be polled when that may have changed. The number of such tasks is -/// indefinite, so we collect multiple `Waker`s in a `Vec` to be triggered en masse when the -/// condition arises. -/// -/// Complication arises from the spurious polling expected by futures. If each interested task -/// blindly registered a new `Waker` on finding the condition not in effect, the `Vec` would grow -/// with proportion to the (unbounded) number of spurious wakeups that interested tasks undergo. To -/// resolve this, we increment a generation counter every time we drain the `Vec`, and associate -/// with each interested task the generation at which it last registered. If a spurious wakeup -/// occurs, the task's generation is current, and we can avoid growing the `Vec`. If, however, the -/// wakeup is genuine but the condition of interest has already passed, then the task's generation -/// no longer matches the counter, and we infer that the task's `Waker` is no longer stored and a -/// new one must be recorded. -#[derive(Debug)] -pub struct Broadcast { - wakers: Vec, - generation: u64, -} - -impl Broadcast { - pub fn new() -> Self { - Self { - wakers: Vec::new(), - generation: 0, - } - } - - /// Ensure the next `wake` call will wake the calling task - /// - /// Checks the task-associated generation counter stored in `state`. If it's present and - /// current, we already have this task's `Waker` and no action is necessary. Otherwise, record a - /// `Waker` and store the current generation in `state`. - pub fn register(&mut self, cx: &mut Context, state: &mut State) { - if state.0 == Some(self.generation) { - return; - } - state.0 = Some(self.generation); - self.wakers.push(cx.waker().clone()); - } - - /// Wake all known `Waker`s - pub fn wake(&mut self) { - self.generation = self.generation.wrapping_add(1); - for waker in self.wakers.drain(..) { - waker.wake(); - } - } -} - -/// State maintained by each interested task -/// -/// Stores the generation at which the task previously registered a `Waker`, if any. -#[derive(Default)] -pub struct State(Option); diff --git a/quinn/src/connection.rs b/quinn/src/connection.rs index 35ffe6cde..b777f87e4 100644 --- a/quinn/src/connection.rs +++ b/quinn/src/connection.rs @@ -16,12 +16,12 @@ use futures_core::Stream; use proto::{ConnectionError, ConnectionHandle, ConnectionStats, Dir, StreamEvent, StreamId}; use rustc_hash::FxHashMap; use thiserror::Error; +use tokio::sync::Notify; use tokio::time::{sleep_until, Instant as TokioInstant, Sleep}; use tracing::info_span; use udp::UdpState; use crate::{ - broadcast::{self, Broadcast}, mutex::Mutex, poll_fn, recv_stream::RecvStream, @@ -324,11 +324,9 @@ impl Connection { /// Streams are cheap and instantaneous to open unless blocked by flow control. As a /// consequence, the peer won't be notified that a stream has been opened until the stream is /// actually used. - pub fn open_uni(&self) -> OpenUni { - OpenUni { - conn: self.0.clone(), - state: broadcast::State::default(), - } + pub async fn open_uni(&self) -> Result { + let (id, is_0rtt) = self.open(Dir::Uni).await?; + Ok(SendStream::new(self.0.clone(), id, is_0rtt)) } /// Initiate a new outgoing bidirectional stream. @@ -336,10 +334,35 @@ impl Connection { /// Streams are cheap and instantaneous to open unless blocked by flow control. As a /// consequence, the peer won't be notified that a stream has been opened until the stream is /// actually used. - pub fn open_bi(&self) -> OpenBi { - OpenBi { - conn: self.0.clone(), - state: broadcast::State::default(), + pub async fn open_bi(&self) -> Result<(SendStream, RecvStream), ConnectionError> { + let (id, is_0rtt) = self.open(Dir::Bi).await?; + Ok(( + SendStream::new(self.0.clone(), id, is_0rtt), + RecvStream::new(self.0.clone(), id, is_0rtt), + )) + } + + async fn open(&self, dir: Dir) -> Result<(StreamId, bool), ConnectionError> { + loop { + let opening; + { + let mut conn = self.0.lock("open"); + if let Some(ref e) = conn.error { + return Err(e.clone()); + } + if let Some(id) = conn.inner.streams().open(dir) { + let is_0rtt = conn.inner.side().is_client() && conn.inner.is_handshaking(); + return Ok((id, is_0rtt)); + } + // Clone the `Arc` so we can wait on the underlying `Notify` without holding + // the lock. Store it in the outer scope to ensure it outlives the lock guard. + opening = conn.stream_opening[dir as usize].clone(); + // Construct the future while the lock is held to ensure we can't miss a wakeup if + // the `Notify` is signaled immediately after we release the lock. `await` it after + // the lock guard is out of scope. + opening.notified() + } + .await } } @@ -620,61 +643,6 @@ impl Stream for Datagrams { } } -/// A future that will resolve into an opened outgoing unidirectional stream -#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"] -pub struct OpenUni { - conn: ConnectionRef, - state: broadcast::State, -} - -impl Future for OpenUni { - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let this = self.get_mut(); - let mut conn = this.conn.lock("OpenUni::next"); - if let Some(ref e) = conn.error { - return Poll::Ready(Err(e.clone())); - } - if let Some(id) = conn.inner.streams().open(Dir::Uni) { - let is_0rtt = conn.inner.side().is_client() && conn.inner.is_handshaking(); - drop(conn); // Release lock for clone - return Poll::Ready(Ok(SendStream::new(this.conn.clone(), id, is_0rtt))); - } - conn.uni_opening.register(cx, &mut this.state); - Poll::Pending - } -} - -/// A future that will resolve into an opened outgoing bidirectional stream -#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"] -pub struct OpenBi { - conn: ConnectionRef, - state: broadcast::State, -} - -impl Future for OpenBi { - type Output = Result<(SendStream, RecvStream), ConnectionError>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let this = self.get_mut(); - let mut conn = this.conn.lock("OpenBi::next"); - if let Some(ref e) = conn.error { - return Poll::Ready(Err(e.clone())); - } - if let Some(id) = conn.inner.streams().open(Dir::Bi) { - let is_0rtt = conn.inner.side().is_client() && conn.inner.is_handshaking(); - drop(conn); // Release lock for clone - return Poll::Ready(Ok(( - SendStream::new(this.conn.clone(), id, is_0rtt), - RecvStream::new(this.conn.clone(), id, is_0rtt), - ))); - } - conn.bi_opening.register(cx, &mut this.state); - Poll::Pending - } -} - #[derive(Debug)] pub struct ConnectionRef(Arc>); @@ -701,8 +669,7 @@ impl ConnectionRef { endpoint_events, blocked_writers: FxHashMap::default(), blocked_readers: FxHashMap::default(), - uni_opening: Broadcast::new(), - bi_opening: Broadcast::new(), + stream_opening: [Arc::new(Notify::new()), Arc::new(Notify::new())], incoming_uni_streams_reader: None, incoming_bi_streams_reader: None, datagram_reader: None, @@ -762,8 +729,7 @@ pub struct ConnectionInner { endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, pub(crate) blocked_writers: FxHashMap, pub(crate) blocked_readers: FxHashMap, - uni_opening: Broadcast, - bi_opening: Broadcast, + stream_opening: [Arc; 2], incoming_uni_streams_reader: Option, incoming_bi_streams_reader: Option, datagram_reader: Option, @@ -886,11 +852,7 @@ impl ConnectionInner { } } Stream(StreamEvent::Available { dir }) => { - let tasks = match dir { - Dir::Uni => &mut self.uni_opening, - Dir::Bi => &mut self.bi_opening, - }; - tasks.wake(); + self.stream_opening[dir as usize].notify_one(); } Stream(StreamEvent::Finished { id }) => { if let Some(finishing) = self.finishing.remove(&id) { @@ -982,8 +944,8 @@ impl ConnectionInner { for (_, reader) in self.blocked_readers.drain() { reader.wake() } - self.uni_opening.wake(); - self.bi_opening.wake(); + self.stream_opening[Dir::Uni as usize].notify_waiters(); + self.stream_opening[Dir::Bi as usize].notify_waiters(); if let Some(x) = self.incoming_uni_streams_reader.take() { x.wake(); } diff --git a/quinn/src/endpoint.rs b/quinn/src/endpoint.rs index 20c1b680b..03bb34f9c 100644 --- a/quinn/src/endpoint.rs +++ b/quinn/src/endpoint.rs @@ -19,15 +19,12 @@ use proto::{ self as proto, ClientConfig, ConnectError, ConnectionHandle, DatagramEvent, ServerConfig, }; use rustc_hash::FxHashMap; +use tokio::sync::Notify; use udp::{RecvMeta, UdpSocket, UdpState, BATCH_SIZE}; use crate::{ - broadcast::{self, Broadcast}, - connection::Connecting, - poll_fn, - work_limiter::WorkLimiter, - ConnectionEvent, EndpointConfig, EndpointEvent, VarInt, IO_LOOP_BOUND, RECV_TIME_BOUND, - SEND_TIME_BOUND, + connection::Connecting, poll_fn, work_limiter::WorkLimiter, ConnectionEvent, EndpointConfig, + EndpointEvent, VarInt, IO_LOOP_BOUND, RECV_TIME_BOUND, SEND_TIME_BOUND, }; /// A QUIC endpoint. @@ -226,16 +223,23 @@ impl Endpoint { /// [`close()`]: Endpoint::close /// [`Incoming`]: crate::Incoming pub async fn wait_idle(&self) { - let mut state = broadcast::State::default(); - poll_fn(move |cx| { - let endpoint = &mut *self.inner.lock().unwrap(); - if endpoint.connections.is_empty() { - return Poll::Ready(()); + loop { + let idle; + { + let endpoint = &mut *self.inner.lock().unwrap(); + if endpoint.connections.is_empty() { + break; + } + // Clone the `Arc` so we can wait on the underlying `Notify` without holding + // the lock. Store it in the outer scope to ensure it outlives the lock guard. + idle = endpoint.idle.clone(); + // Construct the future while the lock is held to ensure we can't miss a wakeup if + // the `Notify` is signaled immediately after we release the lock. `await` it after + // the lock guard is out of scope. + idle.notified() } - endpoint.idle.register(cx, &mut state); - Poll::Pending - }) - .await; + .await; + } } } @@ -321,7 +325,7 @@ pub(crate) struct EndpointInner { recv_limiter: WorkLimiter, recv_buf: Box<[u8]>, send_limiter: WorkLimiter, - idle: Broadcast, + idle: Arc, } impl EndpointInner { @@ -442,7 +446,7 @@ impl EndpointInner { if e.is_drained() { self.connections.senders.remove(&ch); if self.connections.is_empty() { - self.idle.wake(); + self.idle.notify_waiters(); } } if let Some(event) = self.inner.handle_event(ch, e) { @@ -581,7 +585,7 @@ impl EndpointRef { recv_buf: recv_buf.into(), recv_limiter: WorkLimiter::new(RECV_TIME_BOUND), send_limiter: WorkLimiter::new(SEND_TIME_BOUND), - idle: Broadcast::new(), + idle: Arc::new(Notify::new()), }))) } } diff --git a/quinn/src/lib.rs b/quinn/src/lib.rs index c2185bdfa..4a3818994 100644 --- a/quinn/src/lib.rs +++ b/quinn/src/lib.rs @@ -55,7 +55,6 @@ macro_rules! ready { }; } -mod broadcast; mod connection; mod endpoint; mod mutex; @@ -71,7 +70,7 @@ pub use proto::{ pub use crate::connection::{ Connecting, Connection, Datagrams, IncomingBiStreams, IncomingUniStreams, NewConnection, - OpenBi, OpenUni, SendDatagramError, UnknownStream, ZeroRttAccepted, + SendDatagramError, UnknownStream, ZeroRttAccepted, }; pub use crate::endpoint::{Endpoint, Incoming}; pub use crate::recv_stream::{ReadError, ReadExactError, ReadToEndError, RecvStream};