Skip to content

Commit

Permalink
add recv_with_queue_status
Browse files Browse the repository at this point in the history
  • Loading branch information
b-naber committed Feb 25, 2022
1 parent e8f19e7 commit 0518874
Show file tree
Hide file tree
Showing 2 changed files with 266 additions and 20 deletions.
258 changes: 238 additions & 20 deletions tokio/src/sync/broadcast.rs
Expand Up @@ -372,11 +372,43 @@ struct Recv<'a, T> {

/// Entry in the waiter `LinkedList`.
waiter: UnsafeCell<Waiter>,

// Flag that indicates whether to output the number
// of sent messages that the `Receiver` behind this `Recv`
// has yet to receive.
output_kind: OutputQueueStatus,
}

unsafe impl<'a, T: Send> Send for Recv<'a, T> {}
unsafe impl<'a, T: Send> Sync for Recv<'a, T> {}

#[derive(Copy, Clone)]
enum OutputQueueStatus {
Yes,
No,
}

enum RecvOutput<'a, T> {
WithQueueStatus(RecvGuard<'a, T>, u64),
WithoutQueueStatus(RecvGuard<'a, T>),
}

impl<'a, T: Clone> RecvOutput<'a, T> {
pub(crate) fn try_get_value(self) -> Option<T> {
match self {
RecvOutput::WithoutQueueStatus(val) => val.clone_value(),
RecvOutput::WithQueueStatus(_, _) => panic!("expected RecvOutput::WithoutQueueStatus"),
}
}

pub(crate) fn try_get_value_and_queue_status(self) -> (Option<T>, u64) {
match self {
RecvOutput::WithoutQueueStatus(_) => panic!("expected RecvOutput::WithoutQueueStatus"),
RecvOutput::WithQueueStatus(val, status) => (val.clone_value(), status),
}
}
}

/// Max number of receivers. Reserve space to lock.
const MAX_RECEIVERS: usize = usize::MAX >> 2;

Expand Down Expand Up @@ -695,7 +727,8 @@ impl<T> Receiver<T> {
fn recv_ref(
&mut self,
waiter: Option<(&UnsafeCell<Waiter>, &Waker)>,
) -> Result<RecvGuard<'_, T>, TryRecvError> {
queue_status: OutputQueueStatus,
) -> Result<RecvOutput<'_, T>, TryRecvError> {
let idx = (self.next & self.shared.mask as u64) as usize;

// The slot holding the next value to read
Expand Down Expand Up @@ -773,10 +806,7 @@ impl<T> Receiver<T> {
//
// To account for this, if the channel is closed, the tail
// position is decremented by `buffer-size + 1`.
let mut adjust = 0;
if tail.closed {
adjust = 1
}
let adjust = if tail.closed { 1 } else { 0 };
let next = tail
.pos
.wrapping_sub(self.shared.buffer.len() as u64 + adjust);
Expand All @@ -788,8 +818,12 @@ impl<T> Receiver<T> {
// The receiver is slow but no values have been missed
if missed == 0 {
self.next = self.next.wrapping_add(1);
let behind = (self.shared.buffer.len() - 1) as u64;

return Ok(RecvGuard { slot });
if let OutputQueueStatus::Yes = queue_status {
return Ok(RecvOutput::WithQueueStatus(RecvGuard { slot }, behind));
}
return Ok(RecvOutput::WithoutQueueStatus(RecvGuard { slot }));
}

self.next = next;
Expand All @@ -798,13 +832,62 @@ impl<T> Receiver<T> {
}
}

self.next = self.next.wrapping_add(1);

if slot.closed {
return Err(TryRecvError::Closed);
}

Ok(RecvGuard { slot })
match queue_status {
OutputQueueStatus::Yes => {
// We need to acquire the tail lock here to get access to the next write
// position, but for that we need to drop the slot lock first, see comment
// above in this function (where `slot` is dropped) for an explaination for
// why this is necessary
drop(slot);

let tail = self.shared.tail.lock();
let next_write_pos = tail.pos;
drop(tail);

slot = self.shared.buffer[idx].read().unwrap();

if slot.pos == self.next {
self.next = self.next.wrapping_add(1);
let behind = next_write_pos.wrapping_sub(self.next);

Ok(RecvOutput::WithQueueStatus(RecvGuard { slot }, behind))
} else {
// this is unlikely to happen, but if it does we have lagged behind the sender
// (this is because `self.next != slot.pos + buffer-size` must hold given that this
// condition wasn't fulfilled earlier in this method and `slot.pos` could only have
// increased). We proceed as above in the part of this method that handles missed
// messages.

drop(slot);
let tail = self.shared.tail.lock();
slot = self.shared.buffer[idx].read().unwrap();

let adjust = if tail.closed { 1 } else { 0 };
let next = tail
.pos
.wrapping_sub(self.shared.buffer.len() as u64 + adjust);
drop(tail);

let missed = next.wrapping_sub(self.next);
if missed == 0 {
self.next = self.next.wrapping_add(1);
let behind = self.shared.buffer.len() as u64;
return Ok(RecvOutput::WithQueueStatus(RecvGuard { slot }, behind));
}

self.next = next;
Err(TryRecvError::Lagged(missed))
}
}
OutputQueueStatus::No => {
self.next = self.next.wrapping_add(1);
Ok(RecvOutput::WithoutQueueStatus(RecvGuard { slot }))
}
}
}
}

Expand Down Expand Up @@ -883,7 +966,101 @@ impl<T: Clone> Receiver<T> {
/// ```
pub async fn recv(&mut self) -> Result<T, RecvError> {
let fut = Recv::new(self);
fut.await
match fut.await {
Ok(rcv_output) => match rcv_output {
OutputRecvPoll::WithoutQueueStatus(msg) => Ok(msg),
OutputRecvPoll::WithQueueStatus(_, _) => {
panic!("Cannot receive OutputRecvPoll::WithQueueStatus here")
}
},
Err(e) => Err(e),
}
}

/// Receives the next value for this receiver and the number of messages
/// that were sent by a sender and have not yet been received by this
/// receiver.
///
/// Each [`Receiver`] handle will receive a clone of all values sent
/// **after** it has subscribed.
///
/// `Err(RecvError::Closed)` is returned when all `Sender` halves have
/// dropped, indicating that no further values can be sent on the channel.
///
/// If the [`Receiver`] handle falls behind, once the channel is full, newly
/// sent values will overwrite old values. At this point, a call to
/// [`recv_with_queue_status`] will return with `Err(RecvError::Lagged)` and
/// the [`Receiver`]'s internal cursor is updated to point to the oldest value
/// still held by the channel. A subsequent call to [`recv_with_queue_status`]
/// will return this value **unless** it has been since overwritten.
///
/// # Cancel safety
///
/// This method is cancel safe. If `recv_with_queue_status` is used as the
/// event in a [`tokio::select!`](crate::select) statement and some other branch
/// completes first, it is guaranteed that no messages were received on this
/// channel.
///
/// [`Receiver`]: crate::sync::broadcast::Receiver
/// [`recv_with_queue_status`]: crate::sync::broadcast::Receiver::recv_with_queue_status
///
/// # Examples
///
/// ```
/// use tokio::sync::broadcast;
/// use tokio::time::sleep;
/// use std::time::Duration;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, mut rx1) = broadcast::channel(16);
///
/// let handle = tokio::spawn(async move {
/// sleep(Duration::from_millis(200)).await;
/// assert_eq!(rx1.recv_with_queue_status().await.unwrap(), (10, 1));
/// assert_eq!(rx1.recv_with_queue_status().await.unwrap(), (20, 0));
/// });
///
/// tx.send(10).unwrap();
/// tx.send(20).unwrap();
///
/// handle.await.unwrap();
/// }
/// ```
///
/// Handling lag
///
/// ```
/// use tokio::sync::broadcast;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, mut rx) = broadcast::channel(2);
///
/// tx.send(10).unwrap();
/// tx.send(20).unwrap();
/// tx.send(30).unwrap();
///
/// // The receiver lagged behind
/// assert!(rx.recv().await.is_err());
///
/// // At this point, we can abort or continue with lost messages
///
/// assert_eq!(20, rx.recv().await.unwrap());
/// assert_eq!(30, rx.recv().await.unwrap());
/// }
/// ```
pub async fn recv_with_queue_status(&mut self) -> Result<(T, u64), RecvError> {
let fut = Recv::new_with_queue_status(self);
match fut.await {
Ok(rcv_output) => match rcv_output {
OutputRecvPoll::WithQueueStatus(msg, queue_status) => Ok((msg, queue_status)),
OutputRecvPoll::WithoutQueueStatus(_) => {
panic!("Cannot receive OutputRecvPoll::WithoutQueueStatus here")
}
},
Err(e) => Err(e),
}
}

/// Attempts to return a pending value on this receiver without awaiting.
Expand Down Expand Up @@ -927,8 +1104,9 @@ impl<T: Clone> Receiver<T> {
/// }
/// ```
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
let guard = self.recv_ref(None)?;
guard.clone_value().ok_or(TryRecvError::Closed)
let guard = self.recv_ref(None, OutputQueueStatus::No)?.try_get_value();

guard.ok_or(TryRecvError::Closed)
}
}

Expand All @@ -942,7 +1120,7 @@ impl<T> Drop for Receiver<T> {
drop(tail);

while self.next < until {
match self.recv_ref(None) {
match self.recv_ref(None, OutputQueueStatus::No) {
Ok(_) => {}
// The channel is closed
Err(TryRecvError::Closed) => break,
Expand All @@ -965,6 +1143,20 @@ impl<'a, T> Recv<'a, T> {
pointers: linked_list::Pointers::new(),
_p: PhantomPinned,
}),
output_kind: OutputQueueStatus::No,
}
}

fn new_with_queue_status(receiver: &'a mut Receiver<T>) -> Recv<'a, T> {
Recv {
receiver,
waiter: UnsafeCell::new(Waiter {
queued: false,
waker: None,
pointers: linked_list::Pointers::new(),
_p: PhantomPinned,
}),
output_kind: OutputQueueStatus::Yes,
}
}

Expand All @@ -981,23 +1173,49 @@ impl<'a, T> Recv<'a, T> {
}
}

enum OutputRecvPoll<T> {
WithQueueStatus(T, u64),
WithoutQueueStatus(T),
}

impl<'a, T> Future for Recv<'a, T>
where
T: Clone,
{
type Output = Result<T, RecvError>;
type Output = Result<OutputRecvPoll<T>, RecvError>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
fn poll(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<OutputRecvPoll<T>, RecvError>> {
let output_queue_status = self.output_kind;
let (receiver, waiter) = self.project();

let guard = match receiver.recv_ref(Some((waiter, cx.waker()))) {
Ok(value) => value,
match receiver.recv_ref(Some((waiter, cx.waker())), output_queue_status) {
Ok(value) => match output_queue_status {
OutputQueueStatus::Yes => {
let (out_opt, queue_status) = value.try_get_value_and_queue_status();
match out_opt {
Some(out) => {
return Poll::Ready(Ok(OutputRecvPoll::WithQueueStatus(
out,
queue_status,
)));
}
None => return Poll::Ready(Err(RecvError::Closed)),
}
}
OutputQueueStatus::No => match value.try_get_value() {
Some(out) => {
return Poll::Ready(Ok(OutputRecvPoll::WithoutQueueStatus(out)));
}
None => return Poll::Ready(Err(RecvError::Closed)),
},
},
Err(TryRecvError::Empty) => return Poll::Pending,
Err(TryRecvError::Lagged(n)) => return Poll::Ready(Err(RecvError::Lagged(n))),
Err(TryRecvError::Closed) => return Poll::Ready(Err(RecvError::Closed)),
};

Poll::Ready(guard.clone_value().ok_or(RecvError::Closed))
}
}
}

Expand Down
28 changes: 28 additions & 0 deletions tokio/tests/sync_broadcast.rs
Expand Up @@ -5,13 +5,16 @@
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::wasm_bindgen_test as test;

use tokio::runtime;
use tokio::sync::broadcast;
use tokio::time;
use tokio_test::task;
use tokio_test::{
assert_err, assert_ok, assert_pending, assert_ready, assert_ready_err, assert_ready_ok,
};

use std::sync::Arc;
use std::time::Duration;

macro_rules! assert_recv {
($e:expr) => {
Expand Down Expand Up @@ -460,3 +463,28 @@ fn lagging_receiver_recovers_after_wrap_open() {
fn is_closed(err: broadcast::error::RecvError) -> bool {
matches!(err, broadcast::error::RecvError::Closed)
}

#[test]
fn recv_with_queue_status() {
let rt = runtime::Builder::new_current_thread()
.enable_time()
.start_paused(true)
.build()
.unwrap();

let (tx, mut rx1) = broadcast::channel(16);

rt.block_on(async {
let handle1 = rt.spawn(async move {
time::sleep(Duration::from_millis(1)).await;
assert_eq!(rx1.recv_with_queue_status().await.unwrap(), (10, 1));
assert_eq!(rx1.recv_with_queue_status().await.unwrap(), (20, 0));
});

tx.send(10).unwrap();
tx.send(20).unwrap();

time::resume();
handle1.await.unwrap();
});
}

0 comments on commit 0518874

Please sign in to comment.