diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs index 0d9cd3bc176..c15f14eeba1 100644 --- a/tokio/src/sync/broadcast.rs +++ b/tokio/src/sync/broadcast.rs @@ -372,11 +372,43 @@ struct Recv<'a, T> { /// Entry in the waiter `LinkedList`. waiter: UnsafeCell, + + // Flag that indicates whether to output the number + // of sent messages that the `Receiver` behind this `Recv` + // has yet to receive. + output_kind: OutputQueueStatus, } unsafe impl<'a, T: Send> Send for Recv<'a, T> {} unsafe impl<'a, T: Send> Sync for Recv<'a, T> {} +#[derive(Copy, Clone)] +enum OutputQueueStatus { + Yes, + No, +} + +enum RecvOutput<'a, T> { + WithQueueStatus(RecvGuard<'a, T>, u64), + WithoutQueueStatus(RecvGuard<'a, T>), +} + +impl<'a, T: Clone> RecvOutput<'a, T> { + pub(crate) fn try_get_value(self) -> Option { + match self { + RecvOutput::WithoutQueueStatus(val) => val.clone_value(), + RecvOutput::WithQueueStatus(_, _) => panic!("expected RecvOutput::WithoutQueueStatus"), + } + } + + pub(crate) fn try_get_value_and_queue_status(self) -> (Option, u64) { + match self { + RecvOutput::WithoutQueueStatus(_) => panic!("expected RecvOutput::WithoutQueueStatus"), + RecvOutput::WithQueueStatus(val, status) => (val.clone_value(), status), + } + } +} + /// Max number of receivers. Reserve space to lock. const MAX_RECEIVERS: usize = usize::MAX >> 2; @@ -695,7 +727,8 @@ impl Receiver { fn recv_ref( &mut self, waiter: Option<(&UnsafeCell, &Waker)>, - ) -> Result, TryRecvError> { + queue_status: OutputQueueStatus, + ) -> Result, TryRecvError> { let idx = (self.next & self.shared.mask as u64) as usize; // The slot holding the next value to read @@ -773,10 +806,7 @@ impl Receiver { // // To account for this, if the channel is closed, the tail // position is decremented by `buffer-size + 1`. - let mut adjust = 0; - if tail.closed { - adjust = 1 - } + let adjust = if tail.closed { 1 } else { 0 }; let next = tail .pos .wrapping_sub(self.shared.buffer.len() as u64 + adjust); @@ -788,8 +818,12 @@ impl Receiver { // The receiver is slow but no values have been missed if missed == 0 { self.next = self.next.wrapping_add(1); + let behind = (self.shared.buffer.len() - 1) as u64; - return Ok(RecvGuard { slot }); + if let OutputQueueStatus::Yes = queue_status { + return Ok(RecvOutput::WithQueueStatus(RecvGuard { slot }, behind)); + } + return Ok(RecvOutput::WithoutQueueStatus(RecvGuard { slot })); } self.next = next; @@ -798,13 +832,62 @@ impl Receiver { } } - self.next = self.next.wrapping_add(1); - if slot.closed { return Err(TryRecvError::Closed); } - Ok(RecvGuard { slot }) + match queue_status { + OutputQueueStatus::Yes => { + // We need to acquire the tail lock here to get access to the next write + // position, but for that we need to drop the slot lock first, see comment + // above in this function (where `slot` is dropped) for an explaination for + // why this is necessary + drop(slot); + + let tail = self.shared.tail.lock(); + let next_write_pos = tail.pos; + drop(tail); + + slot = self.shared.buffer[idx].read().unwrap(); + + if slot.pos == self.next { + self.next = self.next.wrapping_add(1); + let behind = next_write_pos.wrapping_sub(self.next); + + Ok(RecvOutput::WithQueueStatus(RecvGuard { slot }, behind)) + } else { + // this is unlikely to happen, but if it does we have lagged behind the sender + // (this is because `self.next != slot.pos + buffer-size` must hold given that this + // condition wasn't fulfilled earlier in this method and `slot.pos` could only have + // increased). We proceed as above in the part of this method that handles missed + // messages. + + drop(slot); + let tail = self.shared.tail.lock(); + slot = self.shared.buffer[idx].read().unwrap(); + + let adjust = if tail.closed { 1 } else { 0 }; + let next = tail + .pos + .wrapping_sub(self.shared.buffer.len() as u64 + adjust); + drop(tail); + + let missed = next.wrapping_sub(self.next); + if missed == 0 { + self.next = self.next.wrapping_add(1); + let behind = self.shared.buffer.len() as u64; + return Ok(RecvOutput::WithQueueStatus(RecvGuard { slot }, behind)); + } + + self.next = next; + Err(TryRecvError::Lagged(missed)) + } + } + OutputQueueStatus::No => { + self.next = self.next.wrapping_add(1); + Ok(RecvOutput::WithoutQueueStatus(RecvGuard { slot })) + } + } } } @@ -883,7 +966,101 @@ impl Receiver { /// ``` pub async fn recv(&mut self) -> Result { let fut = Recv::new(self); - fut.await + match fut.await { + Ok(rcv_output) => match rcv_output { + OutputRecvPoll::WithoutQueueStatus(msg) => Ok(msg), + OutputRecvPoll::WithQueueStatus(_, _) => { + panic!("Cannot receive OutputRecvPoll::WithQueueStatus here") + } + }, + Err(e) => Err(e), + } + } + + /// Receives the next value for this receiver and the number of messages + /// that were sent by a sender and have not yet been received by this + /// receiver. + /// + /// Each [`Receiver`] handle will receive a clone of all values sent + /// **after** it has subscribed. + /// + /// `Err(RecvError::Closed)` is returned when all `Sender` halves have + /// dropped, indicating that no further values can be sent on the channel. + /// + /// If the [`Receiver`] handle falls behind, once the channel is full, newly + /// sent values will overwrite old values. At this point, a call to + /// [`recv_with_queue_status`] will return with `Err(RecvError::Lagged)` and + /// the [`Receiver`]'s internal cursor is updated to point to the oldest value + /// still held by the channel. A subsequent call to [`recv_with_queue_status`] + /// will return this value **unless** it has been since overwritten. + /// + /// # Cancel safety + /// + /// This method is cancel safe. If `recv_with_queue_status` is used as the + /// event in a [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// channel. + /// + /// [`Receiver`]: crate::sync::broadcast::Receiver + /// [`recv_with_queue_status`]: crate::sync::broadcast::Receiver::recv_with_queue_status + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// use tokio::time::sleep; + /// use std::time::Duration; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx1) = broadcast::channel(16); + /// + /// let handle = tokio::spawn(async move { + /// sleep(Duration::from_millis(200)).await; + /// assert_eq!(rx1.recv_with_queue_status().await.unwrap(), (10, 1)); + /// assert_eq!(rx1.recv_with_queue_status().await.unwrap(), (20, 0)); + /// }); + /// + /// tx.send(10).unwrap(); + /// tx.send(20).unwrap(); + /// + /// handle.await.unwrap(); + /// } + /// ``` + /// + /// Handling lag + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = broadcast::channel(2); + /// + /// tx.send(10).unwrap(); + /// tx.send(20).unwrap(); + /// tx.send(30).unwrap(); + /// + /// // The receiver lagged behind + /// assert!(rx.recv().await.is_err()); + /// + /// // At this point, we can abort or continue with lost messages + /// + /// assert_eq!(20, rx.recv().await.unwrap()); + /// assert_eq!(30, rx.recv().await.unwrap()); + /// } + /// ``` + pub async fn recv_with_queue_status(&mut self) -> Result<(T, u64), RecvError> { + let fut = Recv::new_with_queue_status(self); + match fut.await { + Ok(rcv_output) => match rcv_output { + OutputRecvPoll::WithQueueStatus(msg, queue_status) => Ok((msg, queue_status)), + OutputRecvPoll::WithoutQueueStatus(_) => { + panic!("Cannot receive OutputRecvPoll::WithoutQueueStatus here") + } + }, + Err(e) => Err(e), + } } /// Attempts to return a pending value on this receiver without awaiting. @@ -927,8 +1104,9 @@ impl Receiver { /// } /// ``` pub fn try_recv(&mut self) -> Result { - let guard = self.recv_ref(None)?; - guard.clone_value().ok_or(TryRecvError::Closed) + let guard = self.recv_ref(None, OutputQueueStatus::No)?.try_get_value(); + + guard.ok_or(TryRecvError::Closed) } } @@ -942,7 +1120,7 @@ impl Drop for Receiver { drop(tail); while self.next < until { - match self.recv_ref(None) { + match self.recv_ref(None, OutputQueueStatus::No) { Ok(_) => {} // The channel is closed Err(TryRecvError::Closed) => break, @@ -965,6 +1143,20 @@ impl<'a, T> Recv<'a, T> { pointers: linked_list::Pointers::new(), _p: PhantomPinned, }), + output_kind: OutputQueueStatus::No, + } + } + + fn new_with_queue_status(receiver: &'a mut Receiver) -> Recv<'a, T> { + Recv { + receiver, + waiter: UnsafeCell::new(Waiter { + queued: false, + waker: None, + pointers: linked_list::Pointers::new(), + _p: PhantomPinned, + }), + output_kind: OutputQueueStatus::Yes, } } @@ -981,23 +1173,49 @@ impl<'a, T> Recv<'a, T> { } } +enum OutputRecvPoll { + WithQueueStatus(T, u64), + WithoutQueueStatus(T), +} + impl<'a, T> Future for Recv<'a, T> where T: Clone, { - type Output = Result; + type Output = Result, RecvError>; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, RecvError>> { + let output_queue_status = self.output_kind; let (receiver, waiter) = self.project(); - let guard = match receiver.recv_ref(Some((waiter, cx.waker()))) { - Ok(value) => value, + match receiver.recv_ref(Some((waiter, cx.waker())), output_queue_status) { + Ok(value) => match output_queue_status { + OutputQueueStatus::Yes => { + let (out_opt, queue_status) = value.try_get_value_and_queue_status(); + match out_opt { + Some(out) => { + return Poll::Ready(Ok(OutputRecvPoll::WithQueueStatus( + out, + queue_status, + ))); + } + None => return Poll::Ready(Err(RecvError::Closed)), + } + } + OutputQueueStatus::No => match value.try_get_value() { + Some(out) => { + return Poll::Ready(Ok(OutputRecvPoll::WithoutQueueStatus(out))); + } + None => return Poll::Ready(Err(RecvError::Closed)), + }, + }, Err(TryRecvError::Empty) => return Poll::Pending, Err(TryRecvError::Lagged(n)) => return Poll::Ready(Err(RecvError::Lagged(n))), Err(TryRecvError::Closed) => return Poll::Ready(Err(RecvError::Closed)), - }; - - Poll::Ready(guard.clone_value().ok_or(RecvError::Closed)) + } } } diff --git a/tokio/tests/sync_broadcast.rs b/tokio/tests/sync_broadcast.rs index 1b68eb7edbd..7d07b8dfa0d 100644 --- a/tokio/tests/sync_broadcast.rs +++ b/tokio/tests/sync_broadcast.rs @@ -5,13 +5,16 @@ #[cfg(target_arch = "wasm32")] use wasm_bindgen_test::wasm_bindgen_test as test; +use tokio::runtime; use tokio::sync::broadcast; +use tokio::time; use tokio_test::task; use tokio_test::{ assert_err, assert_ok, assert_pending, assert_ready, assert_ready_err, assert_ready_ok, }; use std::sync::Arc; +use std::time::Duration; macro_rules! assert_recv { ($e:expr) => { @@ -460,3 +463,28 @@ fn lagging_receiver_recovers_after_wrap_open() { fn is_closed(err: broadcast::error::RecvError) -> bool { matches!(err, broadcast::error::RecvError::Closed) } + +#[test] +fn recv_with_queue_status() { + let rt = runtime::Builder::new_current_thread() + .enable_time() + .start_paused(true) + .build() + .unwrap(); + + let (tx, mut rx1) = broadcast::channel(16); + + rt.block_on(async { + let handle1 = rt.spawn(async move { + time::sleep(Duration::from_millis(1)).await; + assert_eq!(rx1.recv_with_queue_status().await.unwrap(), (10, 1)); + assert_eq!(rx1.recv_with_queue_status().await.unwrap(), (20, 0)); + }); + + tx.send(10).unwrap(); + tx.send(20).unwrap(); + + time::resume(); + handle1.await.unwrap(); + }); +}