From cf025ba45f68934ae2138bb75ee2a5ee50506d1b Mon Sep 17 00:00:00 2001 From: Carl Lerche Date: Thu, 24 Sep 2020 17:26:38 -0700 Subject: [PATCH] sync: support mpsc send with `&self` (#2861) Updates the mpsc channel to use the intrusive waker based sempahore. This enables using `Sender` with `&self`. Instead of using `Sender::poll_ready` to ensure capacity and updating the `Sender` state, `async fn Sender::reserve()` is added. This function returns a `Permit` value representing the reserved capacity. Fixes: #2637 Refs: #2718 (intrusive waiters) --- tokio-test/src/io.rs | 4 +- tokio/src/signal/unix.rs | 30 +- tokio/src/stream/mod.rs | 4 +- tokio/src/stream/stream_map.rs | 4 +- tokio/src/sync/batch_semaphore.rs | 10 +- tokio/src/sync/mod.rs | 9 +- tokio/src/sync/mpsc/bounded.rs | 338 +++--- tokio/src/sync/mpsc/chan.rs | 268 +---- tokio/src/sync/mpsc/error.rs | 20 - tokio/src/sync/mpsc/mod.rs | 2 +- tokio/src/sync/mpsc/unbounded.rs | 39 +- tokio/src/sync/semaphore_ll.rs | 1221 --------------------- tokio/src/sync/tests/loom_mpsc.rs | 14 +- tokio/src/sync/tests/loom_semaphore_ll.rs | 192 ---- tokio/src/sync/tests/mod.rs | 2 - tokio/src/sync/tests/semaphore_ll.rs | 470 -------- tokio/src/util/linked_list.rs | 20 +- tokio/tests/rt_threaded.rs | 10 +- tokio/tests/sync_mpsc.rs | 363 +++--- 19 files changed, 459 insertions(+), 2561 deletions(-) delete mode 100644 tokio/src/sync/semaphore_ll.rs delete mode 100644 tokio/src/sync/tests/loom_semaphore_ll.rs delete mode 100644 tokio/src/sync/tests/semaphore_ll.rs diff --git a/tokio-test/src/io.rs b/tokio-test/src/io.rs index 4f0b589743b..b91ddc34ea2 100644 --- a/tokio-test/src/io.rs +++ b/tokio-test/src/io.rs @@ -200,7 +200,9 @@ impl Inner { } fn poll_action(&mut self, cx: &mut task::Context<'_>) -> Poll> { - self.rx.poll_recv(cx) + use futures_core::stream::Stream; + + Pin::new(&mut self.rx).poll_next(cx) } fn read(&mut self, dst: &mut ReadBuf<'_>) -> io::Result<()> { diff --git a/tokio/src/signal/unix.rs b/tokio/src/signal/unix.rs index 45a091d76a6..30a05872331 100644 --- a/tokio/src/signal/unix.rs +++ b/tokio/src/signal/unix.rs @@ -391,35 +391,7 @@ impl Signal { poll_fn(|cx| self.poll_recv(cx)).await } - /// Polls to receive the next signal notification event, outside of an - /// `async` context. - /// - /// `None` is returned if no more events can be received by this stream. - /// - /// # Examples - /// - /// Polling from a manually implemented future - /// - /// ```rust,no_run - /// use std::pin::Pin; - /// use std::future::Future; - /// use std::task::{Context, Poll}; - /// use tokio::signal::unix::Signal; - /// - /// struct MyFuture { - /// signal: Signal, - /// } - /// - /// impl Future for MyFuture { - /// type Output = Option<()>; - /// - /// fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - /// println!("polling MyFuture"); - /// self.signal.poll_recv(cx) - /// } - /// } - /// ``` - pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { + pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { self.rx.poll_recv(cx) } } diff --git a/tokio/src/stream/mod.rs b/tokio/src/stream/mod.rs index 6a99d9d85b2..59e1482f33b 100644 --- a/tokio/src/stream/mod.rs +++ b/tokio/src/stream/mod.rs @@ -270,8 +270,8 @@ pub trait StreamExt: Stream { /// # #[tokio::main(basic_scheduler)] /// async fn main() { /// # time::pause(); - /// let (mut tx1, rx1) = mpsc::channel(10); - /// let (mut tx2, rx2) = mpsc::channel(10); + /// let (tx1, rx1) = mpsc::channel(10); + /// let (tx2, rx2) = mpsc::channel(10); /// /// let mut rx = rx1.merge(rx2); /// diff --git a/tokio/src/stream/stream_map.rs b/tokio/src/stream/stream_map.rs index 2f60ea4ddaf..a1c80f1520c 100644 --- a/tokio/src/stream/stream_map.rs +++ b/tokio/src/stream/stream_map.rs @@ -57,8 +57,8 @@ use std::task::{Context, Poll}; /// /// #[tokio::main] /// async fn main() { -/// let (mut tx1, rx1) = mpsc::channel(10); -/// let (mut tx2, rx2) = mpsc::channel(10); +/// let (tx1, rx1) = mpsc::channel(10); +/// let (tx2, rx2) = mpsc::channel(10); /// /// tokio::spawn(async move { /// tx1.send(1).await.unwrap(); diff --git a/tokio/src/sync/batch_semaphore.rs b/tokio/src/sync/batch_semaphore.rs index a1048ca3734..9f324f9c928 100644 --- a/tokio/src/sync/batch_semaphore.rs +++ b/tokio/src/sync/batch_semaphore.rs @@ -165,7 +165,6 @@ impl Semaphore { /// permits and notifies all pending waiters. // This will be used once the bounded MPSC is updated to use the new // semaphore implementation. - #[allow(dead_code)] pub(crate) fn close(&self) { let mut waiters = self.waiters.lock().unwrap(); // If the semaphore's permits counter has enough permits for an @@ -185,6 +184,11 @@ impl Semaphore { } } + /// Returns true if the semaphore is closed + pub(crate) fn is_closed(&self) -> bool { + self.permits.load(Acquire) & Self::CLOSED == Self::CLOSED + } + pub(crate) fn try_acquire(&self, num_permits: u32) -> Result<(), TryAcquireError> { assert!( num_permits as usize <= Self::MAX_PERMITS, @@ -194,8 +198,8 @@ impl Semaphore { let num_permits = (num_permits as usize) << Self::PERMIT_SHIFT; let mut curr = self.permits.load(Acquire); loop { - // Has the semaphore closed?git - if curr & Self::CLOSED > 0 { + // Has the semaphore closed? + if curr & Self::CLOSED == Self::CLOSED { return Err(TryAcquireError::Closed); } diff --git a/tokio/src/sync/mod.rs b/tokio/src/sync/mod.rs index 4c069467dee..6531931b365 100644 --- a/tokio/src/sync/mod.rs +++ b/tokio/src/sync/mod.rs @@ -106,7 +106,7 @@ //! //! #[tokio::main] //! async fn main() { -//! let (mut tx, mut rx) = mpsc::channel(100); +//! let (tx, mut rx) = mpsc::channel(100); //! //! tokio::spawn(async move { //! for i in 0..10 { @@ -150,7 +150,7 @@ //! for _ in 0..10 { //! // Each task needs its own `tx` handle. This is done by cloning the //! // original handle. -//! let mut tx = tx.clone(); +//! let tx = tx.clone(); //! //! tokio::spawn(async move { //! tx.send(&b"data to write"[..]).await.unwrap(); @@ -213,7 +213,7 @@ //! //! // Spawn tasks that will send the increment command. //! for _ in 0..10 { -//! let mut cmd_tx = cmd_tx.clone(); +//! let cmd_tx = cmd_tx.clone(); //! //! join_handles.push(tokio::spawn(async move { //! let (resp_tx, resp_rx) = oneshot::channel(); @@ -443,7 +443,6 @@ cfg_sync! { pub mod oneshot; pub(crate) mod batch_semaphore; - pub(crate) mod semaphore_ll; mod semaphore; pub use semaphore::{Semaphore, SemaphorePermit, OwnedSemaphorePermit}; @@ -473,7 +472,7 @@ cfg_not_sync! { cfg_signal_internal! { pub(crate) mod mpsc; - pub(crate) mod semaphore_ll; + pub(crate) mod batch_semaphore; } } diff --git a/tokio/src/sync/mpsc/bounded.rs b/tokio/src/sync/mpsc/bounded.rs index 14e4731aaae..2d2006d5883 100644 --- a/tokio/src/sync/mpsc/bounded.rs +++ b/tokio/src/sync/mpsc/bounded.rs @@ -1,6 +1,6 @@ +use crate::sync::batch_semaphore::{self as semaphore, TryAcquireError}; use crate::sync::mpsc::chan; -use crate::sync::mpsc::error::{ClosedError, SendError, TryRecvError, TrySendError}; -use crate::sync::semaphore_ll as semaphore; +use crate::sync::mpsc::error::{SendError, TryRecvError, TrySendError}; cfg_time! { use crate::sync::mpsc::error::SendTimeoutError; @@ -8,6 +8,7 @@ cfg_time! { } use std::fmt; +#[cfg(any(feature = "signal", feature = "process", feature = "stream"))] use std::task::{Context, Poll}; /// Send values to the associated `Receiver`. @@ -17,20 +18,14 @@ pub struct Sender { chan: chan::Tx, } -impl Clone for Sender { - fn clone(&self) -> Self { - Sender { - chan: self.chan.clone(), - } - } -} - -impl fmt::Debug for Sender { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("Sender") - .field("chan", &self.chan) - .finish() - } +/// Permit to send one value into the channel. +/// +/// `Permit` values are returned by [`Sender::reserve()`] and are used to +/// guarantee channel capacity before generating a message to send. +/// +/// [`Sender::reserve()`]: Sender::reserve +pub struct Permit<'a, T> { + chan: &'a chan::Tx, } /// Receive values from the associated `Sender`. @@ -41,14 +36,6 @@ pub struct Receiver { chan: chan::Rx, } -impl fmt::Debug for Receiver { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("Receiver") - .field("chan", &self.chan) - .finish() - } -} - /// Creates a bounded mpsc channel for communicating between asynchronous tasks /// with backpressure. /// @@ -77,7 +64,7 @@ impl fmt::Debug for Receiver { /// /// #[tokio::main] /// async fn main() { -/// let (mut tx, mut rx) = mpsc::channel(100); +/// let (tx, mut rx) = mpsc::channel(100); /// /// tokio::spawn(async move { /// for i in 0..10 { @@ -125,7 +112,7 @@ impl Receiver { /// /// #[tokio::main] /// async fn main() { - /// let (mut tx, mut rx) = mpsc::channel(100); + /// let (tx, mut rx) = mpsc::channel(100); /// /// tokio::spawn(async move { /// tx.send("hello").await.unwrap(); @@ -143,7 +130,7 @@ impl Receiver { /// /// #[tokio::main] /// async fn main() { - /// let (mut tx, mut rx) = mpsc::channel(100); + /// let (tx, mut rx) = mpsc::channel(100); /// /// tx.send("hello").await.unwrap(); /// tx.send("world").await.unwrap(); @@ -154,12 +141,11 @@ impl Receiver { /// ``` pub async fn recv(&mut self) -> Option { use crate::future::poll_fn; - - poll_fn(|cx| self.poll_recv(cx)).await + poll_fn(|cx| self.chan.recv(cx)).await } - #[doc(hidden)] // TODO: document - pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { + #[cfg(any(feature = "signal", feature = "process"))] + pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { self.chan.recv(cx) } @@ -178,7 +164,7 @@ impl Receiver { /// use tokio::sync::mpsc; /// /// fn main() { - /// let (mut tx, mut rx) = mpsc::channel::(10); + /// let (tx, mut rx) = mpsc::channel::(10); /// /// let sync_code = thread::spawn(move || { /// assert_eq!(Some(10), rx.blocking_recv()); @@ -215,12 +201,53 @@ impl Receiver { /// Closes the receiving half of a channel, without dropping it. /// /// This prevents any further messages from being sent on the channel while - /// still enabling the receiver to drain messages that are buffered. + /// still enabling the receiver to drain messages that are buffered. Any + /// outstanding [`Permit`] values will still be able to send messages. + /// + /// In order to guarantee no messages are dropped, after calling `close()`, + /// `recv()` must be called until `None` is returned. + /// + /// [`Permit`]: Permit + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(20); + /// + /// tokio::spawn(async move { + /// let mut i = 0; + /// while let Ok(permit) = tx.reserve().await { + /// permit.send(i); + /// i += 1; + /// } + /// }); + /// + /// rx.close(); + /// + /// while let Some(msg) = rx.recv().await { + /// println!("got {}", msg); + /// } + /// + /// // Channel closed and no messages are lost. + /// } + /// ``` pub fn close(&mut self) { self.chan.close(); } } +impl fmt::Debug for Receiver { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Receiver") + .field("chan", &self.chan) + .finish() + } +} + impl Unpin for Receiver {} cfg_stream! { @@ -228,7 +255,7 @@ cfg_stream! { type Item = T; fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.poll_recv(cx) + self.chan.recv(cx) } } } @@ -267,7 +294,7 @@ impl Sender { /// /// #[tokio::main] /// async fn main() { - /// let (mut tx, mut rx) = mpsc::channel(1); + /// let (tx, mut rx) = mpsc::channel(1); /// /// tokio::spawn(async move { /// for i in 0..10 { @@ -283,17 +310,13 @@ impl Sender { /// } /// } /// ``` - pub async fn send(&mut self, value: T) -> Result<(), SendError> { - use crate::future::poll_fn; - - if poll_fn(|cx| self.poll_ready(cx)).await.is_err() { - return Err(SendError(value)); - } - - match self.try_send(value) { - Ok(()) => Ok(()), - Err(TrySendError::Full(_)) => unreachable!(), - Err(TrySendError::Closed(value)) => Err(SendError(value)), + pub async fn send(&self, value: T) -> Result<(), SendError> { + match self.reserve().await { + Ok(permit) => { + permit.send(value); + Ok(()) + } + Err(_) => Err(SendError(value)), } } @@ -304,9 +327,6 @@ impl Sender { /// with [`send`], this function has two failure cases instead of one (one for /// disconnection, one for a full buffer). /// - /// This function may be paired with [`poll_ready`] in order to wait for - /// channel capacity before trying to send a value. - /// /// # Errors /// /// If the channel capacity has been reached, i.e., the channel has `n` @@ -318,7 +338,6 @@ impl Sender { /// an error. The error includes the value passed to `send`. /// /// [`send`]: Sender::send - /// [`poll_ready`]: Sender::poll_ready /// [`channel`]: channel /// [`close`]: Receiver::close /// @@ -330,8 +349,8 @@ impl Sender { /// #[tokio::main] /// async fn main() { /// // Create a channel with buffer size 1 - /// let (mut tx1, mut rx) = mpsc::channel(1); - /// let mut tx2 = tx1.clone(); + /// let (tx1, mut rx) = mpsc::channel(1); + /// let tx2 = tx1.clone(); /// /// tokio::spawn(async move { /// tx1.send(1).await.unwrap(); @@ -359,8 +378,15 @@ impl Sender { /// } /// } /// ``` - pub fn try_send(&mut self, message: T) -> Result<(), TrySendError> { - self.chan.try_send(message)?; + pub fn try_send(&self, message: T) -> Result<(), TrySendError> { + match self.chan.semaphore().0.try_acquire(1) { + Ok(_) => {} + Err(TryAcquireError::Closed) => return Err(TrySendError::Closed(message)), + Err(TryAcquireError::NoPermits) => return Err(TrySendError::Full(message)), + } + + // Send the message + self.chan.send(message); Ok(()) } @@ -392,7 +418,7 @@ impl Sender { /// /// #[tokio::main] /// async fn main() { - /// let (mut tx, mut rx) = mpsc::channel(1); + /// let (tx, mut rx) = mpsc::channel(1); /// /// tokio::spawn(async move { /// for i in 0..10 { @@ -412,27 +438,22 @@ impl Sender { #[cfg(feature = "time")] #[cfg_attr(docsrs, doc(cfg(feature = "time")))] pub async fn send_timeout( - &mut self, + &self, value: T, timeout: Duration, ) -> Result<(), SendTimeoutError> { - use crate::future::poll_fn; - - match crate::time::timeout(timeout, poll_fn(|cx| self.poll_ready(cx))).await { + let permit = match crate::time::timeout(timeout, self.reserve()).await { Err(_) => { return Err(SendTimeoutError::Timeout(value)); } Ok(Err(_)) => { return Err(SendTimeoutError::Closed(value)); } - Ok(_) => {} - } + Ok(Ok(permit)) => permit, + }; - match self.try_send(value) { - Ok(()) => Ok(()), - Err(TrySendError::Full(_)) => unreachable!(), - Err(TrySendError::Closed(value)) => Err(SendTimeoutError::Closed(value)), - } + permit.send(value); + Ok(()) } /// Blocking send to call outside of asynchronous contexts. @@ -450,7 +471,7 @@ impl Sender { /// use tokio::sync::mpsc; /// /// fn main() { - /// let (mut tx, mut rx) = mpsc::channel::(1); + /// let (tx, mut rx) = mpsc::channel::(1); /// /// let sync_code = thread::spawn(move || { /// tx.blocking_send(10).unwrap(); @@ -462,92 +483,139 @@ impl Sender { /// sync_code.join().unwrap() /// } /// ``` - pub fn blocking_send(&mut self, value: T) -> Result<(), SendError> { + pub fn blocking_send(&self, value: T) -> Result<(), SendError> { let mut enter_handle = crate::runtime::enter::enter(false); enter_handle.block_on(self.send(value)).unwrap() } - /// Returns `Poll::Ready(Ok(()))` when the channel is able to accept another item. + /// Wait for channel capacity. Once capacity to send one message is + /// available, it is reserved for the caller. /// - /// If the channel is full, then `Poll::Pending` is returned and the task is notified when a - /// slot becomes available. + /// If the channel is full, the function waits for the number of unreceived + /// messages to become less than the channel capacity. Capacity to send one + /// message is reserved for the caller. A [`Permit`] is returned to track + /// the reserved capacity. The [`send`] function on [`Permit`] consumes the + /// reserved capacity. /// - /// Once `poll_ready` returns `Poll::Ready(Ok(()))`, a call to `try_send` will succeed unless - /// the channel has since been closed. To provide this guarantee, the channel reserves one slot - /// in the channel for the coming send. This reserved slot is not available to other `Sender` - /// instances, so you need to be careful to not end up with deadlocks by blocking after calling - /// `poll_ready` but before sending an element. + /// Dropping [`Permit`] without sending a message releases the capacity back + /// to the channel. /// - /// If, after `poll_ready` succeeds, you decide you do not wish to send an item after all, you - /// can use [`disarm`](Sender::disarm) to release the reserved slot. + /// [`Permit`]: Permit + /// [`send`]: Permit::send /// - /// Until an item is sent or [`disarm`](Sender::disarm) is called, repeated calls to - /// `poll_ready` will return either `Poll::Ready(Ok(()))` or `Poll::Ready(Err(_))` if channel - /// is closed. - pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.chan.poll_ready(cx).map_err(|_| ClosedError::new()) - } - - /// Undo a successful call to `poll_ready`. + /// # Examples /// - /// Once a call to `poll_ready` returns `Poll::Ready(Ok(()))`, it holds up one slot in the - /// channel to make room for the coming send. `disarm` allows you to give up that slot if you - /// decide you do not wish to send an item after all. After calling `disarm`, you must call - /// `poll_ready` until it returns `Poll::Ready(Ok(()))` before attempting to send again. + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); /// - /// Returns `false` if no slot is reserved for this sender (usually because `poll_ready` was - /// not previously called, or did not succeed). + /// // Reserve capacity + /// let permit = tx.reserve().await.unwrap(); /// - /// # Motivation + /// // Trying to send directly on the `tx` will fail due to no + /// // available capacity. + /// assert!(tx.try_send(123).is_err()); /// - /// Since `poll_ready` takes up one of the finite number of slots in a bounded channel, callers - /// need to send an item shortly after `poll_ready` succeeds. If they do not, idle senders may - /// take up all the slots of the channel, and prevent active senders from getting any requests - /// through. Consider this code that forwards from one channel to another: + /// // Sending on the permit succeeds + /// permit.send(456); /// - /// ```rust,ignore - /// loop { - /// ready!(tx.poll_ready(cx))?; - /// if let Some(item) = ready!(rx.poll_recv(cx)) { - /// tx.try_send(item)?; - /// } else { - /// break; - /// } + /// // The value sent on the permit is received + /// assert_eq!(rx.recv().await.unwrap(), 456); /// } /// ``` + pub async fn reserve(&self) -> Result, SendError<()>> { + match self.chan.semaphore().0.acquire(1).await { + Ok(_) => {} + Err(_) => return Err(SendError(())), + } + + Ok(Permit { chan: &self.chan }) + } +} + +impl Clone for Sender { + fn clone(&self) -> Self { + Sender { + chan: self.chan.clone(), + } + } +} + +impl fmt::Debug for Sender { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Sender") + .field("chan", &self.chan) + .finish() + } +} + +// ===== impl Permit ===== + +impl Permit<'_, T> { + /// Sends a value using the reserved capacity. + /// + /// Capacity for the message has already been reserved. The message is sent + /// to the receiver and the permit is consumed. The operation will succeed + /// even if the receiver half has been closed. See [`Receiver::close`] for + /// more details on performing a clean shutdown. + /// + /// [`Receiver::close`]: Receiver::close + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// // Reserve capacity + /// let permit = tx.reserve().await.unwrap(); + /// + /// // Trying to send directly on the `tx` will fail due to no + /// // available capacity. + /// assert!(tx.try_send(123).is_err()); /// - /// If many such forwarders exist, and they all forward into a single (cloned) `Sender`, then - /// any number of forwarders may be waiting for `rx.poll_recv` at the same time. While they do, - /// they are effectively each reducing the channel's capacity by 1. If enough of these - /// forwarders are idle, forwarders whose `rx` _do_ have elements will be unable to find a spot - /// for them through `poll_ready`, and the system will deadlock. - /// - /// `disarm` solves this problem by allowing you to give up the reserved slot if you find that - /// you have to block. We can then fix the code above by writing: - /// - /// ```rust,ignore - /// loop { - /// ready!(tx.poll_ready(cx))?; - /// let item = rx.poll_recv(cx); - /// if let Poll::Ready(Ok(_)) = item { - /// // we're going to send the item below, so don't disarm - /// } else { - /// // give up our send slot, we won't need it for a while - /// tx.disarm(); - /// } - /// if let Some(item) = ready!(item) { - /// tx.try_send(item)?; - /// } else { - /// break; - /// } + /// // Send a message on the permit + /// permit.send(456); + /// + /// // The value sent on the permit is received + /// assert_eq!(rx.recv().await.unwrap(), 456); /// } /// ``` - pub fn disarm(&mut self) -> bool { - if self.chan.is_ready() { - self.chan.disarm(); - true - } else { - false + pub fn send(self, value: T) { + use std::mem; + + self.chan.send(value); + + // Avoid the drop logic + mem::forget(self); + } +} + +impl Drop for Permit<'_, T> { + fn drop(&mut self) { + use chan::Semaphore; + + let semaphore = self.chan.semaphore(); + + // Add the permit back to the semaphore + semaphore.add_permit(); + + if semaphore.is_closed() && semaphore.is_idle() { + self.chan.wake_rx(); } } } + +impl fmt::Debug for Permit<'_, T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Permit") + .field("chan", &self.chan) + .finish() + } +} diff --git a/tokio/src/sync/mpsc/chan.rs b/tokio/src/sync/mpsc/chan.rs index 0a53cda2038..2d3f014996a 100644 --- a/tokio/src/sync/mpsc/chan.rs +++ b/tokio/src/sync/mpsc/chan.rs @@ -2,8 +2,8 @@ use crate::loom::cell::UnsafeCell; use crate::loom::future::AtomicWaker; use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::Arc; -use crate::sync::mpsc::error::{ClosedError, TryRecvError}; -use crate::sync::mpsc::{error, list}; +use crate::sync::mpsc::error::TryRecvError; +use crate::sync::mpsc::list; use std::fmt; use std::process; @@ -12,21 +12,13 @@ use std::task::Poll::{Pending, Ready}; use std::task::{Context, Poll}; /// Channel sender -pub(crate) struct Tx { +pub(crate) struct Tx { inner: Arc>, - permit: S::Permit, } -impl fmt::Debug for Tx -where - S::Permit: fmt::Debug, - S: fmt::Debug, -{ +impl fmt::Debug for Tx { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("Tx") - .field("inner", &self.inner) - .field("permit", &self.permit) - .finish() + fmt.debug_struct("Tx").field("inner", &self.inner).finish() } } @@ -35,71 +27,20 @@ pub(crate) struct Rx { inner: Arc>, } -impl fmt::Debug for Rx -where - S: fmt::Debug, -{ +impl fmt::Debug for Rx { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("Rx").field("inner", &self.inner).finish() } } -#[derive(Debug, Eq, PartialEq)] -pub(crate) enum TrySendError { - Closed, - Full, -} - -impl From<(T, TrySendError)> for error::SendError { - fn from(src: (T, TrySendError)) -> error::SendError { - match src.1 { - TrySendError::Closed => error::SendError(src.0), - TrySendError::Full => unreachable!(), - } - } -} - -impl From<(T, TrySendError)> for error::TrySendError { - fn from(src: (T, TrySendError)) -> error::TrySendError { - match src.1 { - TrySendError::Closed => error::TrySendError::Closed(src.0), - TrySendError::Full => error::TrySendError::Full(src.0), - } - } -} - pub(crate) trait Semaphore { - type Permit; - - fn new_permit() -> Self::Permit; - - /// The permit is dropped without a value being sent. In this case, the - /// permit must be returned to the semaphore. - /// - /// # Return - /// - /// Returns true if the permit was acquired. - fn drop_permit(&self, permit: &mut Self::Permit) -> bool; - fn is_idle(&self) -> bool; fn add_permit(&self); - fn poll_acquire( - &self, - cx: &mut Context<'_>, - permit: &mut Self::Permit, - ) -> Poll>; - - fn try_acquire(&self, permit: &mut Self::Permit) -> Result<(), TrySendError>; - - /// A value was sent into the channel and the permit held by `tx` is - /// dropped. In this case, the permit should not immeditely be returned to - /// the semaphore. Instead, the permit is returnred to the semaphore once - /// the sent value is read by the rx handle. - fn forget(&self, permit: &mut Self::Permit); - fn close(&self); + + fn is_closed(&self) -> bool; } struct Chan { @@ -157,10 +98,7 @@ impl fmt::Debug for RxFields { unsafe impl Send for Chan {} unsafe impl Sync for Chan {} -pub(crate) fn channel(semaphore: S) -> (Tx, Rx) -where - S: Semaphore, -{ +pub(crate) fn channel(semaphore: S) -> (Tx, Rx) { let (tx, rx) = list::channel(); let chan = Arc::new(Chan { @@ -179,48 +117,27 @@ where // ===== impl Tx ===== -impl Tx -where - S: Semaphore, -{ +impl Tx { fn new(chan: Arc>) -> Tx { - Tx { - inner: chan, - permit: S::new_permit(), - } + Tx { inner: chan } } - pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.semaphore.poll_acquire(cx, &mut self.permit) - } - - pub(crate) fn disarm(&mut self) { - // TODO: should this error if not acquired? - self.inner.semaphore.drop_permit(&mut self.permit); + pub(super) fn semaphore(&self) -> &S { + &self.inner.semaphore } /// Send a message and notify the receiver. - pub(crate) fn try_send(&mut self, value: T) -> Result<(), (T, TrySendError)> { - self.inner.try_send(value, &mut self.permit) - } -} - -impl Tx { - pub(crate) fn is_ready(&self) -> bool { - self.permit.is_acquired() + pub(crate) fn send(&self, value: T) { + self.inner.send(value); } -} -impl Tx { - pub(crate) fn send_unbounded(&self, value: T) -> Result<(), (T, TrySendError)> { - self.inner.try_send(value, &mut ()) + /// Wake the receive half + pub(crate) fn wake_rx(&self) { + self.inner.rx_waker.wake(); } } -impl Clone for Tx -where - S: Semaphore, -{ +impl Clone for Tx { fn clone(&self) -> Tx { // Using a Relaxed ordering here is sufficient as the caller holds a // strong ref to `self`, preventing a concurrent decrement to zero. @@ -228,22 +145,12 @@ where Tx { inner: self.inner.clone(), - permit: S::new_permit(), } } } -impl Drop for Tx -where - S: Semaphore, -{ +impl Drop for Tx { fn drop(&mut self) { - let notify = self.inner.semaphore.drop_permit(&mut self.permit); - - if notify && self.inner.semaphore.is_idle() { - self.inner.rx_waker.wake(); - } - if self.inner.tx_count.fetch_sub(1, AcqRel) != 1 { return; } @@ -252,16 +159,13 @@ where self.inner.tx.close(); // Notify the receiver - self.inner.rx_waker.wake(); + self.wake_rx(); } } // ===== impl Rx ===== -impl Rx -where - S: Semaphore, -{ +impl Rx { fn new(chan: Arc>) -> Rx { Rx { inner: chan } } @@ -349,10 +253,7 @@ where } } -impl Drop for Rx -where - S: Semaphore, -{ +impl Drop for Rx { fn drop(&mut self) { use super::block::Read::Value; @@ -370,25 +271,13 @@ where // ===== impl Chan ===== -impl Chan -where - S: Semaphore, -{ - fn try_send(&self, value: T, permit: &mut S::Permit) -> Result<(), (T, TrySendError)> { - if let Err(e) = self.semaphore.try_acquire(permit) { - return Err((value, e)); - } - +impl Chan { + fn send(&self, value: T) { // Push the value self.tx.push(value); // Notify the rx task self.rx_waker.wake(); - - // Release the permit - self.semaphore.forget(permit); - - Ok(()) } } @@ -407,74 +296,24 @@ impl Drop for Chan { } } -use crate::sync::semaphore_ll::TryAcquireError; - -impl From for TrySendError { - fn from(src: TryAcquireError) -> TrySendError { - if src.is_closed() { - TrySendError::Closed - } else if src.is_no_permits() { - TrySendError::Full - } else { - unreachable!(); - } - } -} - // ===== impl Semaphore for (::Semaphore, capacity) ===== -use crate::sync::semaphore_ll::Permit; - -impl Semaphore for (crate::sync::semaphore_ll::Semaphore, usize) { - type Permit = Permit; - - fn new_permit() -> Permit { - Permit::new() - } - - fn drop_permit(&self, permit: &mut Permit) -> bool { - let ret = permit.is_acquired(); - permit.release(1, &self.0); - ret - } - +impl Semaphore for (crate::sync::batch_semaphore::Semaphore, usize) { fn add_permit(&self) { - self.0.add_permits(1) + self.0.release(1) } fn is_idle(&self) -> bool { self.0.available_permits() == self.1 } - fn poll_acquire( - &self, - cx: &mut Context<'_>, - permit: &mut Permit, - ) -> Poll> { - // Keep track of task budget - let coop = ready!(crate::coop::poll_proceed(cx)); - - permit - .poll_acquire(cx, 1, &self.0) - .map_err(|_| ClosedError::new()) - .map(move |r| { - coop.made_progress(); - r - }) - } - - fn try_acquire(&self, permit: &mut Permit) -> Result<(), TrySendError> { - permit.try_acquire(1, &self.0)?; - Ok(()) - } - - fn forget(&self, permit: &mut Self::Permit) { - permit.forget(1); - } - fn close(&self) { self.0.close(); } + + fn is_closed(&self) -> bool { + self.0.is_closed() + } } // ===== impl Semaphore for AtomicUsize ===== @@ -483,14 +322,6 @@ use std::sync::atomic::Ordering::{Acquire, Release}; use std::usize; impl Semaphore for AtomicUsize { - type Permit = (); - - fn new_permit() {} - - fn drop_permit(&self, _permit: &mut ()) -> bool { - false - } - fn add_permit(&self) { let prev = self.fetch_sub(2, Release); @@ -504,40 +335,11 @@ impl Semaphore for AtomicUsize { self.load(Acquire) >> 1 == 0 } - fn poll_acquire( - &self, - _cx: &mut Context<'_>, - permit: &mut (), - ) -> Poll> { - Ready(self.try_acquire(permit).map_err(|_| ClosedError::new())) - } - - fn try_acquire(&self, _permit: &mut ()) -> Result<(), TrySendError> { - let mut curr = self.load(Acquire); - - loop { - if curr & 1 == 1 { - return Err(TrySendError::Closed); - } - - if curr == usize::MAX ^ 1 { - // Overflowed the ref count. There is no safe way to recover, so - // abort the process. In practice, this should never happen. - process::abort() - } - - match self.compare_exchange(curr, curr + 2, AcqRel, Acquire) { - Ok(_) => return Ok(()), - Err(actual) => { - curr = actual; - } - } - } - } - - fn forget(&self, _permit: &mut ()) {} - fn close(&self) { self.fetch_or(1, Release); } + + fn is_closed(&self) -> bool { + self.load(Acquire) & 1 == 1 + } } diff --git a/tokio/src/sync/mpsc/error.rs b/tokio/src/sync/mpsc/error.rs index 72c42aa53e7..77054529c69 100644 --- a/tokio/src/sync/mpsc/error.rs +++ b/tokio/src/sync/mpsc/error.rs @@ -94,26 +94,6 @@ impl fmt::Display for TryRecvError { impl Error for TryRecvError {} -// ===== ClosedError ===== - -/// Error returned by [`Sender::poll_ready`](super::Sender::poll_ready). -#[derive(Debug)] -pub struct ClosedError(()); - -impl ClosedError { - pub(crate) fn new() -> ClosedError { - ClosedError(()) - } -} - -impl fmt::Display for ClosedError { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "channel closed") - } -} - -impl Error for ClosedError {} - cfg_time! { // ===== SendTimeoutError ===== diff --git a/tokio/src/sync/mpsc/mod.rs b/tokio/src/sync/mpsc/mod.rs index 7e663da89f0..a2bcf83b0ee 100644 --- a/tokio/src/sync/mpsc/mod.rs +++ b/tokio/src/sync/mpsc/mod.rs @@ -76,7 +76,7 @@ pub(super) mod block; mod bounded; -pub use self::bounded::{channel, Receiver, Sender}; +pub use self::bounded::{channel, Permit, Receiver, Sender}; mod chan; diff --git a/tokio/src/sync/mpsc/unbounded.rs b/tokio/src/sync/mpsc/unbounded.rs index 6b2ca722729..59456375297 100644 --- a/tokio/src/sync/mpsc/unbounded.rs +++ b/tokio/src/sync/mpsc/unbounded.rs @@ -73,8 +73,7 @@ impl UnboundedReceiver { UnboundedReceiver { chan } } - #[doc(hidden)] // TODO: doc - pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { self.chan.recv(cx) } @@ -174,7 +173,41 @@ impl UnboundedSender { /// [`close`]: UnboundedReceiver::close /// [`UnboundedReceiver`]: UnboundedReceiver pub fn send(&self, message: T) -> Result<(), SendError> { - self.chan.send_unbounded(message)?; + if !self.inc_num_messages() { + return Err(SendError(message)); + } + + self.chan.send(message); Ok(()) } + + fn inc_num_messages(&self) -> bool { + use std::process; + use std::sync::atomic::Ordering::{AcqRel, Acquire}; + + let mut curr = self.chan.semaphore().load(Acquire); + + loop { + if curr & 1 == 1 { + return false; + } + + if curr == usize::MAX ^ 1 { + // Overflowed the ref count. There is no safe way to recover, so + // abort the process. In practice, this should never happen. + process::abort() + } + + match self + .chan + .semaphore() + .compare_exchange(curr, curr + 2, AcqRel, Acquire) + { + Ok(_) => return true, + Err(actual) => { + curr = actual; + } + } + } + } } diff --git a/tokio/src/sync/semaphore_ll.rs b/tokio/src/sync/semaphore_ll.rs deleted file mode 100644 index f044095f8fc..00000000000 --- a/tokio/src/sync/semaphore_ll.rs +++ /dev/null @@ -1,1221 +0,0 @@ -#![cfg_attr(not(feature = "sync"), allow(dead_code, unreachable_pub))] - -//! Thread-safe, asynchronous counting semaphore. -//! -//! A `Semaphore` instance holds a set of permits. Permits are used to -//! synchronize access to a shared resource. -//! -//! Before accessing the shared resource, callers acquire a permit from the -//! semaphore. Once the permit is acquired, the caller then enters the critical -//! section. If no permits are available, then acquiring the semaphore returns -//! `Pending`. The task is woken once a permit becomes available. - -use crate::loom::cell::UnsafeCell; -use crate::loom::future::AtomicWaker; -use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize}; -use crate::loom::thread; - -use std::cmp; -use std::fmt; -use std::ptr::{self, NonNull}; -use std::sync::atomic::Ordering::{self, AcqRel, Acquire, Relaxed, Release}; -use std::task::Poll::{Pending, Ready}; -use std::task::{Context, Poll}; -use std::usize; - -/// Futures-aware semaphore. -pub(crate) struct Semaphore { - /// Tracks both the waiter queue tail pointer and the number of remaining - /// permits. - state: AtomicUsize, - - /// waiter queue head pointer. - head: UnsafeCell>, - - /// Coordinates access to the queue head. - rx_lock: AtomicUsize, - - /// Stub waiter node used as part of the MPSC channel algorithm. - stub: Box, -} - -/// A semaphore permit -/// -/// Tracks the lifecycle of a semaphore permit. -/// -/// An instance of `Permit` is intended to be used with a **single** instance of -/// `Semaphore`. Using a single instance of `Permit` with multiple semaphore -/// instances will result in unexpected behavior. -/// -/// `Permit` does **not** release the permit back to the semaphore on drop. It -/// is the user's responsibility to ensure that `Permit::release` is called -/// before dropping the permit. -#[derive(Debug)] -pub(crate) struct Permit { - waiter: Option>, - state: PermitState, -} - -/// Error returned by `Permit::poll_acquire`. -#[derive(Debug)] -pub(crate) struct AcquireError(()); - -/// Error returned by `Permit::try_acquire`. -#[derive(Debug)] -pub(crate) enum TryAcquireError { - Closed, - NoPermits, -} - -/// Node used to notify the semaphore waiter when permit is available. -#[derive(Debug)] -struct Waiter { - /// Stores waiter state. - /// - /// See `WaiterState` for more details. - state: AtomicUsize, - - /// Task to wake when a permit is made available. - waker: AtomicWaker, - - /// Next pointer in the queue of waiting senders. - next: AtomicPtr, -} - -/// Semaphore state -/// -/// The 2 low bits track the modes. -/// -/// - Closed -/// - Full -/// -/// When not full, the rest of the `usize` tracks the total number of messages -/// in the channel. When full, the rest of the `usize` is a pointer to the tail -/// of the "waiting senders" queue. -#[derive(Copy, Clone)] -struct SemState(usize); - -/// Permit state -#[derive(Debug, Copy, Clone)] -enum PermitState { - /// Currently waiting for permits to be made available and assigned to the - /// waiter. - Waiting(u16), - - /// The number of acquired permits - Acquired(u16), -} - -/// State for an individual waker node -#[derive(Debug, Copy, Clone)] -struct WaiterState(usize); - -/// Waiter node is in the semaphore queue -const QUEUED: usize = 0b001; - -/// Semaphore has been closed, no more permits will be issued. -const CLOSED: usize = 0b10; - -/// The permit that owns the `Waiter` dropped. -const DROPPED: usize = 0b100; - -/// Represents "one requested permit" in the waiter state -const PERMIT_ONE: usize = 0b1000; - -/// Masks the waiter state to only contain bits tracking number of requested -/// permits. -const PERMIT_MASK: usize = usize::MAX - (PERMIT_ONE - 1); - -/// How much to shift a permit count to pack it into the waker state -const PERMIT_SHIFT: u32 = PERMIT_ONE.trailing_zeros(); - -/// Flag differentiating between available permits and waiter pointers. -/// -/// If we assume pointers are properly aligned, then the least significant bit -/// will always be zero. So, we use that bit to track if the value represents a -/// number. -const NUM_FLAG: usize = 0b01; - -/// Signal the semaphore is closed -const CLOSED_FLAG: usize = 0b10; - -/// Maximum number of permits a semaphore can manage -const MAX_PERMITS: usize = usize::MAX >> NUM_SHIFT; - -/// When representing "numbers", the state has to be shifted this much (to get -/// rid of the flag bit). -const NUM_SHIFT: usize = 2; - -// ===== impl Semaphore ===== - -impl Semaphore { - /// Creates a new semaphore with the initial number of permits - /// - /// # Panics - /// - /// Panics if `permits` is zero. - pub(crate) fn new(permits: usize) -> Semaphore { - let stub = Box::new(Waiter::new()); - let ptr = NonNull::from(&*stub); - - // Allocations are aligned - debug_assert!(ptr.as_ptr() as usize & NUM_FLAG == 0); - - let state = SemState::new(permits, &stub); - - Semaphore { - state: AtomicUsize::new(state.to_usize()), - head: UnsafeCell::new(ptr), - rx_lock: AtomicUsize::new(0), - stub, - } - } - - /// Returns the current number of available permits - pub(crate) fn available_permits(&self) -> usize { - let curr = SemState(self.state.load(Acquire)); - curr.available_permits() - } - - /// Tries to acquire the requested number of permits, registering the waiter - /// if not enough permits are available. - fn poll_acquire( - &self, - cx: &mut Context<'_>, - num_permits: u16, - permit: &mut Permit, - ) -> Poll> { - self.poll_acquire2(num_permits, || { - let waiter = permit.waiter.get_or_insert_with(|| Box::new(Waiter::new())); - - waiter.waker.register_by_ref(cx.waker()); - - Some(NonNull::from(&**waiter)) - }) - } - - fn try_acquire(&self, num_permits: u16) -> Result<(), TryAcquireError> { - match self.poll_acquire2(num_permits, || None) { - Poll::Ready(res) => res.map_err(to_try_acquire), - Poll::Pending => Err(TryAcquireError::NoPermits), - } - } - - /// Polls for a permit - /// - /// Tries to acquire available permits first. If unable to acquire a - /// sufficient number of permits, the caller's waiter is pushed onto the - /// semaphore's wait queue. - fn poll_acquire2( - &self, - num_permits: u16, - mut get_waiter: F, - ) -> Poll> - where - F: FnMut() -> Option>, - { - let num_permits = num_permits as usize; - - // Load the current state - let mut curr = SemState(self.state.load(Acquire)); - - // Saves a ref to the waiter node - let mut maybe_waiter: Option> = None; - - /// Used in branches where we attempt to push the waiter into the wait - /// queue but fail due to permits becoming available or the wait queue - /// transitioning to "closed". In this case, the waiter must be - /// transitioned back to the "idle" state. - macro_rules! revert_to_idle { - () => { - if let Some(waiter) = maybe_waiter { - unsafe { waiter.as_ref() }.revert_to_idle(); - } - }; - } - - loop { - let mut next = curr; - - if curr.is_closed() { - revert_to_idle!(); - return Ready(Err(AcquireError::closed())); - } - - let acquired = next.acquire_permits(num_permits, &self.stub); - - if !acquired { - // There are not enough available permits to satisfy the - // request. The permit transitions to a waiting state. - debug_assert!(curr.waiter().is_some() || curr.available_permits() < num_permits); - - if let Some(waiter) = maybe_waiter.as_ref() { - // Safety: the caller owns the waiter. - let w = unsafe { waiter.as_ref() }; - w.set_permits_to_acquire(num_permits - curr.available_permits()); - } else { - // Get the waiter for the permit. - if let Some(waiter) = get_waiter() { - // Safety: the caller owns the waiter. - let w = unsafe { waiter.as_ref() }; - - // If there are any currently available permits, the - // waiter acquires those immediately and waits for the - // remaining permits to become available. - if !w.to_queued(num_permits - curr.available_permits()) { - // The node is alrady queued, there is no further work - // to do. - return Pending; - } - - maybe_waiter = Some(waiter); - } else { - // No waiter, this indicates the caller does not wish to - // "wait", so there is nothing left to do. - return Pending; - } - } - - next.set_waiter(maybe_waiter.unwrap()); - } - - debug_assert_ne!(curr.0, 0); - debug_assert_ne!(next.0, 0); - - match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { - Ok(_) => { - if acquired { - // Successfully acquire permits **without** queuing the - // waiter node. The waiter node is not currently in the - // queue. - revert_to_idle!(); - return Ready(Ok(())); - } else { - // The node is pushed into the queue, the final step is - // to set the node's "next" pointer to return the wait - // queue into a consistent state. - - let prev_waiter = - curr.waiter().unwrap_or_else(|| NonNull::from(&*self.stub)); - - let waiter = maybe_waiter.unwrap(); - - // Link the nodes. - // - // Safety: the mpsc algorithm guarantees the old tail of - // the queue is not removed from the queue during the - // push process. - unsafe { - prev_waiter.as_ref().store_next(waiter); - } - - return Pending; - } - } - Err(actual) => { - curr = SemState(actual); - } - } - } - } - - /// Closes the semaphore. This prevents the semaphore from issuing new - /// permits and notifies all pending waiters. - pub(crate) fn close(&self) { - // Acquire the `rx_lock`, setting the "closed" flag on the lock. - let prev = self.rx_lock.fetch_or(1, AcqRel); - - if prev != 0 { - // Another thread has the lock and will be responsible for notifying - // pending waiters. - return; - } - - self.add_permits_locked(0, true); - } - /// Adds `n` new permits to the semaphore. - /// - /// The maximum number of permits is `usize::MAX >> 3`, and this function will panic if the limit is exceeded. - pub(crate) fn add_permits(&self, n: usize) { - if n == 0 { - return; - } - - // TODO: Handle overflow. A panic is not sufficient, the process must - // abort. - let prev = self.rx_lock.fetch_add(n << 1, AcqRel); - - if prev != 0 { - // Another thread has the lock and will be responsible for notifying - // pending waiters. - return; - } - - self.add_permits_locked(n, false); - } - - fn add_permits_locked(&self, mut rem: usize, mut closed: bool) { - while rem > 0 || closed { - if closed { - SemState::fetch_set_closed(&self.state, AcqRel); - } - - // Release the permits and notify - self.add_permits_locked2(rem, closed); - - let n = rem << 1; - - let actual = if closed { - let actual = self.rx_lock.fetch_sub(n | 1, AcqRel); - closed = false; - actual - } else { - let actual = self.rx_lock.fetch_sub(n, AcqRel); - closed = actual & 1 == 1; - actual - }; - - rem = (actual >> 1) - rem; - } - } - - /// Releases a specific amount of permits to the semaphore - /// - /// This function is called by `add_permits` after the add lock has been - /// acquired. - fn add_permits_locked2(&self, mut n: usize, closed: bool) { - // If closing the semaphore, we want to drain the entire queue. The - // number of permits being assigned doesn't matter. - if closed { - n = usize::MAX; - } - - 'outer: while n > 0 { - unsafe { - let mut head = self.head.with(|head| *head); - let mut next_ptr = head.as_ref().next.load(Acquire); - - let stub = self.stub(); - - if head == stub { - // The stub node indicates an empty queue. Any remaining - // permits get assigned back to the semaphore. - let next = match NonNull::new(next_ptr) { - Some(next) => next, - None => { - // This loop is not part of the standard intrusive mpsc - // channel algorithm. This is where we atomically pop - // the last task and add `n` to the remaining capacity. - // - // This modification to the pop algorithm works because, - // at this point, we have not done any work (only done - // reading). We have a *pretty* good idea that there is - // no concurrent pusher. - // - // The capacity is then atomically added by doing an - // AcqRel CAS on `state`. The `state` cell is the - // linchpin of the algorithm. - // - // By successfully CASing `head` w/ AcqRel, we ensure - // that, if any thread was racing and entered a push, we - // see that and abort pop, retrying as it is - // "inconsistent". - let mut curr = SemState::load(&self.state, Acquire); - - loop { - if curr.has_waiter(&self.stub) { - // A waiter is being added concurrently. - // This is the MPSC queue's "inconsistent" - // state and we must loop and try again. - thread::yield_now(); - continue 'outer; - } - - // If closing, nothing more to do. - if closed { - debug_assert!(curr.is_closed(), "state = {:?}", curr); - return; - } - - let mut next = curr; - next.release_permits(n, &self.stub); - - match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { - Ok(_) => return, - Err(actual) => { - curr = SemState(actual); - } - } - } - } - }; - - self.head.with_mut(|head| *head = next); - head = next; - next_ptr = next.as_ref().next.load(Acquire); - } - - // `head` points to a waiter assign permits to the waiter. If - // all requested permits are satisfied, then we can continue, - // otherwise the node stays in the wait queue. - if !head.as_ref().assign_permits(&mut n, closed) { - assert_eq!(n, 0); - return; - } - - if let Some(next) = NonNull::new(next_ptr) { - self.head.with_mut(|head| *head = next); - - self.remove_queued(head, closed); - continue 'outer; - } - - let state = SemState::load(&self.state, Acquire); - - // This must always be a pointer as the wait list is not empty. - let tail = state.waiter().unwrap(); - - if tail != head { - // Inconsistent - thread::yield_now(); - continue 'outer; - } - - self.push_stub(closed); - - next_ptr = head.as_ref().next.load(Acquire); - - if let Some(next) = NonNull::new(next_ptr) { - self.head.with_mut(|head| *head = next); - - self.remove_queued(head, closed); - continue 'outer; - } - - // Inconsistent state, loop - thread::yield_now(); - } - } - } - - /// The wait node has had all of its permits assigned and has been removed - /// from the wait queue. - /// - /// Attempt to remove the QUEUED bit from the node. If additional permits - /// are concurrently requested, the node must be pushed back into the wait - /// queued. - fn remove_queued(&self, waiter: NonNull, closed: bool) { - let mut curr = WaiterState(unsafe { waiter.as_ref() }.state.load(Acquire)); - - loop { - if curr.is_dropped() { - // The Permit dropped, it is on us to release the memory - let _ = unsafe { Box::from_raw(waiter.as_ptr()) }; - return; - } - - // The node is removed from the queue. We attempt to unset the - // queued bit, but concurrently the waiter has requested more - // permits. When the waiter requested more permits, it saw the - // queued bit set so took no further action. This requires us to - // push the node back into the queue. - if curr.permits_to_acquire() > 0 { - // More permits are requested. The waiter must be re-queued - unsafe { - self.push_waiter(waiter, closed); - } - return; - } - - let mut next = curr; - next.unset_queued(); - - let w = unsafe { waiter.as_ref() }; - - match w.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { - Ok(_) => return, - Err(actual) => { - curr = WaiterState(actual); - } - } - } - } - - unsafe fn push_stub(&self, closed: bool) { - self.push_waiter(self.stub(), closed); - } - - unsafe fn push_waiter(&self, waiter: NonNull, closed: bool) { - // Set the next pointer. This does not require an atomic operation as - // this node is not accessible. The write will be flushed with the next - // operation - waiter.as_ref().next.store(ptr::null_mut(), Relaxed); - - // Update the tail to point to the new node. We need to see the previous - // node in order to update the next pointer as well as release `task` - // to any other threads calling `push`. - let next = SemState::new_ptr(waiter, closed); - let prev = SemState(self.state.swap(next.0, AcqRel)); - - debug_assert_eq!(closed, prev.is_closed()); - - // This function is only called when there are pending tasks. Because of - // this, the state must *always* be in pointer mode. - let prev = prev.waiter().unwrap(); - - // No cycles plz - debug_assert_ne!(prev, waiter); - - // Release `task` to the consume end. - prev.as_ref().next.store(waiter.as_ptr(), Release); - } - - fn stub(&self) -> NonNull { - unsafe { NonNull::new_unchecked(&*self.stub as *const _ as *mut _) } - } -} - -impl Drop for Semaphore { - fn drop(&mut self) { - self.close(); - } -} - -impl fmt::Debug for Semaphore { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("Semaphore") - .field("state", &SemState::load(&self.state, Relaxed)) - .field("head", &self.head.with(|ptr| ptr)) - .field("rx_lock", &self.rx_lock.load(Relaxed)) - .field("stub", &self.stub) - .finish() - } -} - -unsafe impl Send for Semaphore {} -unsafe impl Sync for Semaphore {} - -// ===== impl Permit ===== - -impl Permit { - /// Creates a new `Permit`. - /// - /// The permit begins in the "unacquired" state. - pub(crate) fn new() -> Permit { - use PermitState::Acquired; - - Permit { - waiter: None, - state: Acquired(0), - } - } - - /// Returns `true` if the permit has been acquired - #[allow(dead_code)] // may be used later - pub(crate) fn is_acquired(&self) -> bool { - match self.state { - PermitState::Acquired(num) if num > 0 => true, - _ => false, - } - } - - /// Tries to acquire the permit. If no permits are available, the current task - /// is notified once a new permit becomes available. - pub(crate) fn poll_acquire( - &mut self, - cx: &mut Context<'_>, - num_permits: u16, - semaphore: &Semaphore, - ) -> Poll> { - use std::cmp::Ordering::*; - use PermitState::*; - - match self.state { - Waiting(requested) => { - // There must be a waiter - let waiter = self.waiter.as_ref().unwrap(); - - match requested.cmp(&num_permits) { - Less => { - let delta = num_permits - requested; - - // Request additional permits. If the waiter has been - // dequeued, it must be re-queued. - if !waiter.try_inc_permits_to_acquire(delta as usize) { - let waiter = NonNull::from(&**waiter); - - // Ignore the result. The check for - // `permits_to_acquire()` will converge the state as - // needed - let _ = semaphore.poll_acquire2(delta, || Some(waiter))?; - } - - self.state = Waiting(num_permits); - } - Greater => { - let delta = requested - num_permits; - let to_release = waiter.try_dec_permits_to_acquire(delta as usize); - - semaphore.add_permits(to_release); - self.state = Waiting(num_permits); - } - Equal => {} - } - - if waiter.permits_to_acquire()? == 0 { - self.state = Acquired(requested); - return Ready(Ok(())); - } - - waiter.waker.register_by_ref(cx.waker()); - - if waiter.permits_to_acquire()? == 0 { - self.state = Acquired(requested); - return Ready(Ok(())); - } - - Pending - } - Acquired(acquired) => { - if acquired >= num_permits { - Ready(Ok(())) - } else { - match semaphore.poll_acquire(cx, num_permits - acquired, self)? { - Ready(()) => { - self.state = Acquired(num_permits); - Ready(Ok(())) - } - Pending => { - self.state = Waiting(num_permits); - Pending - } - } - } - } - } - } - - /// Tries to acquire the permit. - pub(crate) fn try_acquire( - &mut self, - num_permits: u16, - semaphore: &Semaphore, - ) -> Result<(), TryAcquireError> { - use PermitState::*; - - match self.state { - Waiting(requested) => { - // There must be a waiter - let waiter = self.waiter.as_ref().unwrap(); - - if requested > num_permits { - let delta = requested - num_permits; - let to_release = waiter.try_dec_permits_to_acquire(delta as usize); - - semaphore.add_permits(to_release); - self.state = Waiting(num_permits); - } - - let res = waiter.permits_to_acquire().map_err(to_try_acquire)?; - - if res == 0 { - if requested < num_permits { - // Try to acquire the additional permits - semaphore.try_acquire(num_permits - requested)?; - } - - self.state = Acquired(num_permits); - Ok(()) - } else { - Err(TryAcquireError::NoPermits) - } - } - Acquired(acquired) => { - if acquired < num_permits { - semaphore.try_acquire(num_permits - acquired)?; - self.state = Acquired(num_permits); - } - - Ok(()) - } - } - } - - /// Releases a permit back to the semaphore - pub(crate) fn release(&mut self, n: u16, semaphore: &Semaphore) { - let n = self.forget(n); - semaphore.add_permits(n as usize); - } - - /// Forgets the permit **without** releasing it back to the semaphore. - /// - /// After calling `forget`, `poll_acquire` is able to acquire new permit - /// from the semaphore. - /// - /// Repeatedly calling `forget` without associated calls to `add_permit` - /// will result in the semaphore losing all permits. - /// - /// Will forget **at most** the number of acquired permits. This number is - /// returned. - pub(crate) fn forget(&mut self, n: u16) -> u16 { - use PermitState::*; - - match self.state { - Waiting(requested) => { - let n = cmp::min(n, requested); - - // Decrement - let acquired = self - .waiter - .as_ref() - .unwrap() - .try_dec_permits_to_acquire(n as usize) as u16; - - if n == requested { - self.state = Acquired(0); - } else if acquired == requested - n { - self.state = Waiting(acquired); - } else { - self.state = Waiting(requested - n); - } - - acquired - } - Acquired(acquired) => { - let n = cmp::min(n, acquired); - self.state = Acquired(acquired - n); - n - } - } - } -} - -impl Default for Permit { - fn default() -> Self { - Self::new() - } -} - -impl Drop for Permit { - fn drop(&mut self) { - if let Some(waiter) = self.waiter.take() { - // Set the dropped flag - let state = WaiterState(waiter.state.fetch_or(DROPPED, AcqRel)); - - if state.is_queued() { - // The waiter is stored in the queue. The semaphore will drop it - std::mem::forget(waiter); - } - } - } -} - -// ===== impl AcquireError ==== - -impl AcquireError { - fn closed() -> AcquireError { - AcquireError(()) - } -} - -fn to_try_acquire(_: AcquireError) -> TryAcquireError { - TryAcquireError::Closed -} - -impl fmt::Display for AcquireError { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "semaphore closed") - } -} - -impl std::error::Error for AcquireError {} - -// ===== impl TryAcquireError ===== - -impl TryAcquireError { - /// Returns `true` if the error was caused by a closed semaphore. - pub(crate) fn is_closed(&self) -> bool { - match self { - TryAcquireError::Closed => true, - _ => false, - } - } - - /// Returns `true` if the error was caused by calling `try_acquire` on a - /// semaphore with no available permits. - pub(crate) fn is_no_permits(&self) -> bool { - match self { - TryAcquireError::NoPermits => true, - _ => false, - } - } -} - -impl fmt::Display for TryAcquireError { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - TryAcquireError::Closed => write!(fmt, "semaphore closed"), - TryAcquireError::NoPermits => write!(fmt, "no permits available"), - } - } -} - -impl std::error::Error for TryAcquireError {} - -// ===== impl Waiter ===== - -impl Waiter { - fn new() -> Waiter { - Waiter { - state: AtomicUsize::new(0), - waker: AtomicWaker::new(), - next: AtomicPtr::new(ptr::null_mut()), - } - } - - fn permits_to_acquire(&self) -> Result { - let state = WaiterState(self.state.load(Acquire)); - - if state.is_closed() { - Err(AcquireError(())) - } else { - Ok(state.permits_to_acquire()) - } - } - - /// Only increments the number of permits *if* the waiter is currently - /// queued. - /// - /// # Returns - /// - /// `true` if the number of permits to acquire has been incremented. `false` - /// otherwise. On `false`, the caller should use `Semaphore::poll_acquire`. - fn try_inc_permits_to_acquire(&self, n: usize) -> bool { - let mut curr = WaiterState(self.state.load(Acquire)); - - loop { - if !curr.is_queued() { - assert_eq!(0, curr.permits_to_acquire()); - return false; - } - - let mut next = curr; - next.set_permits_to_acquire(n + curr.permits_to_acquire()); - - match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { - Ok(_) => return true, - Err(actual) => curr = WaiterState(actual), - } - } - } - - /// Try to decrement the number of permits to acquire. This returns the - /// actual number of permits that were decremented. The delta between `n` - /// and the return has been assigned to the permit and the caller must - /// assign these back to the semaphore. - fn try_dec_permits_to_acquire(&self, n: usize) -> usize { - let mut curr = WaiterState(self.state.load(Acquire)); - - loop { - if !curr.is_queued() { - assert_eq!(0, curr.permits_to_acquire()); - } - - let delta = cmp::min(n, curr.permits_to_acquire()); - let rem = curr.permits_to_acquire() - delta; - - let mut next = curr; - next.set_permits_to_acquire(rem); - - match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { - Ok(_) => return n - delta, - Err(actual) => curr = WaiterState(actual), - } - } - } - - /// Store the number of remaining permits needed to satisfy the waiter and - /// transition to the "QUEUED" state. - /// - /// # Returns - /// - /// `true` if the `QUEUED` bit was set as part of the transition. - fn to_queued(&self, num_permits: usize) -> bool { - let mut curr = WaiterState(self.state.load(Acquire)); - - // The waiter should **not** be waiting for any permits. - debug_assert_eq!(curr.permits_to_acquire(), 0); - - loop { - let mut next = curr; - next.set_permits_to_acquire(num_permits); - next.set_queued(); - - match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { - Ok(_) => { - if curr.is_queued() { - return false; - } else { - // Make sure the next pointer is null - self.next.store(ptr::null_mut(), Relaxed); - return true; - } - } - Err(actual) => curr = WaiterState(actual), - } - } - } - - /// Set the number of permits to acquire. - /// - /// This function is only called when the waiter is being inserted into the - /// wait queue. Because of this, there are no concurrent threads that can - /// modify the state and using `store` is safe. - fn set_permits_to_acquire(&self, num_permits: usize) { - debug_assert!(WaiterState(self.state.load(Acquire)).is_queued()); - - let mut state = WaiterState(QUEUED); - state.set_permits_to_acquire(num_permits); - - self.state.store(state.0, Release); - } - - /// Assign permits to the waiter. - /// - /// Returns `true` if the waiter should be removed from the queue - fn assign_permits(&self, n: &mut usize, closed: bool) -> bool { - let mut curr = WaiterState(self.state.load(Acquire)); - - loop { - let mut next = curr; - - // Number of permits to assign to this waiter - let assign = cmp::min(curr.permits_to_acquire(), *n); - - // Assign the permits - next.set_permits_to_acquire(curr.permits_to_acquire() - assign); - - if closed { - next.set_closed(); - } - - match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { - Ok(_) => { - // Update `n` - *n -= assign; - - if next.permits_to_acquire() == 0 { - if curr.permits_to_acquire() > 0 { - self.waker.wake(); - } - - return true; - } else { - return false; - } - } - Err(actual) => curr = WaiterState(actual), - } - } - } - - fn revert_to_idle(&self) { - // An idle node is not waiting on any permits - self.state.store(0, Relaxed); - } - - fn store_next(&self, next: NonNull) { - self.next.store(next.as_ptr(), Release); - } -} - -// ===== impl SemState ===== - -impl SemState { - /// Returns a new default `State` value. - fn new(permits: usize, stub: &Waiter) -> SemState { - assert!(permits <= MAX_PERMITS); - - if permits > 0 { - SemState((permits << NUM_SHIFT) | NUM_FLAG) - } else { - SemState(stub as *const _ as usize) - } - } - - /// Returns a `State` tracking `ptr` as the tail of the queue. - fn new_ptr(tail: NonNull, closed: bool) -> SemState { - let mut val = tail.as_ptr() as usize; - - if closed { - val |= CLOSED_FLAG; - } - - SemState(val) - } - - /// Returns the amount of remaining capacity - fn available_permits(self) -> usize { - if !self.has_available_permits() { - return 0; - } - - self.0 >> NUM_SHIFT - } - - /// Returns `true` if the state has permits that can be claimed by a waiter. - fn has_available_permits(self) -> bool { - self.0 & NUM_FLAG == NUM_FLAG - } - - fn has_waiter(self, stub: &Waiter) -> bool { - !self.has_available_permits() && !self.is_stub(stub) - } - - /// Tries to atomically acquire specified number of permits. - /// - /// # Return - /// - /// Returns `true` if the specified number of permits were acquired, `false` - /// otherwise. Returning false does not mean that there are no more - /// available permits. - fn acquire_permits(&mut self, num: usize, stub: &Waiter) -> bool { - debug_assert!(num > 0); - - if self.available_permits() < num { - return false; - } - - debug_assert!(self.waiter().is_none()); - - self.0 -= num << NUM_SHIFT; - - if self.0 == NUM_FLAG { - // Set the state to the stub pointer. - self.0 = stub as *const _ as usize; - } - - true - } - - /// Releases permits - /// - /// Returns `true` if the permits were accepted. - fn release_permits(&mut self, permits: usize, stub: &Waiter) { - debug_assert!(permits > 0); - - if self.is_stub(stub) { - self.0 = (permits << NUM_SHIFT) | NUM_FLAG | (self.0 & CLOSED_FLAG); - return; - } - - debug_assert!(self.has_available_permits()); - - self.0 += permits << NUM_SHIFT; - } - - fn is_waiter(self) -> bool { - self.0 & NUM_FLAG == 0 - } - - /// Returns the waiter, if one is set. - fn waiter(self) -> Option> { - if self.is_waiter() { - let waiter = NonNull::new(self.as_ptr()).expect("null pointer stored"); - - Some(waiter) - } else { - None - } - } - - /// Assumes `self` represents a pointer - fn as_ptr(self) -> *mut Waiter { - (self.0 & !CLOSED_FLAG) as *mut Waiter - } - - /// Sets to a pointer to a waiter. - /// - /// This can only be done from the full state. - fn set_waiter(&mut self, waiter: NonNull) { - let waiter = waiter.as_ptr() as usize; - debug_assert!(!self.is_closed()); - - self.0 = waiter; - } - - fn is_stub(self, stub: &Waiter) -> bool { - self.as_ptr() as usize == stub as *const _ as usize - } - - /// Loads the state from an AtomicUsize. - fn load(cell: &AtomicUsize, ordering: Ordering) -> SemState { - let value = cell.load(ordering); - SemState(value) - } - - fn fetch_set_closed(cell: &AtomicUsize, ordering: Ordering) -> SemState { - let value = cell.fetch_or(CLOSED_FLAG, ordering); - SemState(value) - } - - fn is_closed(self) -> bool { - self.0 & CLOSED_FLAG == CLOSED_FLAG - } - - /// Converts the state into a `usize` representation. - fn to_usize(self) -> usize { - self.0 - } -} - -impl fmt::Debug for SemState { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut fmt = fmt.debug_struct("SemState"); - - if self.is_waiter() { - fmt.field("state", &""); - } else { - fmt.field("permits", &self.available_permits()); - } - - fmt.finish() - } -} - -// ===== impl WaiterState ===== - -impl WaiterState { - fn permits_to_acquire(self) -> usize { - self.0 >> PERMIT_SHIFT - } - - fn set_permits_to_acquire(&mut self, val: usize) { - self.0 = (val << PERMIT_SHIFT) | (self.0 & !PERMIT_MASK) - } - - fn is_queued(self) -> bool { - self.0 & QUEUED == QUEUED - } - - fn set_queued(&mut self) { - self.0 |= QUEUED; - } - - fn is_closed(self) -> bool { - self.0 & CLOSED == CLOSED - } - - fn set_closed(&mut self) { - self.0 |= CLOSED; - } - - fn unset_queued(&mut self) { - assert!(self.is_queued()); - self.0 -= QUEUED; - } - - fn is_dropped(self) -> bool { - self.0 & DROPPED == DROPPED - } -} diff --git a/tokio/src/sync/tests/loom_mpsc.rs b/tokio/src/sync/tests/loom_mpsc.rs index 6a1a6abedda..e8db2dea4ca 100644 --- a/tokio/src/sync/tests/loom_mpsc.rs +++ b/tokio/src/sync/tests/loom_mpsc.rs @@ -7,17 +7,17 @@ use loom::thread; #[test] fn closing_tx() { loom::model(|| { - let (mut tx, mut rx) = mpsc::channel(16); + let (tx, mut rx) = mpsc::channel(16); thread::spawn(move || { tx.try_send(()).unwrap(); drop(tx); }); - let v = block_on(poll_fn(|cx| rx.poll_recv(cx))); + let v = block_on(rx.recv()); assert!(v.is_some()); - let v = block_on(poll_fn(|cx| rx.poll_recv(cx))); + let v = block_on(rx.recv()); assert!(v.is_none()); }); } @@ -32,10 +32,10 @@ fn closing_unbounded_tx() { drop(tx); }); - let v = block_on(poll_fn(|cx| rx.poll_recv(cx))); + let v = block_on(rx.recv()); assert!(v.is_some()); - let v = block_on(poll_fn(|cx| rx.poll_recv(cx))); + let v = block_on(rx.recv()); assert!(v.is_none()); }); } @@ -53,7 +53,7 @@ fn dropping_tx() { } drop(tx); - let v = block_on(poll_fn(|cx| rx.poll_recv(cx))); + let v = block_on(rx.recv()); assert!(v.is_none()); }); } @@ -71,7 +71,7 @@ fn dropping_unbounded_tx() { } drop(tx); - let v = block_on(poll_fn(|cx| rx.poll_recv(cx))); + let v = block_on(rx.recv()); assert!(v.is_none()); }); } diff --git a/tokio/src/sync/tests/loom_semaphore_ll.rs b/tokio/src/sync/tests/loom_semaphore_ll.rs deleted file mode 100644 index b5e5efba82c..00000000000 --- a/tokio/src/sync/tests/loom_semaphore_ll.rs +++ /dev/null @@ -1,192 +0,0 @@ -use crate::sync::semaphore_ll::*; - -use futures::future::poll_fn; -use loom::future::block_on; -use loom::thread; -use std::future::Future; -use std::pin::Pin; -use std::sync::atomic::AtomicUsize; -use std::sync::atomic::Ordering::SeqCst; -use std::sync::Arc; -use std::task::Poll::Ready; -use std::task::{Context, Poll}; - -#[test] -fn basic_usage() { - const NUM: usize = 2; - - struct Actor { - waiter: Permit, - shared: Arc, - } - - struct Shared { - semaphore: Semaphore, - active: AtomicUsize, - } - - impl Future for Actor { - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - let me = &mut *self; - - ready!(me.waiter.poll_acquire(cx, 1, &me.shared.semaphore)).unwrap(); - - let actual = me.shared.active.fetch_add(1, SeqCst); - assert!(actual <= NUM - 1); - - let actual = me.shared.active.fetch_sub(1, SeqCst); - assert!(actual <= NUM); - - me.waiter.release(1, &me.shared.semaphore); - - Ready(()) - } - } - - loom::model(|| { - let shared = Arc::new(Shared { - semaphore: Semaphore::new(NUM), - active: AtomicUsize::new(0), - }); - - for _ in 0..NUM { - let shared = shared.clone(); - - thread::spawn(move || { - block_on(Actor { - waiter: Permit::new(), - shared, - }); - }); - } - - block_on(Actor { - waiter: Permit::new(), - shared, - }); - }); -} - -#[test] -fn release() { - loom::model(|| { - let semaphore = Arc::new(Semaphore::new(1)); - - { - let semaphore = semaphore.clone(); - thread::spawn(move || { - let mut permit = Permit::new(); - - block_on(poll_fn(|cx| permit.poll_acquire(cx, 1, &semaphore))).unwrap(); - - permit.release(1, &semaphore); - }); - } - - let mut permit = Permit::new(); - - block_on(poll_fn(|cx| permit.poll_acquire(cx, 1, &semaphore))).unwrap(); - - permit.release(1, &semaphore); - }); -} - -#[test] -fn basic_closing() { - const NUM: usize = 2; - - loom::model(|| { - let semaphore = Arc::new(Semaphore::new(1)); - - for _ in 0..NUM { - let semaphore = semaphore.clone(); - - thread::spawn(move || { - let mut permit = Permit::new(); - - for _ in 0..2 { - block_on(poll_fn(|cx| { - permit.poll_acquire(cx, 1, &semaphore).map_err(|_| ()) - }))?; - - permit.release(1, &semaphore); - } - - Ok::<(), ()>(()) - }); - } - - semaphore.close(); - }); -} - -#[test] -fn concurrent_close() { - const NUM: usize = 3; - - loom::model(|| { - let semaphore = Arc::new(Semaphore::new(1)); - - for _ in 0..NUM { - let semaphore = semaphore.clone(); - - thread::spawn(move || { - let mut permit = Permit::new(); - - block_on(poll_fn(|cx| { - permit.poll_acquire(cx, 1, &semaphore).map_err(|_| ()) - }))?; - - permit.release(1, &semaphore); - - semaphore.close(); - - Ok::<(), ()>(()) - }); - } - }); -} - -#[test] -fn batch() { - let mut b = loom::model::Builder::new(); - b.preemption_bound = Some(1); - - b.check(|| { - let semaphore = Arc::new(Semaphore::new(10)); - let active = Arc::new(AtomicUsize::new(0)); - let mut ths = vec![]; - - for _ in 0..2 { - let semaphore = semaphore.clone(); - let active = active.clone(); - - ths.push(thread::spawn(move || { - let mut permit = Permit::new(); - - for n in &[4, 10, 8] { - block_on(poll_fn(|cx| permit.poll_acquire(cx, *n, &semaphore))).unwrap(); - - active.fetch_add(*n as usize, SeqCst); - - let num_active = active.load(SeqCst); - assert!(num_active <= 10); - - thread::yield_now(); - - active.fetch_sub(*n as usize, SeqCst); - - permit.release(*n, &semaphore); - } - })); - } - - for th in ths.into_iter() { - th.join().unwrap(); - } - - assert_eq!(10, semaphore.available_permits()); - }); -} diff --git a/tokio/src/sync/tests/mod.rs b/tokio/src/sync/tests/mod.rs index c637cb6b816..a78be6f3e15 100644 --- a/tokio/src/sync/tests/mod.rs +++ b/tokio/src/sync/tests/mod.rs @@ -1,6 +1,5 @@ cfg_not_loom! { mod atomic_waker; - mod semaphore_ll; mod semaphore_batch; } @@ -12,6 +11,5 @@ cfg_loom! { mod loom_notify; mod loom_oneshot; mod loom_semaphore_batch; - mod loom_semaphore_ll; mod loom_watch; } diff --git a/tokio/src/sync/tests/semaphore_ll.rs b/tokio/src/sync/tests/semaphore_ll.rs deleted file mode 100644 index bfb075780bb..00000000000 --- a/tokio/src/sync/tests/semaphore_ll.rs +++ /dev/null @@ -1,470 +0,0 @@ -use crate::sync::semaphore_ll::{Permit, Semaphore}; -use tokio_test::*; - -#[test] -fn poll_acquire_one_available() { - let s = Semaphore::new(100); - assert_eq!(s.available_permits(), 100); - - // Polling for a permit succeeds immediately - let mut permit = task::spawn(Permit::new()); - assert!(!permit.is_acquired()); - - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_eq!(s.available_permits(), 99); - assert!(permit.is_acquired()); - - // Polling again on the same waiter does not claim a new permit - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_eq!(s.available_permits(), 99); - assert!(permit.is_acquired()); -} - -#[test] -fn poll_acquire_many_available() { - let s = Semaphore::new(100); - assert_eq!(s.available_permits(), 100); - - // Polling for a permit succeeds immediately - let mut permit = task::spawn(Permit::new()); - assert!(!permit.is_acquired()); - - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 5, &s))); - assert_eq!(s.available_permits(), 95); - assert!(permit.is_acquired()); - - // Polling again on the same waiter does not claim a new permit - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_eq!(s.available_permits(), 95); - assert!(permit.is_acquired()); - - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 5, &s))); - assert_eq!(s.available_permits(), 95); - assert!(permit.is_acquired()); - - // Polling for a larger number of permits acquires more - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 8, &s))); - assert_eq!(s.available_permits(), 92); - assert!(permit.is_acquired()); -} - -#[test] -fn try_acquire_one_available() { - let s = Semaphore::new(100); - assert_eq!(s.available_permits(), 100); - - // Polling for a permit succeeds immediately - let mut permit = Permit::new(); - assert!(!permit.is_acquired()); - - assert_ok!(permit.try_acquire(1, &s)); - assert_eq!(s.available_permits(), 99); - assert!(permit.is_acquired()); - - // Polling again on the same waiter does not claim a new permit - assert_ok!(permit.try_acquire(1, &s)); - assert_eq!(s.available_permits(), 99); - assert!(permit.is_acquired()); -} - -#[test] -fn try_acquire_many_available() { - let s = Semaphore::new(100); - assert_eq!(s.available_permits(), 100); - - // Polling for a permit succeeds immediately - let mut permit = Permit::new(); - assert!(!permit.is_acquired()); - - assert_ok!(permit.try_acquire(5, &s)); - assert_eq!(s.available_permits(), 95); - assert!(permit.is_acquired()); - - // Polling again on the same waiter does not claim a new permit - assert_ok!(permit.try_acquire(5, &s)); - assert_eq!(s.available_permits(), 95); - assert!(permit.is_acquired()); -} - -#[test] -fn poll_acquire_one_unavailable() { - let s = Semaphore::new(1); - - let mut permit_1 = task::spawn(Permit::new()); - let mut permit_2 = task::spawn(Permit::new()); - - // Acquire the first permit - assert_ready_ok!(permit_1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_eq!(s.available_permits(), 0); - - permit_2.enter(|cx, mut p| { - // Try to acquire the second permit - assert_pending!(p.poll_acquire(cx, 1, &s)); - }); - - permit_1.release(1, &s); - - assert_eq!(s.available_permits(), 0); - assert!(permit_2.is_woken()); - assert_ready_ok!(permit_2.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - permit_2.release(1, &s); - assert_eq!(s.available_permits(), 1); -} - -#[test] -fn forget_acquired() { - let s = Semaphore::new(1); - - // Polling for a permit succeeds immediately - let mut permit = task::spawn(Permit::new()); - - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - assert_eq!(s.available_permits(), 0); - - permit.forget(1); - assert_eq!(s.available_permits(), 0); -} - -#[test] -fn forget_waiting() { - let s = Semaphore::new(0); - - // Polling for a permit succeeds immediately - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - assert_eq!(s.available_permits(), 0); - - permit.forget(1); - - s.add_permits(1); - - assert!(!permit.is_woken()); - assert_eq!(s.available_permits(), 1); -} - -#[test] -fn poll_acquire_many_unavailable() { - let s = Semaphore::new(5); - - let mut permit_1 = task::spawn(Permit::new()); - let mut permit_2 = task::spawn(Permit::new()); - let mut permit_3 = task::spawn(Permit::new()); - - // Acquire the first permit - assert_ready_ok!(permit_1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_eq!(s.available_permits(), 4); - - permit_2.enter(|cx, mut p| { - // Try to acquire the second permit - assert_pending!(p.poll_acquire(cx, 5, &s)); - }); - - assert_eq!(s.available_permits(), 0); - - permit_3.enter(|cx, mut p| { - // Try to acquire the third permit - assert_pending!(p.poll_acquire(cx, 3, &s)); - }); - - permit_1.release(1, &s); - - assert_eq!(s.available_permits(), 0); - assert!(permit_2.is_woken()); - assert_ready_ok!(permit_2.enter(|cx, mut p| p.poll_acquire(cx, 5, &s))); - - assert!(!permit_3.is_woken()); - assert_eq!(s.available_permits(), 0); - - permit_2.release(1, &s); - assert!(!permit_3.is_woken()); - assert_eq!(s.available_permits(), 0); - - permit_2.release(2, &s); - assert!(permit_3.is_woken()); - - assert_ready_ok!(permit_3.enter(|cx, mut p| p.poll_acquire(cx, 3, &s))); -} - -#[test] -fn try_acquire_one_unavailable() { - let s = Semaphore::new(1); - - let mut permit_1 = Permit::new(); - let mut permit_2 = Permit::new(); - - // Acquire the first permit - assert_ok!(permit_1.try_acquire(1, &s)); - assert_eq!(s.available_permits(), 0); - - assert_err!(permit_2.try_acquire(1, &s)); - - permit_1.release(1, &s); - - assert_eq!(s.available_permits(), 1); - assert_ok!(permit_2.try_acquire(1, &s)); - - permit_2.release(1, &s); - assert_eq!(s.available_permits(), 1); -} - -#[test] -fn try_acquire_many_unavailable() { - let s = Semaphore::new(5); - - let mut permit_1 = Permit::new(); - let mut permit_2 = Permit::new(); - - // Acquire the first permit - assert_ok!(permit_1.try_acquire(1, &s)); - assert_eq!(s.available_permits(), 4); - - assert_err!(permit_2.try_acquire(5, &s)); - - permit_1.release(1, &s); - assert_eq!(s.available_permits(), 5); - - assert_ok!(permit_2.try_acquire(5, &s)); - - permit_2.release(1, &s); - assert_eq!(s.available_permits(), 1); - - permit_2.release(1, &s); - assert_eq!(s.available_permits(), 2); -} - -#[test] -fn poll_acquire_one_zero_permits() { - let s = Semaphore::new(0); - assert_eq!(s.available_permits(), 0); - - let mut permit = task::spawn(Permit::new()); - - // Try to acquire the permit - permit.enter(|cx, mut p| { - assert_pending!(p.poll_acquire(cx, 1, &s)); - }); - - s.add_permits(1); - - assert!(permit.is_woken()); - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); -} - -#[test] -#[should_panic] -fn validates_max_permits() { - use std::usize; - Semaphore::new((usize::MAX >> 2) + 1); -} - -#[test] -fn close_semaphore_prevents_acquire() { - let s = Semaphore::new(5); - s.close(); - - assert_eq!(5, s.available_permits()); - - let mut permit_1 = task::spawn(Permit::new()); - let mut permit_2 = task::spawn(Permit::new()); - - assert_ready_err!(permit_1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_eq!(5, s.available_permits()); - - assert_ready_err!(permit_2.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - assert_eq!(5, s.available_permits()); -} - -#[test] -fn close_semaphore_notifies_permit1() { - let s = Semaphore::new(0); - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - s.close(); - - assert!(permit.is_woken()); - assert_ready_err!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); -} - -#[test] -fn close_semaphore_notifies_permit2() { - let s = Semaphore::new(2); - - let mut permit1 = task::spawn(Permit::new()); - let mut permit2 = task::spawn(Permit::new()); - let mut permit3 = task::spawn(Permit::new()); - let mut permit4 = task::spawn(Permit::new()); - - // Acquire a couple of permits - assert_ready_ok!(permit1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_ready_ok!(permit2.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - assert_pending!(permit3.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_pending!(permit4.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - s.close(); - - assert!(permit3.is_woken()); - assert!(permit4.is_woken()); - - assert_ready_err!(permit3.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_ready_err!(permit4.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - assert_eq!(0, s.available_permits()); - - permit1.release(1, &s); - - assert_eq!(1, s.available_permits()); - - assert_ready_err!(permit1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - permit2.release(1, &s); - - assert_eq!(2, s.available_permits()); -} - -#[test] -fn poll_acquire_additional_permits_while_waiting_before_assigned() { - let s = Semaphore::new(1); - - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 3, &s))); - - s.add_permits(1); - assert!(!permit.is_woken()); - - s.add_permits(1); - assert!(permit.is_woken()); - - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 3, &s))); -} - -#[test] -fn try_acquire_additional_permits_while_waiting_before_assigned() { - let s = Semaphore::new(1); - - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - - assert_err!(permit.enter(|_, mut p| p.try_acquire(3, &s))); - - s.add_permits(1); - assert!(permit.is_woken()); - - assert_ok!(permit.enter(|_, mut p| p.try_acquire(2, &s))); -} - -#[test] -fn poll_acquire_additional_permits_while_waiting_after_assigned_success() { - let s = Semaphore::new(1); - - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - - s.add_permits(2); - - assert!(permit.is_woken()); - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 3, &s))); -} - -#[test] -fn poll_acquire_additional_permits_while_waiting_after_assigned_requeue() { - let s = Semaphore::new(1); - - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - - s.add_permits(2); - - assert!(permit.is_woken()); - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 4, &s))); - - s.add_permits(1); - - assert!(permit.is_woken()); - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 4, &s))); -} - -#[test] -fn poll_acquire_fewer_permits_while_waiting() { - let s = Semaphore::new(1); - - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - assert_eq!(s.available_permits(), 0); - - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_eq!(s.available_permits(), 0); -} - -#[test] -fn poll_acquire_fewer_permits_after_assigned() { - let s = Semaphore::new(1); - - let mut permit1 = task::spawn(Permit::new()); - let mut permit2 = task::spawn(Permit::new()); - - assert_pending!(permit1.enter(|cx, mut p| p.poll_acquire(cx, 5, &s))); - assert_eq!(s.available_permits(), 0); - - assert_pending!(permit2.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - s.add_permits(4); - assert!(permit1.is_woken()); - assert!(!permit2.is_woken()); - - assert_ready_ok!(permit1.enter(|cx, mut p| p.poll_acquire(cx, 3, &s))); - - assert!(permit2.is_woken()); - assert_eq!(s.available_permits(), 1); - - assert_ready_ok!(permit2.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); -} - -#[test] -fn forget_partial_1() { - let s = Semaphore::new(0); - - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - s.add_permits(1); - - assert_eq!(0, s.available_permits()); - - permit.release(1, &s); - - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - assert_eq!(s.available_permits(), 0); -} - -#[test] -fn forget_partial_2() { - let s = Semaphore::new(0); - - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - s.add_permits(1); - - assert_eq!(0, s.available_permits()); - - permit.release(1, &s); - - s.add_permits(1); - - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - assert_eq!(s.available_permits(), 0); -} diff --git a/tokio/src/util/linked_list.rs b/tokio/src/util/linked_list.rs index d493efe4514..5073855e8a8 100644 --- a/tokio/src/util/linked_list.rs +++ b/tokio/src/util/linked_list.rs @@ -126,7 +126,6 @@ impl LinkedList { } /// Returns whether the linked list doesn not contain any node - #[cfg_attr(any(feature = "udp", feature = "uds"), allow(unused))] pub(crate) fn is_empty(&self) -> bool { if self.head.is_some() { return false; @@ -182,20 +181,17 @@ impl fmt::Debug for LinkedList { } } -impl Default for LinkedList { - fn default() -> Self { - Self::new() +#[cfg(any(feature = "sync", feature = "signal", feature = "process"))] +impl LinkedList { + pub(crate) fn last(&self) -> Option<&L::Target> { + let tail = self.tail.as_ref()?; + unsafe { Some(&*tail.as_ptr()) } } } -cfg_sync! { - impl LinkedList { - pub(crate) fn last(&self) -> Option<&L::Target> { - let tail = self.tail.as_ref()?; - unsafe { - Some(&*tail.as_ptr()) - } - } +impl Default for LinkedList { + fn default() -> Self { + Self::new() } } diff --git a/tokio/tests/rt_threaded.rs b/tokio/tests/rt_threaded.rs index a67c090ebf4..2c7cfb80c1b 100644 --- a/tokio/tests/rt_threaded.rs +++ b/tokio/tests/rt_threaded.rs @@ -70,7 +70,7 @@ fn many_multishot_futures() { let (start_tx, mut chain_rx) = tokio::sync::mpsc::channel(10); for _ in 0..CHAIN { - let (mut next_tx, next_rx) = tokio::sync::mpsc::channel(10); + let (next_tx, next_rx) = tokio::sync::mpsc::channel(10); // Forward all the messages rt.spawn(async move { @@ -83,8 +83,8 @@ fn many_multishot_futures() { } // This final task cycles if needed - let (mut final_tx, final_rx) = tokio::sync::mpsc::channel(10); - let mut cycle_tx = start_tx.clone(); + let (final_tx, final_rx) = tokio::sync::mpsc::channel(10); + let cycle_tx = start_tx.clone(); let mut rem = CYCLES; rt.spawn(async move { @@ -107,7 +107,7 @@ fn many_multishot_futures() { { rt.block_on(async move { - for mut start_tx in start_txs { + for start_tx in start_txs { start_tx.send("ping").await.unwrap(); } @@ -340,7 +340,7 @@ fn coop_and_block_in_place() { .unwrap(); rt.block_on(async move { - let (mut tx, mut rx) = tokio::sync::mpsc::channel(1024); + let (tx, mut rx) = tokio::sync::mpsc::channel(1024); // Fill the channel for _ in 0..1024 { diff --git a/tokio/tests/sync_mpsc.rs b/tokio/tests/sync_mpsc.rs index 919bddbfb18..adefcb12cb8 100644 --- a/tokio/tests/sync_mpsc.rs +++ b/tokio/tests/sync_mpsc.rs @@ -17,74 +17,72 @@ trait AssertSend: Send {} impl AssertSend for mpsc::Sender {} impl AssertSend for mpsc::Receiver {} -#[test] -fn send_recv_with_buffer() { - let (tx, rx) = mpsc::channel::(16); - let mut tx = task::spawn(tx); - let mut rx = task::spawn(rx); +#[tokio::test] +async fn send_recv_with_buffer() { + let (tx, mut rx) = mpsc::channel::(16); // Using poll_ready / try_send - assert_ready_ok!(tx.enter(|cx, mut tx| tx.poll_ready(cx))); - tx.try_send(1).unwrap(); + // let permit assert_ready_ok!(tx.reserve()); + let permit = tx.reserve().await.unwrap(); + permit.send(1); // Without poll_ready tx.try_send(2).unwrap(); drop(tx); - let val = assert_ready!(rx.enter(|cx, mut rx| rx.poll_recv(cx))); + let val = rx.recv().await; assert_eq!(val, Some(1)); - let val = assert_ready!(rx.enter(|cx, mut rx| rx.poll_recv(cx))); + let val = rx.recv().await; assert_eq!(val, Some(2)); - let val = assert_ready!(rx.enter(|cx, mut rx| rx.poll_recv(cx))); + let val = rx.recv().await; assert!(val.is_none()); } -#[test] -fn disarm() { - let (tx, rx) = mpsc::channel::(2); - let mut tx1 = task::spawn(tx.clone()); - let mut tx2 = task::spawn(tx.clone()); - let mut tx3 = task::spawn(tx.clone()); - let mut tx4 = task::spawn(tx); - let mut rx = task::spawn(rx); +#[tokio::test] +async fn reserve_disarm() { + let (tx, mut rx) = mpsc::channel::(2); + let tx1 = tx.clone(); + let tx2 = tx.clone(); + let tx3 = tx.clone(); + let tx4 = tx; // We should be able to `poll_ready` two handles without problem - assert_ready_ok!(tx1.enter(|cx, mut tx| tx.poll_ready(cx))); - assert_ready_ok!(tx2.enter(|cx, mut tx| tx.poll_ready(cx))); + let permit1 = assert_ok!(tx1.reserve().await); + let permit2 = assert_ok!(tx2.reserve().await); // But a third should not be ready - assert_pending!(tx3.enter(|cx, mut tx| tx.poll_ready(cx))); + let mut r3 = task::spawn(tx3.reserve()); + assert_pending!(r3.poll()); + + let mut r4 = task::spawn(tx4.reserve()); + assert_pending!(r4.poll()); // Using one of the reserved slots should allow a new handle to become ready - tx1.try_send(1).unwrap(); + permit1.send(1); + // We also need to receive for the slot to be free - let _ = assert_ready!(rx.enter(|cx, mut rx| rx.poll_recv(cx))).unwrap(); + assert!(!r3.is_woken()); + rx.recv().await.unwrap(); // Now there's a free slot! - assert_ready_ok!(tx3.enter(|cx, mut tx| tx.poll_ready(cx))); - assert_pending!(tx4.enter(|cx, mut tx| tx.poll_ready(cx))); + assert!(r3.is_woken()); + assert!(!r4.is_woken()); - // Dropping a ready handle should also open up a slot - drop(tx2); - assert_ready_ok!(tx4.enter(|cx, mut tx| tx.poll_ready(cx))); - assert_pending!(tx1.enter(|cx, mut tx| tx.poll_ready(cx))); - - // Explicitly disarming a handle should also open a slot - assert!(tx3.disarm()); - assert_ready_ok!(tx1.enter(|cx, mut tx| tx.poll_ready(cx))); + // Dropping a permit should also open up a slot + drop(permit2); + assert!(r4.is_woken()); - // Disarming a non-armed sender does not free up a slot - assert!(!tx3.disarm()); - assert_pending!(tx3.enter(|cx, mut tx| tx.poll_ready(cx))); + let mut r1 = task::spawn(tx1.reserve()); + assert_pending!(r1.poll()); } #[tokio::test] async fn send_recv_stream_with_buffer() { use tokio::stream::StreamExt; - let (mut tx, mut rx) = mpsc::channel::(16); + let (tx, mut rx) = mpsc::channel::(16); tokio::spawn(async move { assert_ok!(tx.send(1).await); @@ -98,7 +96,7 @@ async fn send_recv_stream_with_buffer() { #[tokio::test] async fn async_send_recv_with_buffer() { - let (mut tx, mut rx) = mpsc::channel(16); + let (tx, mut rx) = mpsc::channel(16); tokio::spawn(async move { assert_ok!(tx.send(1).await); @@ -110,37 +108,36 @@ async fn async_send_recv_with_buffer() { assert_eq!(None, rx.recv().await); } -#[test] -fn start_send_past_cap() { +#[tokio::test] +async fn start_send_past_cap() { + use std::future::Future; + let mut t1 = task::spawn(()); - let mut t2 = task::spawn(()); - let mut t3 = task::spawn(()); - let (mut tx1, mut rx) = mpsc::channel(1); - let mut tx2 = tx1.clone(); + let (tx1, mut rx) = mpsc::channel(1); + let tx2 = tx1.clone(); assert_ok!(tx1.try_send(())); - t1.enter(|cx, _| { - assert_pending!(tx1.poll_ready(cx)); - }); + let mut r1 = Box::pin(tx1.reserve()); + t1.enter(|cx, _| assert_pending!(r1.as_mut().poll(cx))); - t2.enter(|cx, _| { - assert_pending!(tx2.poll_ready(cx)); - }); + { + let mut r2 = task::spawn(tx2.reserve()); + assert_pending!(r2.poll()); - drop(tx1); + drop(r1); - let val = t3.enter(|cx, _| assert_ready!(rx.poll_recv(cx))); - assert!(val.is_some()); + assert!(rx.recv().await.is_some()); - assert!(t2.is_woken()); - assert!(!t1.is_woken()); + assert!(r2.is_woken()); + assert!(!t1.is_woken()); + } + drop(tx1); drop(tx2); - let val = t3.enter(|cx, _| assert_ready!(rx.poll_recv(cx))); - assert!(val.is_none()); + assert!(rx.recv().await.is_none()); } #[test] @@ -149,26 +146,20 @@ fn buffer_gteq_one() { mpsc::channel::(0); } -#[test] -fn send_recv_unbounded() { - let mut t1 = task::spawn(()); - +#[tokio::test] +async fn send_recv_unbounded() { let (tx, mut rx) = mpsc::unbounded_channel::(); // Using `try_send` assert_ok!(tx.send(1)); assert_ok!(tx.send(2)); - let val = assert_ready!(t1.enter(|cx, _| rx.poll_recv(cx))); - assert_eq!(val, Some(1)); - - let val = assert_ready!(t1.enter(|cx, _| rx.poll_recv(cx))); - assert_eq!(val, Some(2)); + assert_eq!(rx.recv().await, Some(1)); + assert_eq!(rx.recv().await, Some(2)); drop(tx); - let val = assert_ready!(t1.enter(|cx, _| rx.poll_recv(cx))); - assert!(val.is_none()); + assert!(rx.recv().await.is_none()); } #[tokio::test] @@ -201,11 +192,10 @@ async fn send_recv_stream_unbounded() { assert_eq!(None, rx.next().await); } -#[test] -fn no_t_bounds_buffer() { +#[tokio::test] +async fn no_t_bounds_buffer() { struct NoImpls; - let mut t1 = task::spawn(()); let (tx, mut rx) = mpsc::channel(100); // sender should be Debug even though T isn't Debug @@ -215,15 +205,13 @@ fn no_t_bounds_buffer() { // and sender should be Clone even though T isn't Clone assert!(tx.clone().try_send(NoImpls).is_ok()); - let val = assert_ready!(t1.enter(|cx, _| rx.poll_recv(cx))); - assert!(val.is_some()); + assert!(rx.recv().await.is_some()); } -#[test] -fn no_t_bounds_unbounded() { +#[tokio::test] +async fn no_t_bounds_unbounded() { struct NoImpls; - let mut t1 = task::spawn(()); let (tx, mut rx) = mpsc::unbounded_channel(); // sender should be Debug even though T isn't Debug @@ -233,133 +221,87 @@ fn no_t_bounds_unbounded() { // and sender should be Clone even though T isn't Clone assert!(tx.clone().send(NoImpls).is_ok()); - let val = assert_ready!(t1.enter(|cx, _| rx.poll_recv(cx))); - assert!(val.is_some()); + assert!(rx.recv().await.is_some()); } -#[test] -fn send_recv_buffer_limited() { - let mut t1 = task::spawn(()); - let mut t2 = task::spawn(()); - - let (mut tx, mut rx) = mpsc::channel::(1); - - // Run on a task context - t1.enter(|cx, _| { - assert_ready_ok!(tx.poll_ready(cx)); - - // Send first message - assert_ok!(tx.try_send(1)); +#[tokio::test] +async fn send_recv_buffer_limited() { + let (tx, mut rx) = mpsc::channel::(1); - // Not ready - assert_pending!(tx.poll_ready(cx)); + // Reserve capacity + let p1 = assert_ok!(tx.reserve().await); - // Send second message - assert_err!(tx.try_send(1337)); - }); + // Send first message + p1.send(1); - t2.enter(|cx, _| { - // Take the value - let val = assert_ready!(rx.poll_recv(cx)); - assert_eq!(Some(1), val); - }); + // Not ready + let mut p2 = task::spawn(tx.reserve()); + assert_pending!(p2.poll()); - assert!(t1.is_woken()); + // Take the value + assert!(rx.recv().await.is_some()); - t1.enter(|cx, _| { - assert_ready_ok!(tx.poll_ready(cx)); + // Notified + assert!(p2.is_woken()); - assert_ok!(tx.try_send(2)); + // Trying to send fails + assert_err!(tx.try_send(1337)); - // Not ready - assert_pending!(tx.poll_ready(cx)); - }); + // Send second + let permit = assert_ready_ok!(p2.poll()); + permit.send(2); - t2.enter(|cx, _| { - // Take the value - let val = assert_ready!(rx.poll_recv(cx)); - assert_eq!(Some(2), val); - }); - - t1.enter(|cx, _| { - assert_ready_ok!(tx.poll_ready(cx)); - }); + assert!(rx.recv().await.is_some()); } -#[test] -fn recv_close_gets_none_idle() { - let mut t1 = task::spawn(()); - - let (mut tx, mut rx) = mpsc::channel::(10); +#[tokio::test] +async fn recv_close_gets_none_idle() { + let (tx, mut rx) = mpsc::channel::(10); rx.close(); - t1.enter(|cx, _| { - let val = assert_ready!(rx.poll_recv(cx)); - assert!(val.is_none()); - assert_ready_err!(tx.poll_ready(cx)); - }); -} - -#[test] -fn recv_close_gets_none_reserved() { - let mut t1 = task::spawn(()); - let mut t2 = task::spawn(()); - let mut t3 = task::spawn(()); + assert!(rx.recv().await.is_none()); - let (mut tx1, mut rx) = mpsc::channel::(1); - let mut tx2 = tx1.clone(); + assert_err!(tx.send(1).await); +} - assert_ready_ok!(t1.enter(|cx, _| tx1.poll_ready(cx))); +#[tokio::test] +async fn recv_close_gets_none_reserved() { + let (tx1, mut rx) = mpsc::channel::(1); + let tx2 = tx1.clone(); - t2.enter(|cx, _| { - assert_pending!(tx2.poll_ready(cx)); - }); + let permit1 = assert_ok!(tx1.reserve().await); + let mut permit2 = task::spawn(tx2.reserve()); + assert_pending!(permit2.poll()); rx.close(); - assert!(t2.is_woken()); - - t2.enter(|cx, _| { - assert_ready_err!(tx2.poll_ready(cx)); - }); - - t3.enter(|cx, _| assert_pending!(rx.poll_recv(cx))); + assert!(permit2.is_woken()); + assert_ready_err!(permit2.poll()); - assert!(!t1.is_woken()); - assert!(!t2.is_woken()); - - assert_ok!(tx1.try_send(123)); + { + let mut recv = task::spawn(rx.recv()); + assert_pending!(recv.poll()); - assert!(t3.is_woken()); + permit1.send(123); + assert!(recv.is_woken()); - t3.enter(|cx, _| { - let v = assert_ready!(rx.poll_recv(cx)); + let v = assert_ready!(recv.poll()); assert_eq!(v, Some(123)); + } - let v = assert_ready!(rx.poll_recv(cx)); - assert!(v.is_none()); - }); + assert!(rx.recv().await.is_none()); } -#[test] -fn tx_close_gets_none() { - let mut t1 = task::spawn(()); - +#[tokio::test] +async fn tx_close_gets_none() { let (_, mut rx) = mpsc::channel::(10); - - // Run on a task context - t1.enter(|cx, _| { - let v = assert_ready!(rx.poll_recv(cx)); - assert!(v.is_none()); - }); + assert!(rx.recv().await.is_none()); } -#[test] -fn try_send_fail() { - let mut t1 = task::spawn(()); - - let (mut tx, mut rx) = mpsc::channel(1); +#[tokio::test] +async fn try_send_fail() { + let (tx, mut rx) = mpsc::channel(1); tx.try_send("hello").unwrap(); @@ -369,60 +311,48 @@ fn try_send_fail() { _ => panic!(), } - let val = assert_ready!(t1.enter(|cx, _| rx.poll_recv(cx))); - assert_eq!(val, Some("hello")); + assert_eq!(rx.recv().await, Some("hello")); assert_ok!(tx.try_send("goodbye")); drop(tx); - let val = assert_ready!(t1.enter(|cx, _| rx.poll_recv(cx))); - assert_eq!(val, Some("goodbye")); - - let val = assert_ready!(t1.enter(|cx, _| rx.poll_recv(cx))); - assert!(val.is_none()); + assert_eq!(rx.recv().await, Some("goodbye")); + assert!(rx.recv().await.is_none()); } -#[test] -fn drop_tx_with_permit_releases_permit() { - let mut t1 = task::spawn(()); - let mut t2 = task::spawn(()); - +#[tokio::test] +async fn drop_permit_releases_permit() { // poll_ready reserves capacity, ensure that the capacity is released if tx // is dropped w/o sending a value. - let (mut tx1, _rx) = mpsc::channel::(1); - let mut tx2 = tx1.clone(); - - assert_ready_ok!(t1.enter(|cx, _| tx1.poll_ready(cx))); + let (tx1, _rx) = mpsc::channel::(1); + let tx2 = tx1.clone(); - t2.enter(|cx, _| { - assert_pending!(tx2.poll_ready(cx)); - }); + let permit = assert_ok!(tx1.reserve().await); - drop(tx1); + let mut reserve2 = task::spawn(tx2.reserve()); + assert_pending!(reserve2.poll()); - assert!(t2.is_woken()); + drop(permit); - assert_ready_ok!(t2.enter(|cx, _| tx2.poll_ready(cx))); + assert!(reserve2.is_woken()); + assert_ready_ok!(reserve2.poll()); } -#[test] -fn dropping_rx_closes_channel() { - let mut t1 = task::spawn(()); - - let (mut tx, rx) = mpsc::channel(100); +#[tokio::test] +async fn dropping_rx_closes_channel() { + let (tx, rx) = mpsc::channel(100); let msg = Arc::new(()); assert_ok!(tx.try_send(msg.clone())); drop(rx); - assert_ready_err!(t1.enter(|cx, _| tx.poll_ready(cx))); - + assert_err!(tx.reserve().await); assert_eq!(1, Arc::strong_count(&msg)); } #[test] fn dropping_rx_closes_channel_for_try() { - let (mut tx, rx) = mpsc::channel(100); + let (tx, rx) = mpsc::channel(100); let msg = Arc::new(()); tx.try_send(msg.clone()).unwrap(); @@ -444,7 +374,7 @@ fn dropping_rx_closes_channel_for_try() { fn unconsumed_messages_are_dropped() { let msg = Arc::new(()); - let (mut tx, rx) = mpsc::channel(100); + let (tx, rx) = mpsc::channel(100); tx.try_send(msg.clone()).unwrap(); @@ -457,7 +387,7 @@ fn unconsumed_messages_are_dropped() { #[test] fn try_recv() { - let (mut tx, mut rx) = mpsc::channel(1); + let (tx, mut rx) = mpsc::channel(1); match rx.try_recv() { Err(TryRecvError::Empty) => {} _ => panic!(), @@ -495,7 +425,7 @@ fn try_recv_unbounded() { #[test] fn blocking_recv() { - let (mut tx, mut rx) = mpsc::channel::(1); + let (tx, mut rx) = mpsc::channel::(1); let sync_code = thread::spawn(move || { assert_eq!(Some(10), rx.blocking_recv()); @@ -516,7 +446,7 @@ async fn blocking_recv_async() { #[test] fn blocking_send() { - let (mut tx, mut rx) = mpsc::channel::(1); + let (tx, mut rx) = mpsc::channel::(1); let sync_code = thread::spawn(move || { tx.blocking_send(10).unwrap(); @@ -531,28 +461,25 @@ fn blocking_send() { #[tokio::test] #[should_panic] async fn blocking_send_async() { - let (mut tx, _rx) = mpsc::channel::<()>(1); + let (tx, _rx) = mpsc::channel::<()>(1); let _ = tx.blocking_send(()); } -#[test] -fn ready_close_cancel_bounded() { - use futures::future::poll_fn; - - let (mut tx, mut rx) = mpsc::channel::<()>(100); +#[tokio::test] +async fn ready_close_cancel_bounded() { + let (tx, mut rx) = mpsc::channel::<()>(100); let _tx2 = tx.clone(); - { - let mut ready = task::spawn(async { poll_fn(|cx| tx.poll_ready(cx)).await }); - assert_ready_ok!(ready.poll()); - } + let permit = assert_ok!(tx.reserve().await); rx.close(); - let mut recv = task::spawn(async { rx.recv().await }); + let mut recv = task::spawn(rx.recv()); assert_pending!(recv.poll()); - drop(tx); + drop(permit); assert!(recv.is_woken()); + let val = assert_ready!(recv.poll()); + assert!(val.is_none()); }