diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs index 69ede8f89ce..426011e79c3 100644 --- a/tokio/src/sync/broadcast.rs +++ b/tokio/src/sync/broadcast.rs @@ -336,9 +336,6 @@ struct Slot { /// Uniquely identifies the `send` stored in the slot. pos: u64, - /// True signals the channel is closed. - closed: bool, - /// The value being broadcast. /// /// The value is set by `send` when the write lock is held. When a reader @@ -452,7 +449,6 @@ pub fn channel(mut capacity: usize) -> (Sender, Receiver) { buffer.push(RwLock::new(Slot { rem: AtomicUsize::new(0), pos: (i as u64).wrapping_sub(capacity as u64), - closed: false, val: UnsafeCell::new(None), })); } @@ -537,8 +533,43 @@ impl Sender { /// } /// ``` pub fn send(&self, value: T) -> Result> { - self.send2(Some(value)) - .map_err(|SendError(maybe_v)| SendError(maybe_v.unwrap())) + let mut tail = self.shared.tail.lock(); + + if tail.rx_cnt == 0 { + return Err(SendError(value)); + } + + // Position to write into + let pos = tail.pos; + let rem = tail.rx_cnt; + let idx = (pos & self.shared.mask as u64) as usize; + + // Update the tail position + tail.pos = tail.pos.wrapping_add(1); + + // Get the slot + let mut slot = self.shared.buffer[idx].write().unwrap(); + + // Track the position + slot.pos = pos; + + // Set remaining receivers + slot.rem.with_mut(|v| *v = rem); + + // Write the value + slot.val = UnsafeCell::new(Some(value)); + + // Release the slot lock before notifying the receivers. + drop(slot); + + tail.notify_rx(); + + // Release the mutex. This must happen after the slot lock is released, + // otherwise the writer lock bit could be cleared while another thread + // is in the critical section. + drop(tail); + + Ok(rem) } /// Creates a new [`Receiver`] handle that will receive values sent **after** @@ -610,49 +641,11 @@ impl Sender { tail.rx_cnt } - fn send2(&self, value: Option) -> Result>> { + fn close_channel(&self) { let mut tail = self.shared.tail.lock(); - - if tail.rx_cnt == 0 { - return Err(SendError(value)); - } - - // Position to write into - let pos = tail.pos; - let rem = tail.rx_cnt; - let idx = (pos & self.shared.mask as u64) as usize; - - // Update the tail position - tail.pos = tail.pos.wrapping_add(1); - - // Get the slot - let mut slot = self.shared.buffer[idx].write().unwrap(); - - // Track the position - slot.pos = pos; - - // Set remaining receivers - slot.rem.with_mut(|v| *v = rem); - - // Set the closed bit if the value is `None`; otherwise write the value - if value.is_none() { - tail.closed = true; - slot.closed = true; - } else { - slot.val.with_mut(|ptr| unsafe { *ptr = value }); - } - - // Release the slot lock before notifying the receivers. - drop(slot); + tail.closed = true; tail.notify_rx(); - - // Release the mutex. This must happen after the slot lock is released, - // otherwise the writer lock bit could be cleared while another thread - // is in the critical section. - drop(tail); - - Ok(rem) } } @@ -700,7 +693,7 @@ impl Clone for Sender { impl Drop for Sender { fn drop(&mut self) { if 1 == self.shared.num_tx.fetch_sub(1, SeqCst) { - let _ = self.send2(None); + self.close_channel(); } } } @@ -784,14 +777,6 @@ impl Receiver { let mut slot = self.shared.buffer[idx].read().unwrap(); if slot.pos != self.next { - let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64); - - // The receiver has read all current values in the channel and there - // is no waiter to register - if waiter.is_none() && next_pos == self.next { - return Err(TryRecvError::Empty); - } - // Release the `slot` lock before attempting to acquire the `tail` // lock. This is required because `send2` acquires the tail lock // first followed by the slot lock. Acquiring the locks in reverse @@ -813,6 +798,13 @@ impl Receiver { let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64); if next_pos == self.next { + // At this point the channel is empty for *this* receiver. If + // it's been closed, then that's what we return, otherwise we + // set a waker and return empty. + if tail.closed { + return Err(TryRecvError::Closed); + } + // Store the waker if let Some((waiter, waker)) = waiter { // Safety: called while locked. @@ -846,22 +838,7 @@ impl Receiver { // catch up by skipping dropped messages and setting the // internal cursor to the **oldest** message stored by the // channel. - // - // However, finding the oldest position is a bit more - // complicated than `tail-position - buffer-size`. When - // the channel is closed, the tail position is incremented to - // signal a new `None` message, but `None` is not stored in the - // channel itself (see issue #2425 for why). - // - // To account for this, if the channel is closed, the tail - // position is decremented by `buffer-size + 1`. - let mut adjust = 0; - if tail.closed { - adjust = 1 - } - let next = tail - .pos - .wrapping_sub(self.shared.buffer.len() as u64 + adjust); + let next = tail.pos.wrapping_sub(self.shared.buffer.len() as u64); let missed = next.wrapping_sub(self.next); @@ -882,10 +859,6 @@ impl Receiver { self.next = self.next.wrapping_add(1); - if slot.closed { - return Err(TryRecvError::Closed); - } - Ok(RecvGuard { slot }) } } diff --git a/tokio/tests/sync_broadcast.rs b/tokio/tests/sync_broadcast.rs index cee888dd2c0..9aa34841e26 100644 --- a/tokio/tests/sync_broadcast.rs +++ b/tokio/tests/sync_broadcast.rs @@ -47,7 +47,7 @@ macro_rules! assert_closed { ($e:expr) => { match assert_err!($e) { broadcast::error::TryRecvError::Closed => {} - _ => panic!("did not lag"), + _ => panic!("is not closed"), } }; } @@ -517,3 +517,12 @@ fn resubscribe_lagged() { assert_empty!(rx); assert_empty!(rx_resub); } + +#[test] +fn resubscribe_to_closed_channel() { + let (tx, rx) = tokio::sync::broadcast::channel::(2); + drop(tx); + + let mut rx_resub = rx.resubscribe(); + assert_closed!(rx_resub.try_recv()); +}