Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sync: remove broadcast channel slot level closed flag #4867

Merged
merged 2 commits into from Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
123 changes: 48 additions & 75 deletions tokio/src/sync/broadcast.rs
Expand Up @@ -336,9 +336,6 @@ struct Slot<T> {
/// Uniquely identifies the `send` stored in the slot.
pos: u64,

/// True signals the channel is closed.
closed: bool,

/// The value being broadcast.
///
/// The value is set by `send` when the write lock is held. When a reader
Expand Down Expand Up @@ -452,7 +449,6 @@ pub fn channel<T: Clone>(mut capacity: usize) -> (Sender<T>, Receiver<T>) {
buffer.push(RwLock::new(Slot {
rem: AtomicUsize::new(0),
pos: (i as u64).wrapping_sub(capacity as u64),
closed: false,
val: UnsafeCell::new(None),
}));
}
Expand Down Expand Up @@ -537,8 +533,43 @@ impl<T> Sender<T> {
/// }
/// ```
pub fn send(&self, value: T) -> Result<usize, SendError<T>> {
self.send2(Some(value))
.map_err(|SendError(maybe_v)| SendError(maybe_v.unwrap()))
let mut tail = self.shared.tail.lock();

if tail.rx_cnt == 0 {
return Err(SendError(value));
}

// Position to write into
let pos = tail.pos;
let rem = tail.rx_cnt;
let idx = (pos & self.shared.mask as u64) as usize;

// Update the tail position
tail.pos = tail.pos.wrapping_add(1);

// Get the slot
let mut slot = self.shared.buffer[idx].write().unwrap();

// Track the position
slot.pos = pos;

// Set remaining receivers
slot.rem.with_mut(|v| *v = rem);

// Write the value
slot.val.with_mut(|ptr| unsafe { *ptr = Some(value) });
hds marked this conversation as resolved.
Show resolved Hide resolved

// Release the slot lock before notifying the receivers.
drop(slot);

tail.notify_rx();

// Release the mutex. This must happen after the slot lock is released,
// otherwise the writer lock bit could be cleared while another thread
// is in the critical section.
drop(tail);

Ok(rem)
}

/// Creates a new [`Receiver`] handle that will receive values sent **after**
Expand Down Expand Up @@ -610,49 +641,11 @@ impl<T> Sender<T> {
tail.rx_cnt
}

fn send2(&self, value: Option<T>) -> Result<usize, SendError<Option<T>>> {
fn close_channel(&self) {
let mut tail = self.shared.tail.lock();

if tail.rx_cnt == 0 {
return Err(SendError(value));
}

// Position to write into
let pos = tail.pos;
let rem = tail.rx_cnt;
let idx = (pos & self.shared.mask as u64) as usize;

// Update the tail position
tail.pos = tail.pos.wrapping_add(1);

// Get the slot
let mut slot = self.shared.buffer[idx].write().unwrap();

// Track the position
slot.pos = pos;

// Set remaining receivers
slot.rem.with_mut(|v| *v = rem);

// Set the closed bit if the value is `None`; otherwise write the value
if value.is_none() {
tail.closed = true;
slot.closed = true;
} else {
slot.val.with_mut(|ptr| unsafe { *ptr = value });
}

// Release the slot lock before notifying the receivers.
drop(slot);
tail.closed = true;

tail.notify_rx();

// Release the mutex. This must happen after the slot lock is released,
// otherwise the writer lock bit could be cleared while another thread
// is in the critical section.
drop(tail);

Ok(rem)
}
}

Expand Down Expand Up @@ -700,7 +693,7 @@ impl<T> Clone for Sender<T> {
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
if 1 == self.shared.num_tx.fetch_sub(1, SeqCst) {
let _ = self.send2(None);
self.close_channel();
}
}
}
Expand Down Expand Up @@ -784,14 +777,6 @@ impl<T> Receiver<T> {
let mut slot = self.shared.buffer[idx].read().unwrap();

if slot.pos != self.next {
let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64);

// The receiver has read all current values in the channel and there
// is no waiter to register
if waiter.is_none() && next_pos == self.next {
return Err(TryRecvError::Empty);
}

// Release the `slot` lock before attempting to acquire the `tail`
// lock. This is required because `send2` acquires the tail lock
// first followed by the slot lock. Acquiring the locks in reverse
Expand All @@ -813,6 +798,13 @@ impl<T> Receiver<T> {
let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64);

if next_pos == self.next {
// At this point the channel is empty for *this* receiver. If
// it's been closed, then that's what we return, otherwise we
// set a waker and return empty.
if tail.closed {
return Err(TryRecvError::Closed);
}

// Store the waker
if let Some((waiter, waker)) = waiter {
// Safety: called while locked.
Expand Down Expand Up @@ -846,22 +838,7 @@ impl<T> Receiver<T> {
// catch up by skipping dropped messages and setting the
// internal cursor to the **oldest** message stored by the
// channel.
//
// However, finding the oldest position is a bit more
// complicated than `tail-position - buffer-size`. When
// the channel is closed, the tail position is incremented to
// signal a new `None` message, but `None` is not stored in the
// channel itself (see issue #2425 for why).
//
// To account for this, if the channel is closed, the tail
// position is decremented by `buffer-size + 1`.
let mut adjust = 0;
if tail.closed {
adjust = 1
}
let next = tail
.pos
.wrapping_sub(self.shared.buffer.len() as u64 + adjust);
let next = tail.pos.wrapping_sub(self.shared.buffer.len() as u64);

let missed = next.wrapping_sub(self.next);

Expand All @@ -882,10 +859,6 @@ impl<T> Receiver<T> {

self.next = self.next.wrapping_add(1);

if slot.closed {
return Err(TryRecvError::Closed);
}

Ok(RecvGuard { slot })
}
}
Expand Down
11 changes: 10 additions & 1 deletion tokio/tests/sync_broadcast.rs
Expand Up @@ -47,7 +47,7 @@ macro_rules! assert_closed {
($e:expr) => {
match assert_err!($e) {
broadcast::error::TryRecvError::Closed => {}
_ => panic!("did not lag"),
_ => panic!("is not closed"),
}
};
}
Expand Down Expand Up @@ -517,3 +517,12 @@ fn resubscribe_lagged() {
assert_empty!(rx);
assert_empty!(rx_resub);
}

#[test]
fn resubscribe_to_closed_channel() {
let (tx, rx) = tokio::sync::broadcast::channel::<u32>(2);
drop(tx);

let mut rx_resub = rx.resubscribe();
assert_closed!(rx_resub.try_recv());
}