Skip to content

Commit

Permalink
sync: remove broadcast channel slot level closed flag (#4867)
Browse files Browse the repository at this point in the history
The broadcast channel allows multiple senders to send messages to
multiple receivers, where each receiver receives messages starting from
when it subscribes. After all senders are dropped, the receivers will
continue to receive all waiting messages in the buffer and then receive
a `Closed` error.

To mark that a channel has closed, it stores two closed flags, one on
the channel level and another in the buffer slot *after* the last used
slot (this may also be the earliest entry being kept for lagged
receivers, see #2425).

However, we don't need both closed flags, keeping the channel level
closed flag is sufficient.

Without the slot level closed flag, each receiver receives each message
until it is up to date and for that receiver the channel is empty. Then,
the actual return message is chosen depending on the channel level
closed flag; if the channel is NOT closed, then `Empty` is returned, if
the channel is closed then `Closed` is returned instead.

With the modified logic, there is no longer a need to append a closed
token to the internal buffer (by setting the slot level closed flag on
the next slot). This fixes the off by one error described in #4814,
which caused a receiver which was created after the channel was already
closed to get `Empty` from `try_recv` (or hang forever when calling
`recv`) instead of receiving `Closed`.

As a bonus, we save a single `bool` on each buffer slot.

Refs: #4814
  • Loading branch information
hds committed Aug 10, 2022
1 parent 53cf021 commit 9d9488d
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 76 deletions.
123 changes: 48 additions & 75 deletions tokio/src/sync/broadcast.rs
Expand Up @@ -336,9 +336,6 @@ struct Slot<T> {
/// Uniquely identifies the `send` stored in the slot.
pos: u64,

/// True signals the channel is closed.
closed: bool,

/// The value being broadcast.
///
/// The value is set by `send` when the write lock is held. When a reader
Expand Down Expand Up @@ -452,7 +449,6 @@ pub fn channel<T: Clone>(mut capacity: usize) -> (Sender<T>, Receiver<T>) {
buffer.push(RwLock::new(Slot {
rem: AtomicUsize::new(0),
pos: (i as u64).wrapping_sub(capacity as u64),
closed: false,
val: UnsafeCell::new(None),
}));
}
Expand Down Expand Up @@ -537,8 +533,43 @@ impl<T> Sender<T> {
/// }
/// ```
pub fn send(&self, value: T) -> Result<usize, SendError<T>> {
self.send2(Some(value))
.map_err(|SendError(maybe_v)| SendError(maybe_v.unwrap()))
let mut tail = self.shared.tail.lock();

if tail.rx_cnt == 0 {
return Err(SendError(value));
}

// Position to write into
let pos = tail.pos;
let rem = tail.rx_cnt;
let idx = (pos & self.shared.mask as u64) as usize;

// Update the tail position
tail.pos = tail.pos.wrapping_add(1);

// Get the slot
let mut slot = self.shared.buffer[idx].write().unwrap();

// Track the position
slot.pos = pos;

// Set remaining receivers
slot.rem.with_mut(|v| *v = rem);

// Write the value
slot.val = UnsafeCell::new(Some(value));

// Release the slot lock before notifying the receivers.
drop(slot);

tail.notify_rx();

// Release the mutex. This must happen after the slot lock is released,
// otherwise the writer lock bit could be cleared while another thread
// is in the critical section.
drop(tail);

Ok(rem)
}

/// Creates a new [`Receiver`] handle that will receive values sent **after**
Expand Down Expand Up @@ -610,49 +641,11 @@ impl<T> Sender<T> {
tail.rx_cnt
}

fn send2(&self, value: Option<T>) -> Result<usize, SendError<Option<T>>> {
fn close_channel(&self) {
let mut tail = self.shared.tail.lock();

if tail.rx_cnt == 0 {
return Err(SendError(value));
}

// Position to write into
let pos = tail.pos;
let rem = tail.rx_cnt;
let idx = (pos & self.shared.mask as u64) as usize;

// Update the tail position
tail.pos = tail.pos.wrapping_add(1);

// Get the slot
let mut slot = self.shared.buffer[idx].write().unwrap();

// Track the position
slot.pos = pos;

// Set remaining receivers
slot.rem.with_mut(|v| *v = rem);

// Set the closed bit if the value is `None`; otherwise write the value
if value.is_none() {
tail.closed = true;
slot.closed = true;
} else {
slot.val.with_mut(|ptr| unsafe { *ptr = value });
}

// Release the slot lock before notifying the receivers.
drop(slot);
tail.closed = true;

tail.notify_rx();

// Release the mutex. This must happen after the slot lock is released,
// otherwise the writer lock bit could be cleared while another thread
// is in the critical section.
drop(tail);

Ok(rem)
}
}

Expand Down Expand Up @@ -700,7 +693,7 @@ impl<T> Clone for Sender<T> {
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
if 1 == self.shared.num_tx.fetch_sub(1, SeqCst) {
let _ = self.send2(None);
self.close_channel();
}
}
}
Expand Down Expand Up @@ -784,14 +777,6 @@ impl<T> Receiver<T> {
let mut slot = self.shared.buffer[idx].read().unwrap();

if slot.pos != self.next {
let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64);

// The receiver has read all current values in the channel and there
// is no waiter to register
if waiter.is_none() && next_pos == self.next {
return Err(TryRecvError::Empty);
}

// Release the `slot` lock before attempting to acquire the `tail`
// lock. This is required because `send2` acquires the tail lock
// first followed by the slot lock. Acquiring the locks in reverse
Expand All @@ -813,6 +798,13 @@ impl<T> Receiver<T> {
let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64);

if next_pos == self.next {
// At this point the channel is empty for *this* receiver. If
// it's been closed, then that's what we return, otherwise we
// set a waker and return empty.
if tail.closed {
return Err(TryRecvError::Closed);
}

// Store the waker
if let Some((waiter, waker)) = waiter {
// Safety: called while locked.
Expand Down Expand Up @@ -846,22 +838,7 @@ impl<T> Receiver<T> {
// catch up by skipping dropped messages and setting the
// internal cursor to the **oldest** message stored by the
// channel.
//
// However, finding the oldest position is a bit more
// complicated than `tail-position - buffer-size`. When
// the channel is closed, the tail position is incremented to
// signal a new `None` message, but `None` is not stored in the
// channel itself (see issue #2425 for why).
//
// To account for this, if the channel is closed, the tail
// position is decremented by `buffer-size + 1`.
let mut adjust = 0;
if tail.closed {
adjust = 1
}
let next = tail
.pos
.wrapping_sub(self.shared.buffer.len() as u64 + adjust);
let next = tail.pos.wrapping_sub(self.shared.buffer.len() as u64);

let missed = next.wrapping_sub(self.next);

Expand All @@ -882,10 +859,6 @@ impl<T> Receiver<T> {

self.next = self.next.wrapping_add(1);

if slot.closed {
return Err(TryRecvError::Closed);
}

Ok(RecvGuard { slot })
}
}
Expand Down
11 changes: 10 additions & 1 deletion tokio/tests/sync_broadcast.rs
Expand Up @@ -47,7 +47,7 @@ macro_rules! assert_closed {
($e:expr) => {
match assert_err!($e) {
broadcast::error::TryRecvError::Closed => {}
_ => panic!("did not lag"),
_ => panic!("is not closed"),
}
};
}
Expand Down Expand Up @@ -517,3 +517,12 @@ fn resubscribe_lagged() {
assert_empty!(rx);
assert_empty!(rx_resub);
}

#[test]
fn resubscribe_to_closed_channel() {
let (tx, rx) = tokio::sync::broadcast::channel::<u32>(2);
drop(tx);

let mut rx_resub = rx.resubscribe();
assert_closed!(rx_resub.try_recv());
}

0 comments on commit 9d9488d

Please sign in to comment.