diff --git a/tokio/src/runtime/task/harness.rs b/tokio/src/runtime/task/harness.rs index e86b29e699e..8a9961470c4 100644 --- a/tokio/src/runtime/task/harness.rs +++ b/tokio/src/runtime/task/harness.rs @@ -154,56 +154,75 @@ where /// Read the task output into `dst`. pub(super) fn try_read_output(self, dst: &mut Poll>, waker: &Waker) { + if self.try_poll_completion(waker) { + *dst = Poll::Ready(self.core().take_output()); + } + } + + pub(super) fn try_read_completion(self, dst: &mut Poll>, waker: &Waker) { + if self.try_poll_completion(waker) { + // TODO: might be room for more optimization to avoid reading out + // the output (avoiding the map). + *dst = Poll::Ready(self.core().take_output().map(|_| ())); + } + } + + /// Try to poll the task associated with the harness to completion, or + /// install the associated waker to be notified when the task is complete. + /// + /// Returns a boolean indicating if the task is complete or not. + fn try_poll_completion(&self, waker: &Waker) -> bool { // Load a snapshot of the current task state let snapshot = self.header().state.load(); debug_assert!(snapshot.is_join_interested()); - if !snapshot.is_complete() { - // The waker must be stored in the task struct. - let res = if snapshot.has_join_waker() { - // There already is a waker stored in the struct. If it matches - // the provided waker, then there is no further work to do. - // Otherwise, the waker must be swapped. - let will_wake = unsafe { - // Safety: when `JOIN_INTEREST` is set, only `JOIN_HANDLE` - // may mutate the `waker` field. - self.trailer() - .waker - .with(|ptr| (*ptr).as_ref().unwrap().will_wake(waker)) - }; - - if will_wake { - // The task is not complete **and** the waker is up to date, - // there is nothing further that needs to be done. - return; - } + if snapshot.is_complete() { + return true; + } - // Unset the `JOIN_WAKER` to gain mutable access to the `waker` - // field then update the field with the new join worker. - // - // This requires two atomic operations, unsetting the bit and - // then resetting it. If the task transitions to complete - // concurrently to either one of those operations, then setting - // the join waker fails and we proceed to reading the task - // output. - self.header() - .state - .unset_waker() - .and_then(|snapshot| self.set_join_waker(waker.clone(), snapshot)) - } else { - self.set_join_waker(waker.clone(), snapshot) + // The waker must be stored in the task struct. + let res = if snapshot.has_join_waker() { + // There already is a waker stored in the struct. If it matches + // the provided waker, then there is no further work to do. + // Otherwise, the waker must be swapped. + let will_wake = unsafe { + // Safety: when `JOIN_INTEREST` is set, only `JOIN_HANDLE` + // may mutate the `waker` field. + self.trailer() + .waker + .with(|ptr| (*ptr).as_ref().unwrap().will_wake(waker)) }; - match res { - Ok(_) => return, - Err(snapshot) => { - assert!(snapshot.is_complete()); - } + if will_wake { + // The task is not complete **and** the waker is up to date, + // there is nothing further that needs to be done. + return false; } - } - *dst = Poll::Ready(self.core().take_output()); + // Unset the `JOIN_WAKER` to gain mutable access to the `waker` + // field then update the field with the new join worker. + // + // This requires two atomic operations, unsetting the bit and + // then resetting it. If the task transitions to complete + // concurrently to either one of those operations, then setting + // the join waker fails and we proceed to reading the task + // output. + self.header() + .state + .unset_waker() + .and_then(|snapshot| self.set_join_waker(waker.clone(), snapshot)) + } else { + self.set_join_waker(waker.clone(), snapshot) + }; + + match res { + Ok(_) => false, + Err(snapshot) => { + assert!(snapshot.is_complete()); + true + } + } } fn set_join_waker(&self, waker: Waker, snapshot: Snapshot) -> Result { diff --git a/tokio/src/runtime/task/join.rs b/tokio/src/runtime/task/join.rs index fdcc346e5c1..cd8f19505de 100644 --- a/tokio/src/runtime/task/join.rs +++ b/tokio/src/runtime/task/join.rs @@ -76,9 +76,26 @@ doc_rt_core! { /// [`task::spawn_blocking`]: crate::task::spawn_blocking /// [`std::thread::JoinHandle`]: std::thread::JoinHandle pub struct JoinHandle { - raw: Option, + raw: RawJoinHandle, _p: PhantomData, } + + /// A type-erased variant of [`JoinHandle`], these are created by using + /// [`JoinHandle::into_raw_handle`]. + /// + /// Raw join handles erase the type information associated with a + /// [`JoinHandle`], allowing them to easily be stored in containers for + /// future cancellation. + /// + /// They behave exactly the same as regular [`JoinHandle`], except that + /// instead of resolving to `Result`, they resolve to + /// `Result<(), JoinError>`. + /// + /// [`JoinHandle`]: JoinHandle + /// [`JoinHandle::into_raw_handle`]: JoinHandle::into_raw_handle + pub struct RawJoinHandle { + raw: Option, + } } unsafe impl Send for JoinHandle {} @@ -87,13 +104,173 @@ unsafe impl Sync for JoinHandle {} impl JoinHandle { pub(super) fn new(raw: RawTask) -> JoinHandle { JoinHandle { - raw: Some(raw), + raw: RawJoinHandle { + raw: Some(raw), + }, _p: PhantomData, } } + + /// Cancel the task associated with the handle. + /// + /// Awaiting a cancelled task might complete as usual if the task was + /// already completed at the time it was cancelled, but most likely it + /// will complete with a `Err(JoinError::Cancelled)`. + /// + /// ```rust + /// use tokio::time; + /// + /// #[tokio::main] + /// async fn main() { + /// let mut handles = Vec::new(); + /// + /// handles.push(tokio::spawn(async { + /// time::delay_for(time::Duration::from_secs(10)).await; + /// true + /// })); + /// + /// handles.push(tokio::spawn(async { + /// time::delay_for(time::Duration::from_secs(10)).await; + /// false + /// })); + /// + /// for handle in &handles { + /// handle.cancel(); + /// } + /// + /// for handle in handles { + /// assert!(handle.await.unwrap_err().is_cancelled()); + /// } + /// } + /// ``` + pub fn cancel(&self) { + if let Some(raw) = self.raw.raw { + raw.shutdown(); + } + } + + /// Convert the join handle into a raw join handle. + /// + /// Raw join handles erase the type information associated with a + /// [`JoinHandle`], allowing them to easily be stored in containers for + /// future cancellation. + /// + /// [`JoinHandle`]: JoinHandle + /// + /// # Examples + /// + /// ```rust + /// use tokio::time; + /// + /// #[tokio::main] + /// async fn main() { + /// let mut handles = Vec::new(); + /// + /// let handle = tokio::spawn(async { + /// time::delay_for(time::Duration::from_secs(10)).await; + /// true + /// }); + /// + /// handles.push(handle.into_raw_handle()); + /// + /// let handle = tokio::spawn(async { + /// time::delay_for(time::Duration::from_secs(10)).await; + /// 1u32 + /// }); + /// + /// handles.push(handle.into_raw_handle()); + /// + /// for handle in &handles { + /// handle.cancel(); + /// } + /// + /// for handle in handles { + /// assert!(handle.await.unwrap_err().is_cancelled()); + /// } + /// } + /// ``` + pub fn into_raw_handle(self) -> RawJoinHandle { + self.raw + } +} + +impl RawJoinHandle { + /// Cancel the task associated with the handle. + /// + /// Awaiting a cancelled task might complete as usual if the task was + /// already completed at the time it was cancelled, but most likely it + /// will complete with a `Err(JoinError::Cancelled)`. + /// + /// # Examples + /// + /// ```rust + /// use tokio::time; + /// + /// #[tokio::main] + /// async fn main() { + /// let mut handles = Vec::new(); + /// + /// let handle = tokio::spawn(async { + /// time::delay_for(time::Duration::from_secs(10)).await; + /// true + /// }); + /// + /// handles.push(handle.into_raw_handle()); + /// + /// let handle = tokio::spawn(async { + /// time::delay_for(time::Duration::from_secs(10)).await; + /// 1u32 + /// }); + /// + /// handles.push(handle.into_raw_handle()); + /// + /// for handle in &handles { + /// handle.cancel(); + /// } + /// + /// for handle in handles { + /// assert!(handle.await.unwrap_err().is_cancelled()); + /// } + /// } + /// ``` + pub fn cancel(&self) { + if let Some(raw) = self.raw { + raw.shutdown(); + } + } } -impl Unpin for JoinHandle {} +impl Unpin for RawJoinHandle {} + +impl Future for RawJoinHandle { + type Output = super::Result<()>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut ret = Poll::Pending; + + // Keep track of task budget + ready!(crate::coop::poll_proceed(cx)); + + // Raw should always be set. If it is not, this is due to polling after + // completion + let raw = self + .raw + .as_ref() + .expect("polling after `JoinHandle` already completed"); + + // Try to read the task completion. If the task is not yet complete, the + // waker is stored and is notified once the task does complete. + // + // Safety: + // + // `ret` must be of type `Poll<()>`. + unsafe { + raw.try_read_completion(&mut ret as *mut _ as *mut (), cx.waker()); + } + + ret + } +} impl Future for JoinHandle { type Output = super::Result; @@ -107,6 +284,7 @@ impl Future for JoinHandle { // Raw should always be set. If it is not, this is due to polling after // completion let raw = self + .raw .raw .as_ref() .expect("polling after `JoinHandle` already completed"); @@ -130,7 +308,7 @@ impl Future for JoinHandle { } } -impl Drop for JoinHandle { +impl Drop for RawJoinHandle { fn drop(&mut self) { if let Some(raw) = self.raw.take() { if raw.header().state.drop_join_handle_fast().is_ok() { @@ -150,3 +328,9 @@ where fmt.debug_struct("JoinHandle").finish() } } + +impl fmt::Debug for RawJoinHandle { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("RawJoinHandle").finish() + } +} diff --git a/tokio/src/runtime/task/raw.rs b/tokio/src/runtime/task/raw.rs index cae56d037da..decd992e4b6 100644 --- a/tokio/src/runtime/task/raw.rs +++ b/tokio/src/runtime/task/raw.rs @@ -19,6 +19,9 @@ pub(super) struct Vtable { /// Read the task output, if complete pub(super) try_read_output: unsafe fn(NonNull
, *mut (), &Waker), + /// Read the task completion, if complete + pub(super) try_read_completion: unsafe fn(NonNull
, *mut (), &Waker), + /// The join handle has been dropped pub(super) drop_join_handle_slow: unsafe fn(NonNull
), @@ -32,6 +35,7 @@ pub(super) fn vtable() -> &'static Vtable { poll: poll::, dealloc: dealloc::, try_read_output: try_read_output::, + try_read_completion: try_read_completion::, drop_join_handle_slow: drop_join_handle_slow::, shutdown: shutdown::, } @@ -80,6 +84,12 @@ impl RawTask { (vtable.try_read_output)(self.ptr, dst, waker); } + /// Safety: `dst` must be a `*mut Poll>`. + pub(super) unsafe fn try_read_completion(self, dst: *mut (), waker: &Waker) { + let vtable = self.header().vtable; + (vtable.try_read_completion)(self.ptr, dst, waker); + } + pub(super) fn drop_join_handle_slow(self) { let vtable = self.header().vtable; unsafe { (vtable.drop_join_handle_slow)(self.ptr) } @@ -120,6 +130,17 @@ unsafe fn try_read_output( harness.try_read_output(out, waker); } +unsafe fn try_read_completion( + ptr: NonNull
, + dst: *mut (), + waker: &Waker, +) { + let out = &mut *(dst as *mut Poll>); + + let harness = Harness::::from_raw(ptr); + harness.try_read_completion(out, waker); +} + unsafe fn drop_join_handle_slow(ptr: NonNull
) { let harness = Harness::::from_raw(ptr); harness.drop_join_handle_slow()