From c390a62387fe7346951c8bc57ea2761614b83e82 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Thu, 12 Jan 2023 13:53:31 -0500 Subject: [PATCH] Add broadcast::Sender::len (#5343) * Add broadcast::Sender::len * Add a randomized test for broadcast::Sender::len * fix wasm build * less silly cfg * review feedback * grammar? --- tokio/src/sync/broadcast.rs | 95 ++++++++++++++++++++++++++++++++++- tokio/tests/sync_broadcast.rs | 60 ++++++++++++++++++++++ 2 files changed, 153 insertions(+), 2 deletions(-) diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs index ede990b046e..1c6b2caa3bb 100644 --- a/tokio/src/sync/broadcast.rs +++ b/tokio/src/sync/broadcast.rs @@ -603,6 +603,97 @@ impl Sender { new_receiver(shared) } + /// Returns the number of queued values. + /// + /// A value is queued until it has either been seen by all receivers that were alive at the time + /// it was sent, or has been evicted from the queue by subsequent sends that exceeded the + /// queue's capacity. + /// + /// # Note + /// + /// In contrast to [`Receiver::len`], this method only reports queued values and not values that + /// have been evicted from the queue before being seen by all receivers. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx1) = broadcast::channel(16); + /// let mut rx2 = tx.subscribe(); + /// + /// tx.send(10).unwrap(); + /// tx.send(20).unwrap(); + /// tx.send(30).unwrap(); + /// + /// assert_eq!(tx.len(), 3); + /// + /// rx1.recv().await.unwrap(); + /// + /// // The len is still 3 since rx2 hasn't seen the first value yet. + /// assert_eq!(tx.len(), 3); + /// + /// rx2.recv().await.unwrap(); + /// + /// assert_eq!(tx.len(), 2); + /// } + /// ``` + pub fn len(&self) -> usize { + let tail = self.shared.tail.lock(); + + let base_idx = (tail.pos & self.shared.mask as u64) as usize; + let mut low = 0; + let mut high = self.shared.buffer.len(); + while low < high { + let mid = low + (high - low) / 2; + let idx = base_idx.wrapping_add(mid) & self.shared.mask; + if self.shared.buffer[idx].read().unwrap().rem.load(SeqCst) == 0 { + low = mid + 1; + } else { + high = mid; + } + } + + self.shared.buffer.len() - low + } + + /// Returns true if there are no queued values. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx1) = broadcast::channel(16); + /// let mut rx2 = tx.subscribe(); + /// + /// assert!(tx.is_empty()); + /// + /// tx.send(10).unwrap(); + /// + /// assert!(!tx.is_empty()); + /// + /// rx1.recv().await.unwrap(); + /// + /// // The queue is still not empty since rx2 hasn't seen the value. + /// assert!(!tx.is_empty()); + /// + /// rx2.recv().await.unwrap(); + /// + /// assert!(tx.is_empty()); + /// } + /// ``` + pub fn is_empty(&self) -> bool { + let tail = self.shared.tail.lock(); + + let idx = (tail.pos.wrapping_sub(1) & self.shared.mask as u64) as usize; + self.shared.buffer[idx].read().unwrap().rem.load(SeqCst) == 0 + } + /// Returns the number of active receivers /// /// An active receiver is a [`Receiver`] handle returned from [`channel`] or @@ -731,7 +822,7 @@ impl Receiver { /// assert_eq!(rx1.len(), 2); /// assert_eq!(rx1.recv().await.unwrap(), 10); /// assert_eq!(rx1.len(), 1); - /// assert_eq!(rx1.recv().await.unwrap(), 20); + /// assert_eq!(rx1.recv().await.unwrap(), 20); /// assert_eq!(rx1.len(), 0); /// } /// ``` @@ -761,7 +852,7 @@ impl Receiver { /// /// assert!(!rx1.is_empty()); /// assert_eq!(rx1.recv().await.unwrap(), 10); - /// assert_eq!(rx1.recv().await.unwrap(), 20); + /// assert_eq!(rx1.recv().await.unwrap(), 20); /// assert!(rx1.is_empty()); /// } /// ``` diff --git a/tokio/tests/sync_broadcast.rs b/tokio/tests/sync_broadcast.rs index 9aa34841e26..67c378b84a6 100644 --- a/tokio/tests/sync_broadcast.rs +++ b/tokio/tests/sync_broadcast.rs @@ -526,3 +526,63 @@ fn resubscribe_to_closed_channel() { let mut rx_resub = rx.resubscribe(); assert_closed!(rx_resub.try_recv()); } + +#[test] +fn sender_len() { + let (tx, mut rx1) = broadcast::channel(4); + let mut rx2 = tx.subscribe(); + + assert_eq!(tx.len(), 0); + assert!(tx.is_empty()); + + tx.send(1).unwrap(); + tx.send(2).unwrap(); + tx.send(3).unwrap(); + + assert_eq!(tx.len(), 3); + assert!(!tx.is_empty()); + + assert_recv!(rx1); + assert_recv!(rx1); + + assert_eq!(tx.len(), 3); + assert!(!tx.is_empty()); + + assert_recv!(rx2); + + assert_eq!(tx.len(), 2); + assert!(!tx.is_empty()); + + tx.send(4).unwrap(); + tx.send(5).unwrap(); + tx.send(6).unwrap(); + + assert_eq!(tx.len(), 4); + assert!(!tx.is_empty()); +} + +#[test] +#[cfg(not(tokio_wasm_not_wasi))] +fn sender_len_random() { + use rand::Rng; + + let (tx, mut rx1) = broadcast::channel(16); + let mut rx2 = tx.subscribe(); + + for _ in 0..1000 { + match rand::thread_rng().gen_range(0..4) { + 0 => { + let _ = rx1.try_recv(); + } + 1 => { + let _ = rx2.try_recv(); + } + _ => { + tx.send(0).unwrap(); + } + } + + let expected_len = usize::min(usize::max(rx1.len(), rx2.len()), 16); + assert_eq!(tx.len(), expected_len); + } +}