From 9d9488db67136651f85d839efcb5f61aba8531c9 Mon Sep 17 00:00:00 2001 From: Hayden Stainsby Date: Wed, 10 Aug 2022 21:30:05 +0200 Subject: [PATCH] sync: remove broadcast channel slot level closed flag (#4867) The broadcast channel allows multiple senders to send messages to multiple receivers, where each receiver receives messages starting from when it subscribes. After all senders are dropped, the receivers will continue to receive all waiting messages in the buffer and then receive a `Closed` error. To mark that a channel has closed, it stores two closed flags, one on the channel level and another in the buffer slot *after* the last used slot (this may also be the earliest entry being kept for lagged receivers, see #2425). However, we don't need both closed flags, keeping the channel level closed flag is sufficient. Without the slot level closed flag, each receiver receives each message until it is up to date and for that receiver the channel is empty. Then, the actual return message is chosen depending on the channel level closed flag; if the channel is NOT closed, then `Empty` is returned, if the channel is closed then `Closed` is returned instead. With the modified logic, there is no longer a need to append a closed token to the internal buffer (by setting the slot level closed flag on the next slot). This fixes the off by one error described in #4814, which caused a receiver which was created after the channel was already closed to get `Empty` from `try_recv` (or hang forever when calling `recv`) instead of receiving `Closed`. As a bonus, we save a single `bool` on each buffer slot. Refs: #4814 --- tokio/src/sync/broadcast.rs | 123 +++++++++++++--------------------- tokio/tests/sync_broadcast.rs | 11 ++- 2 files changed, 58 insertions(+), 76 deletions(-) diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs index 69ede8f89ce..426011e79c3 100644 --- a/tokio/src/sync/broadcast.rs +++ b/tokio/src/sync/broadcast.rs @@ -336,9 +336,6 @@ struct Slot { /// Uniquely identifies the `send` stored in the slot. pos: u64, - /// True signals the channel is closed. - closed: bool, - /// The value being broadcast. /// /// The value is set by `send` when the write lock is held. When a reader @@ -452,7 +449,6 @@ pub fn channel(mut capacity: usize) -> (Sender, Receiver) { buffer.push(RwLock::new(Slot { rem: AtomicUsize::new(0), pos: (i as u64).wrapping_sub(capacity as u64), - closed: false, val: UnsafeCell::new(None), })); } @@ -537,8 +533,43 @@ impl Sender { /// } /// ``` pub fn send(&self, value: T) -> Result> { - self.send2(Some(value)) - .map_err(|SendError(maybe_v)| SendError(maybe_v.unwrap())) + let mut tail = self.shared.tail.lock(); + + if tail.rx_cnt == 0 { + return Err(SendError(value)); + } + + // Position to write into + let pos = tail.pos; + let rem = tail.rx_cnt; + let idx = (pos & self.shared.mask as u64) as usize; + + // Update the tail position + tail.pos = tail.pos.wrapping_add(1); + + // Get the slot + let mut slot = self.shared.buffer[idx].write().unwrap(); + + // Track the position + slot.pos = pos; + + // Set remaining receivers + slot.rem.with_mut(|v| *v = rem); + + // Write the value + slot.val = UnsafeCell::new(Some(value)); + + // Release the slot lock before notifying the receivers. + drop(slot); + + tail.notify_rx(); + + // Release the mutex. This must happen after the slot lock is released, + // otherwise the writer lock bit could be cleared while another thread + // is in the critical section. + drop(tail); + + Ok(rem) } /// Creates a new [`Receiver`] handle that will receive values sent **after** @@ -610,49 +641,11 @@ impl Sender { tail.rx_cnt } - fn send2(&self, value: Option) -> Result>> { + fn close_channel(&self) { let mut tail = self.shared.tail.lock(); - - if tail.rx_cnt == 0 { - return Err(SendError(value)); - } - - // Position to write into - let pos = tail.pos; - let rem = tail.rx_cnt; - let idx = (pos & self.shared.mask as u64) as usize; - - // Update the tail position - tail.pos = tail.pos.wrapping_add(1); - - // Get the slot - let mut slot = self.shared.buffer[idx].write().unwrap(); - - // Track the position - slot.pos = pos; - - // Set remaining receivers - slot.rem.with_mut(|v| *v = rem); - - // Set the closed bit if the value is `None`; otherwise write the value - if value.is_none() { - tail.closed = true; - slot.closed = true; - } else { - slot.val.with_mut(|ptr| unsafe { *ptr = value }); - } - - // Release the slot lock before notifying the receivers. - drop(slot); + tail.closed = true; tail.notify_rx(); - - // Release the mutex. This must happen after the slot lock is released, - // otherwise the writer lock bit could be cleared while another thread - // is in the critical section. - drop(tail); - - Ok(rem) } } @@ -700,7 +693,7 @@ impl Clone for Sender { impl Drop for Sender { fn drop(&mut self) { if 1 == self.shared.num_tx.fetch_sub(1, SeqCst) { - let _ = self.send2(None); + self.close_channel(); } } } @@ -784,14 +777,6 @@ impl Receiver { let mut slot = self.shared.buffer[idx].read().unwrap(); if slot.pos != self.next { - let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64); - - // The receiver has read all current values in the channel and there - // is no waiter to register - if waiter.is_none() && next_pos == self.next { - return Err(TryRecvError::Empty); - } - // Release the `slot` lock before attempting to acquire the `tail` // lock. This is required because `send2` acquires the tail lock // first followed by the slot lock. Acquiring the locks in reverse @@ -813,6 +798,13 @@ impl Receiver { let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64); if next_pos == self.next { + // At this point the channel is empty for *this* receiver. If + // it's been closed, then that's what we return, otherwise we + // set a waker and return empty. + if tail.closed { + return Err(TryRecvError::Closed); + } + // Store the waker if let Some((waiter, waker)) = waiter { // Safety: called while locked. @@ -846,22 +838,7 @@ impl Receiver { // catch up by skipping dropped messages and setting the // internal cursor to the **oldest** message stored by the // channel. - // - // However, finding the oldest position is a bit more - // complicated than `tail-position - buffer-size`. When - // the channel is closed, the tail position is incremented to - // signal a new `None` message, but `None` is not stored in the - // channel itself (see issue #2425 for why). - // - // To account for this, if the channel is closed, the tail - // position is decremented by `buffer-size + 1`. - let mut adjust = 0; - if tail.closed { - adjust = 1 - } - let next = tail - .pos - .wrapping_sub(self.shared.buffer.len() as u64 + adjust); + let next = tail.pos.wrapping_sub(self.shared.buffer.len() as u64); let missed = next.wrapping_sub(self.next); @@ -882,10 +859,6 @@ impl Receiver { self.next = self.next.wrapping_add(1); - if slot.closed { - return Err(TryRecvError::Closed); - } - Ok(RecvGuard { slot }) } } diff --git a/tokio/tests/sync_broadcast.rs b/tokio/tests/sync_broadcast.rs index cee888dd2c0..9aa34841e26 100644 --- a/tokio/tests/sync_broadcast.rs +++ b/tokio/tests/sync_broadcast.rs @@ -47,7 +47,7 @@ macro_rules! assert_closed { ($e:expr) => { match assert_err!($e) { broadcast::error::TryRecvError::Closed => {} - _ => panic!("did not lag"), + _ => panic!("is not closed"), } }; } @@ -517,3 +517,12 @@ fn resubscribe_lagged() { assert_empty!(rx); assert_empty!(rx_resub); } + +#[test] +fn resubscribe_to_closed_channel() { + let (tx, rx) = tokio::sync::broadcast::channel::(2); + drop(tx); + + let mut rx_resub = rx.resubscribe(); + assert_closed!(rx_resub.try_recv()); +}