diff --git a/tokio/src/runtime/task/abort.rs b/tokio/src/runtime/task/abort.rs new file mode 100644 index 00000000000..b56cb5cd81a --- /dev/null +++ b/tokio/src/runtime/task/abort.rs @@ -0,0 +1,69 @@ +use crate::runtime::task::RawTask; +use std::fmt; +use std::panic::{RefUnwindSafe, UnwindSafe}; + +/// An owned permission to abort a spawned task, without awaiting its completion. +/// +/// Unlike a [`JoinHandle`], an `AbortHandle` does *not* represent the +/// permission to await the task's completion, only to terminate it. +/// +/// The task may be aborted by calling the [`AbortHandle::abort`] method. +/// Dropping an `AbortHandle` releases the permission to terminate the task +/// --- it does *not* abort the task. +/// +/// **Note**: This is an [unstable API][unstable]. The public API of this type +/// may break in 1.x releases. See [the documentation on unstable +/// features][unstable] for details. +/// +/// [unstable]: crate#unstable-features +/// [`JoinHandle`]: crate::task::JoinHandle +#[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +pub struct AbortHandle { + raw: Option, +} + +impl AbortHandle { + pub(super) fn new(raw: Option) -> Self { + Self { raw } + } + + /// Abort the task associated with the handle. + /// + /// Awaiting a cancelled task might complete as usual if the task was + /// already completed at the time it was cancelled, but most likely it + /// will fail with a [cancelled] `JoinError`. + /// + /// If the task was already cancelled, such as by [`JoinHandle::abort`], + /// this method will do nothing. + /// + /// [cancelled]: method@super::error::JoinError::is_cancelled + // the `AbortHandle` type is only publicly exposed when `tokio_unstable` is + // enabled, but it is still defined for testing purposes. + #[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] + pub fn abort(self) { + if let Some(raw) = self.raw { + raw.remote_abort(); + } + } +} + +unsafe impl Send for AbortHandle {} +unsafe impl Sync for AbortHandle {} + +impl UnwindSafe for AbortHandle {} +impl RefUnwindSafe for AbortHandle {} + +impl fmt::Debug for AbortHandle { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("AbortHandle").finish() + } +} + +impl Drop for AbortHandle { + fn drop(&mut self) { + if let Some(raw) = self.raw.take() { + raw.drop_abort_handle(); + } + } +} diff --git a/tokio/src/runtime/task/join.rs b/tokio/src/runtime/task/join.rs index 8beed2eaacb..b7846348e34 100644 --- a/tokio/src/runtime/task/join.rs +++ b/tokio/src/runtime/task/join.rs @@ -210,6 +210,16 @@ impl JoinHandle { } } } + + /// Returns a new `AbortHandle` that can be used to remotely abort this task. + #[cfg(any(tokio_unstable, test))] + pub(crate) fn abort_handle(&self) -> super::AbortHandle { + let raw = self.raw.map(|raw| { + raw.ref_inc(); + raw + }); + super::AbortHandle::new(raw) + } } impl Unpin for JoinHandle {} diff --git a/tokio/src/runtime/task/mod.rs b/tokio/src/runtime/task/mod.rs index 2a492dc985d..c7a85c5a173 100644 --- a/tokio/src/runtime/task/mod.rs +++ b/tokio/src/runtime/task/mod.rs @@ -155,7 +155,14 @@ cfg_rt_multi_thread! { pub(super) use self::inject::Inject; } +#[cfg(all(feature = "rt", any(tokio_unstable, test)))] +mod abort; mod join; + +#[cfg(all(feature = "rt", any(tokio_unstable, test)))] +#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 +pub use self::abort::AbortHandle; + #[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 pub use self::join::JoinHandle; diff --git a/tokio/src/runtime/task/raw.rs b/tokio/src/runtime/task/raw.rs index 2e4420b5c13..95569f947e5 100644 --- a/tokio/src/runtime/task/raw.rs +++ b/tokio/src/runtime/task/raw.rs @@ -27,6 +27,9 @@ pub(super) struct Vtable { /// The join handle has been dropped. pub(super) drop_join_handle_slow: unsafe fn(NonNull
), + /// An abort handle has been dropped. + pub(super) drop_abort_handle: unsafe fn(NonNull
), + /// The task is remotely aborted. pub(super) remote_abort: unsafe fn(NonNull
), @@ -42,6 +45,7 @@ pub(super) fn vtable() -> &'static Vtable { try_read_output: try_read_output::, try_set_join_waker: try_set_join_waker::, drop_join_handle_slow: drop_join_handle_slow::, + drop_abort_handle: drop_abort_handle::, remote_abort: remote_abort::, shutdown: shutdown::, } @@ -104,6 +108,11 @@ impl RawTask { unsafe { (vtable.drop_join_handle_slow)(self.ptr) } } + pub(super) fn drop_abort_handle(self) { + let vtable = self.header().vtable; + unsafe { (vtable.drop_abort_handle)(self.ptr) } + } + pub(super) fn shutdown(self) { let vtable = self.header().vtable; unsafe { (vtable.shutdown)(self.ptr) } @@ -113,6 +122,13 @@ impl RawTask { let vtable = self.header().vtable; unsafe { (vtable.remote_abort)(self.ptr) } } + + /// Increment the task's reference count. + /// + /// Currently, this is used only when creating an `AbortHandle`. + pub(super) fn ref_inc(self) { + self.header().state.ref_inc(); + } } impl Clone for RawTask { @@ -154,6 +170,11 @@ unsafe fn drop_join_handle_slow(ptr: NonNull
) { harness.drop_join_handle_slow() } +unsafe fn drop_abort_handle(ptr: NonNull
) { + let harness = Harness::::from_raw(ptr); + harness.drop_reference(); +} + unsafe fn remote_abort(ptr: NonNull
) { let harness = Harness::::from_raw(ptr); harness.remote_abort() diff --git a/tokio/src/runtime/tests/task.rs b/tokio/src/runtime/tests/task.rs index 04e1b56e777..622f5661784 100644 --- a/tokio/src/runtime/tests/task.rs +++ b/tokio/src/runtime/tests/task.rs @@ -78,6 +78,44 @@ fn create_drop2() { handle.assert_dropped(); } +#[test] +fn drop_abort_handle1() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + ); + let abort = join.abort_handle(); + drop(join); + handle.assert_not_dropped(); + drop(notified); + handle.assert_not_dropped(); + drop(abort); + handle.assert_dropped(); +} + +#[test] +fn drop_abort_handle2() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + ); + let abort = join.abort_handle(); + drop(notified); + handle.assert_not_dropped(); + drop(abort); + handle.assert_not_dropped(); + drop(join); + handle.assert_dropped(); +} + // Shutting down through Notified works #[test] fn create_shutdown1() { diff --git a/tokio/src/runtime/tests/task_combinations.rs b/tokio/src/runtime/tests/task_combinations.rs index 76ce2330c2c..5c7a0b0109b 100644 --- a/tokio/src/runtime/tests/task_combinations.rs +++ b/tokio/src/runtime/tests/task_combinations.rs @@ -3,6 +3,7 @@ use std::panic; use std::pin::Pin; use std::task::{Context, Poll}; +use crate::runtime::task::AbortHandle; use crate::runtime::Builder; use crate::sync::oneshot; use crate::task::JoinHandle; @@ -56,6 +57,12 @@ enum CombiAbort { AbortedAfterConsumeOutput = 4, } +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiAbortSource { + JoinHandle, + AbortHandle, +} + #[test] fn test_combinations() { let mut rt = &[ @@ -90,6 +97,13 @@ fn test_combinations() { CombiAbort::AbortedAfterFinish, CombiAbort::AbortedAfterConsumeOutput, ]; + let ah = [ + None, + Some(CombiJoinHandle::DropImmediately), + Some(CombiJoinHandle::DropFirstPoll), + Some(CombiJoinHandle::DropAfterNoConsume), + Some(CombiJoinHandle::DropAfterConsume), + ]; for rt in rt.iter().copied() { for ls in ls.iter().copied() { @@ -98,7 +112,34 @@ fn test_combinations() { for ji in ji.iter().copied() { for jh in jh.iter().copied() { for abort in abort.iter().copied() { - test_combination(rt, ls, task, output, ji, jh, abort); + // abort via join handle --- abort handles + // may be dropped at any point + for ah in ah.iter().copied() { + test_combination( + rt, + ls, + task, + output, + ji, + jh, + ah, + abort, + CombiAbortSource::JoinHandle, + ); + } + // if aborting via AbortHandle, it will + // never be dropped. + test_combination( + rt, + ls, + task, + output, + ji, + jh, + None, + abort, + CombiAbortSource::AbortHandle, + ); } } } @@ -108,6 +149,7 @@ fn test_combinations() { } } +#[allow(clippy::too_many_arguments)] fn test_combination( rt: CombiRuntime, ls: CombiLocalSet, @@ -115,12 +157,24 @@ fn test_combination( output: CombiOutput, ji: CombiJoinInterest, jh: CombiJoinHandle, + ah: Option, abort: CombiAbort, + abort_src: CombiAbortSource, ) { - if (jh as usize) < (abort as usize) { - // drop before abort not possible - return; + match (abort_src, ah) { + (CombiAbortSource::JoinHandle, _) if (jh as usize) < (abort as usize) => { + // join handle dropped prior to abort + return; + } + (CombiAbortSource::AbortHandle, Some(_)) => { + // abort handle dropped, we can't abort through the + // abort handle + return; + } + + _ => {} } + if (task == CombiTask::PanicOnDrop) && (output == CombiOutput::PanicOnDrop) { // this causes double panic return; @@ -130,7 +184,7 @@ fn test_combination( return; } - println!("Runtime {:?}, LocalSet {:?}, Task {:?}, Output {:?}, JoinInterest {:?}, JoinHandle {:?}, Abort {:?}", rt, ls, task, output, ji, jh, abort); + println!("Runtime {:?}, LocalSet {:?}, Task {:?}, Output {:?}, JoinInterest {:?}, JoinHandle {:?}, AbortHandle {:?}, Abort {:?} ({:?})", rt, ls, task, output, ji, jh, ah, abort, abort_src); // A runtime optionally with a LocalSet struct Rt { @@ -282,8 +336,24 @@ fn test_combination( ); } + // If we are either aborting the task via an abort handle, or dropping via + // an abort handle, do that now. + let mut abort_handle = if ah.is_some() || abort_src == CombiAbortSource::AbortHandle { + handle.as_ref().map(JoinHandle::abort_handle) + } else { + None + }; + + let do_abort = |abort_handle: &mut Option, + join_handle: Option<&mut JoinHandle<_>>| { + match abort_src { + CombiAbortSource::AbortHandle => abort_handle.take().unwrap().abort(), + CombiAbortSource::JoinHandle => join_handle.unwrap().abort(), + } + }; + if abort == CombiAbort::AbortedImmediately { - handle.as_mut().unwrap().abort(); + do_abort(&mut abort_handle, handle.as_mut()); aborted = true; } if jh == CombiJoinHandle::DropImmediately { @@ -301,12 +371,15 @@ fn test_combination( } if abort == CombiAbort::AbortedFirstPoll { - handle.as_mut().unwrap().abort(); + do_abort(&mut abort_handle, handle.as_mut()); aborted = true; } if jh == CombiJoinHandle::DropFirstPoll { drop(handle.take().unwrap()); } + if ah == Some(CombiJoinHandle::DropFirstPoll) { + drop(abort_handle.take().unwrap()); + } // Signal the future that it can return now let _ = on_complete.send(()); @@ -318,23 +391,42 @@ fn test_combination( if abort == CombiAbort::AbortedAfterFinish { // Don't set aborted to true here as the task already finished - handle.as_mut().unwrap().abort(); + do_abort(&mut abort_handle, handle.as_mut()); } if jh == CombiJoinHandle::DropAfterNoConsume { - // The runtime will usually have dropped every ref-count at this point, - // in which case dropping the JoinHandle drops the output. - // - // (But it might race and still hold a ref-count) - let panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { + if ah == Some(CombiJoinHandle::DropAfterNoConsume) { drop(handle.take().unwrap()); - })); - if panic.is_err() { - assert!( - (output == CombiOutput::PanicOnDrop) - && (!matches!(task, CombiTask::PanicOnRun | CombiTask::PanicOnRunAndDrop)) - && !aborted, - "Dropping JoinHandle shouldn't panic here" - ); + // The runtime will usually have dropped every ref-count at this point, + // in which case dropping the AbortHandle drops the output. + // + // (But it might race and still hold a ref-count) + let panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { + drop(abort_handle.take().unwrap()); + })); + if panic.is_err() { + assert!( + (output == CombiOutput::PanicOnDrop) + && (!matches!(task, CombiTask::PanicOnRun | CombiTask::PanicOnRunAndDrop)) + && !aborted, + "Dropping AbortHandle shouldn't panic here" + ); + } + } else { + // The runtime will usually have dropped every ref-count at this point, + // in which case dropping the JoinHandle drops the output. + // + // (But it might race and still hold a ref-count) + let panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { + drop(handle.take().unwrap()); + })); + if panic.is_err() { + assert!( + (output == CombiOutput::PanicOnDrop) + && (!matches!(task, CombiTask::PanicOnRun | CombiTask::PanicOnRunAndDrop)) + && !aborted, + "Dropping JoinHandle shouldn't panic here" + ); + } } } @@ -362,11 +454,15 @@ fn test_combination( _ => unreachable!(), } - let handle = handle.take().unwrap(); + let mut handle = handle.take().unwrap(); if abort == CombiAbort::AbortedAfterConsumeOutput { - handle.abort(); + do_abort(&mut abort_handle, Some(&mut handle)); } drop(handle); + + if ah == Some(CombiJoinHandle::DropAfterConsume) { + drop(abort_handle.take()); + } } // The output should have been dropped now. Check whether the output diff --git a/tokio/src/task/join_set.rs b/tokio/src/task/join_set.rs index a2f848391e1..996ae1a9219 100644 --- a/tokio/src/task/join_set.rs +++ b/tokio/src/task/join_set.rs @@ -4,7 +4,7 @@ use std::pin::Pin; use std::task::{Context, Poll}; use crate::runtime::Handle; -use crate::task::{JoinError, JoinHandle, LocalSet}; +use crate::task::{AbortHandle, JoinError, JoinHandle, LocalSet}; use crate::util::IdleNotifiedSet; /// A collection of tasks spawned on a Tokio runtime. @@ -73,61 +73,76 @@ impl JoinSet { } impl JoinSet { - /// Spawn the provided task on the `JoinSet`. + /// Spawn the provided task on the `JoinSet`, returning an [`AbortHandle`] + /// that can be used to remotely cancel the task. /// /// # Panics /// /// This method panics if called outside of a Tokio runtime. - pub fn spawn(&mut self, task: F) + /// + /// [`AbortHandle`]: crate::task::AbortHandle + pub fn spawn(&mut self, task: F) -> AbortHandle where F: Future, F: Send + 'static, T: Send, { - self.insert(crate::spawn(task)); + self.insert(crate::spawn(task)) } - /// Spawn the provided task on the provided runtime and store it in this `JoinSet`. - pub fn spawn_on(&mut self, task: F, handle: &Handle) + /// Spawn the provided task on the provided runtime and store it in this + /// `JoinSet` returning an [`AbortHandle`] that can be used to remotely + /// cancel the task. + /// + /// [`AbortHandle`]: crate::task::AbortHandle + pub fn spawn_on(&mut self, task: F, handle: &Handle) -> AbortHandle where F: Future, F: Send + 'static, T: Send, { - self.insert(handle.spawn(task)); + self.insert(handle.spawn(task)) } - /// Spawn the provided task on the current [`LocalSet`] and store it in this `JoinSet`. + /// Spawn the provided task on the current [`LocalSet`] and store it in this + /// `JoinSet`, returning an [`AbortHandle`] that can be used to remotely + /// cancel the task. /// /// # Panics /// /// This method panics if it is called outside of a `LocalSet`. /// /// [`LocalSet`]: crate::task::LocalSet - pub fn spawn_local(&mut self, task: F) + /// [`AbortHandle`]: crate::task::AbortHandle + pub fn spawn_local(&mut self, task: F) -> AbortHandle where F: Future, F: 'static, { - self.insert(crate::task::spawn_local(task)); + self.insert(crate::task::spawn_local(task)) } - /// Spawn the provided task on the provided [`LocalSet`] and store it in this `JoinSet`. + /// Spawn the provided task on the provided [`LocalSet`] and store it in + /// this `JoinSet`, returning an [`AbortHandle`] that can be used to + /// remotely cancel the task. /// /// [`LocalSet`]: crate::task::LocalSet - pub fn spawn_local_on(&mut self, task: F, local_set: &LocalSet) + /// [`AbortHandle`]: crate::task::AbortHandle + pub fn spawn_local_on(&mut self, task: F, local_set: &LocalSet) -> AbortHandle where F: Future, F: 'static, { - self.insert(local_set.spawn_local(task)); + self.insert(local_set.spawn_local(task)) } - fn insert(&mut self, jh: JoinHandle) { + fn insert(&mut self, jh: JoinHandle) -> AbortHandle { + let abort = jh.abort_handle(); let mut entry = self.inner.insert_idle(jh); // Set the waker that is notified when the task completes. entry.with_value_and_context(|jh, ctx| jh.set_join_waker(ctx.waker())); + abort } /// Waits until one of the tasks in the set completes and returns its output. diff --git a/tokio/src/task/mod.rs b/tokio/src/task/mod.rs index d532155a1fe..bf5530b8339 100644 --- a/tokio/src/task/mod.rs +++ b/tokio/src/task/mod.rs @@ -303,6 +303,7 @@ cfg_rt! { cfg_unstable! { mod join_set; pub use join_set::JoinSet; + pub use crate::runtime::task::AbortHandle; } cfg_trace! { diff --git a/tokio/tests/task_join_set.rs b/tokio/tests/task_join_set.rs index 66a2fbb021d..470f861fe9b 100644 --- a/tokio/tests/task_join_set.rs +++ b/tokio/tests/task_join_set.rs @@ -106,6 +106,40 @@ async fn alternating() { } } +#[tokio::test(start_paused = true)] +async fn abort_tasks() { + let mut set = JoinSet::new(); + let mut num_canceled = 0; + let mut num_completed = 0; + for i in 0..16 { + let abort = set.spawn(async move { + tokio::time::sleep(Duration::from_secs(i as u64)).await; + i + }); + + if i % 2 != 0 { + // abort odd-numbered tasks. + abort.abort(); + } + } + loop { + match set.join_one().await { + Ok(Some(res)) => { + num_completed += 1; + assert_eq!(res % 2, 0); + } + Err(e) => { + assert!(e.is_cancelled()); + num_canceled += 1; + } + Ok(None) => break, + } + } + + assert_eq!(num_canceled, 8); + assert_eq!(num_completed, 8); +} + #[test] fn runtime_gone() { let mut set = JoinSet::new();