From c4fa5c9c62bf107a65e8129431214252835b2b34 Mon Sep 17 00:00:00 2001 From: AzureMarker Date: Tue, 21 Dec 2021 12:22:13 -0500 Subject: [PATCH] Use a robust abort handling mechanism --- tokio-util/Cargo.toml | 2 +- tokio-util/src/task/spawn_pinned.rs | 83 ++++++++++------------------- 2 files changed, 29 insertions(+), 56 deletions(-) diff --git a/tokio-util/Cargo.toml b/tokio-util/Cargo.toml index c696238c145..e00cdb8cce6 100644 --- a/tokio-util/Cargo.toml +++ b/tokio-util/Cargo.toml @@ -29,7 +29,7 @@ codec = [] time = ["tokio/time","slab"] io = [] io-util = ["io", "tokio/rt", "tokio/io-util"] -rt = ["tokio/rt", "tokio/sync"] +rt = ["tokio/rt", "tokio/sync", "futures-util"] __docs_rs = ["futures-util"] diff --git a/tokio-util/src/task/spawn_pinned.rs b/tokio-util/src/task/spawn_pinned.rs index 0d6a7a06c57..cfe43a034ea 100644 --- a/tokio-util/src/task/spawn_pinned.rs +++ b/tokio-util/src/task/spawn_pinned.rs @@ -1,10 +1,9 @@ +use futures_util::future::{AbortHandle, Abortable}; use std::fmt; use std::fmt::{Debug, Formatter}; use std::future::Future; -use std::pin::Pin; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; -use std::task::{Context, Poll}; use tokio::runtime::Builder; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::sync::oneshot; @@ -96,7 +95,7 @@ impl LocalPool { let (sender, receiver) = oneshot::channel(); let worker = self.find_and_incr_least_burdened_worker(); - let job_guard = JobGuard(Arc::clone(&worker.task_count)); + let job_guard = JobCountGuard(Arc::clone(&worker.task_count)); let worker_spawner = worker.spawner.clone(); // Spawn a future onto the worker's runtime so can immediately return @@ -105,12 +104,19 @@ impl LocalPool { // Move the job guard into the task let _job_guard = job_guard; + // Propagate aborts via Abortable/AbortHandle + let (abort_handle, abort_registration) = AbortHandle::new_pair(); + let _abort_guard = AbortGuard(abort_handle); + // Inside the future we can't run spawn_local yet because we're not // in the context of a LocalSet. We need to send create_task to the // LocalSet task for spawning. let spawn_task = Box::new(move || { // Once we're in the LocalSet context we can call spawn_local - let join_handle = spawn_local(async move { create_task().await }); + let join_handle = + spawn_local( + async move { Abortable::new(create_task(), abort_registration).await }, + ); // Send the join handle back to the spawner. If sending fails, // we assume the parent task was canceled, so cancel this task @@ -126,9 +132,8 @@ impl LocalPool { panic!("Failed to send job to worker: {}", e); } - // Wait for the task's join handle. Forward task cancellation in - // case this task gets canceled (via ReceiverCancelGuard). - let join_handle = match ReceiverCancelGuard(receiver).await { + // Wait for the task's join handle + let join_handle = match receiver.await { Ok(handle) => handle, Err(e) => { // We sent the task successfully, but failed to get its @@ -139,12 +144,19 @@ impl LocalPool { } }; - // Wait for the task to complete. Forward task cancellation in case - // this task gets canceled. - let join_result = JoinHandleCancelGuard(join_handle).await; + // Wait for the task to complete + let join_result = join_handle.await; match join_result { - Ok(output) => output, + Ok(Ok(output)) => output, + Ok(Err(_)) => { + // Pinned task was aborted. But that only happens if this + // task is aborted. So this is an impossible branch. + unreachable!( + "Reaching this branch means this task was previously \ + aborted but it continued running anyways" + ) + } Err(e) => { if e.is_panic() { std::panic::resume_unwind(e.into_panic()); @@ -196,63 +208,24 @@ impl LocalPool { /// Automatically decrements a worker's job count when a job finishes (when /// this gets dropped). -struct JobGuard(Arc); +struct JobCountGuard(Arc); -impl Drop for JobGuard { +impl Drop for JobCountGuard { fn drop(&mut self) { // Decrement the job count self.0.fetch_sub(1, Ordering::SeqCst); } } -/// Automatically abort/cancel the task when this guard gets dropped. This will -/// forward a cancellation from one task to another. -/// -/// This implements Future by polling the join handle, so just await it. -struct JoinHandleCancelGuard(JoinHandle); - -impl Future for JoinHandleCancelGuard { - type Output = as Future>::Output; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let join_handle = Pin::new(&mut self.0); - join_handle.poll(cx) - } -} +/// Calls abort on the handle when dropped. +struct AbortGuard(AbortHandle); -impl Drop for JoinHandleCancelGuard { +impl Drop for AbortGuard { fn drop(&mut self) { - // Attempt to abort the task. This does nothing if the task has already - // completed. self.0.abort(); } } -/// If the task is canceled while waiting for the join handle, this guard will -/// check if the join handle was sent (in-transit so it wasn't aborted on the -/// worker side) and abort it if so. -struct ReceiverCancelGuard(oneshot::Receiver>); - -impl Future for ReceiverCancelGuard { - type Output = > as Future>::Output; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let receiver = Pin::new(&mut self.0); - receiver.poll(cx) - } -} - -impl Drop for ReceiverCancelGuard { - fn drop(&mut self) { - // If task is canceled while waiting for the join handle, and the join - // handle was already "sent" by the worker, then it's in a limbo state - // and needs to be manually canceled here. - if let Ok(join_handle) = self.0.try_recv() { - join_handle.abort(); - } - } -} - type PinnedFutureSpawner = Box; struct LocalWorkerHandle {