From f6c0405084cc7efbbf10202d44aaf8885fd3f84f Mon Sep 17 00:00:00 2001 From: estk Date: Sat, 28 May 2022 13:15:27 -0700 Subject: [PATCH] sync: add resubscribe method to broadcast::Receiver (#4607) --- tokio/src/sync/broadcast.rs | 28 ++++++++++++++++++++++++++ tokio/tests/sync_broadcast.rs | 38 +++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs index 846d6c027a6..c796d129920 100644 --- a/tokio/src/sync/broadcast.rs +++ b/tokio/src/sync/broadcast.rs @@ -647,6 +647,7 @@ impl Sender { } } +/// Create a new `Receiver` which reads starting from the tail. fn new_receiver(shared: Arc>) -> Receiver { let mut tail = shared.tail.lock(); @@ -881,6 +882,33 @@ impl Receiver { } impl Receiver { + /// Re-subscribes to the channel starting from the current tail element. + /// + /// This [`Receiver`] handle will receive a clone of all values sent + /// **after** it has resubscribed. This will not include elements that are + /// in the queue of the current receiver. Consider the following example. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = broadcast::channel(2); + /// + /// tx.send(1).unwrap(); + /// let mut rx2 = rx.resubscribe(); + /// tx.send(2).unwrap(); + /// + /// assert_eq!(rx2.recv().await.unwrap(), 2); + /// assert_eq!(rx.recv().await.unwrap(), 1); + /// } + /// ``` + pub fn resubscribe(&self) -> Self { + let shared = self.shared.clone(); + new_receiver(shared) + } /// Receives the next value for this receiver. /// /// Each [`Receiver`] handle will receive a clone of all values sent diff --git a/tokio/tests/sync_broadcast.rs b/tokio/tests/sync_broadcast.rs index ca8b4d7f4ce..b38b6380023 100644 --- a/tokio/tests/sync_broadcast.rs +++ b/tokio/tests/sync_broadcast.rs @@ -479,3 +479,41 @@ fn receiver_len_with_lagged() { fn is_closed(err: broadcast::error::RecvError) -> bool { matches!(err, broadcast::error::RecvError::Closed) } + +#[test] +fn resubscribe_points_to_tail() { + let (tx, mut rx) = broadcast::channel(3); + tx.send(1).unwrap(); + + let mut rx_resub = rx.resubscribe(); + + // verify we're one behind at the start + assert_empty!(rx_resub); + assert_eq!(assert_recv!(rx), 1); + + // verify we do not affect rx + tx.send(2).unwrap(); + assert_eq!(assert_recv!(rx_resub), 2); + tx.send(3).unwrap(); + assert_eq!(assert_recv!(rx), 2); + assert_eq!(assert_recv!(rx), 3); + assert_empty!(rx); + + assert_eq!(assert_recv!(rx_resub), 3); + assert_empty!(rx_resub); +} + +#[test] +fn resubscribe_lagged() { + let (tx, mut rx) = broadcast::channel(1); + tx.send(1).unwrap(); + tx.send(2).unwrap(); + + let mut rx_resub = rx.resubscribe(); + assert_lagged!(rx.try_recv(), 1); + assert_empty!(rx_resub); + + assert_eq!(assert_recv!(rx), 2); + assert_empty!(rx); + assert_empty!(rx_resub); +}