diff --git a/tokio-util/Cargo.toml b/tokio-util/Cargo.toml index 85df32c2655..e00cdb8cce6 100644 --- a/tokio-util/Cargo.toml +++ b/tokio-util/Cargo.toml @@ -29,7 +29,7 @@ codec = [] time = ["tokio/time","slab"] io = [] io-util = ["io", "tokio/rt", "tokio/io-util"] -rt = ["tokio/rt"] +rt = ["tokio/rt", "tokio/sync", "futures-util"] __docs_rs = ["futures-util"] diff --git a/tokio-util/src/lib.rs b/tokio-util/src/lib.rs index 3786a4002db..28bdb3dd1ef 100644 --- a/tokio-util/src/lib.rs +++ b/tokio-util/src/lib.rs @@ -42,6 +42,7 @@ cfg_io! { cfg_rt! { pub mod context; + pub mod task; } cfg_time! { diff --git a/tokio-util/src/task/mod.rs b/tokio-util/src/task/mod.rs new file mode 100644 index 00000000000..5aa33df2dc0 --- /dev/null +++ b/tokio-util/src/task/mod.rs @@ -0,0 +1,4 @@ +//! Extra utilities for spawning tasks + +mod spawn_pinned; +pub use spawn_pinned::LocalPoolHandle; diff --git a/tokio-util/src/task/spawn_pinned.rs b/tokio-util/src/task/spawn_pinned.rs new file mode 100644 index 00000000000..6f553e9d07a --- /dev/null +++ b/tokio-util/src/task/spawn_pinned.rs @@ -0,0 +1,307 @@ +use futures_util::future::{AbortHandle, Abortable}; +use std::fmt; +use std::fmt::{Debug, Formatter}; +use std::future::Future; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use tokio::runtime::Builder; +use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; +use tokio::sync::oneshot; +use tokio::task::{spawn_local, JoinHandle, LocalSet}; + +/// A handle to a local pool, used for spawning `!Send` tasks. +#[derive(Clone)] +pub struct LocalPoolHandle { + pool: Arc, +} + +impl LocalPoolHandle { + /// Create a new pool of threads to handle `!Send` tasks. Spawn tasks onto this + /// pool via [`LocalPoolHandle::spawn_pinned`]. + /// + /// # Panics + /// Panics if the pool size is less than one. + pub fn new(pool_size: usize) -> LocalPoolHandle { + assert!(pool_size > 0); + + let workers = (0..pool_size) + .map(|_| LocalWorkerHandle::new_worker()) + .collect(); + + let pool = Arc::new(LocalPool { workers }); + + LocalPoolHandle { pool } + } + + /// Spawn a task onto a worker thread and pin it there so it can't be moved + /// off of the thread. Note that the future is not [`Send`], but the + /// [`FnOnce`] which creates it is. + /// + /// # Examples + /// ``` + /// use std::rc::Rc; + /// use tokio_util::task::LocalPoolHandle; + /// + /// #[tokio::main] + /// async fn main() { + /// // Create the local pool + /// let pool = LocalPoolHandle::new(1); + /// + /// // Spawn a !Send future onto the pool and await it + /// let output = pool + /// .spawn_pinned(|| { + /// // Rc is !Send + !Sync + /// let local_data = Rc::new("test"); + /// + /// // This future holds an Rc, so it is !Send + /// async move { local_data.to_string() } + /// }) + /// .await + /// .unwrap(); + /// + /// assert_eq!(output, "test"); + /// } + /// ``` + pub fn spawn_pinned(&self, create_task: F) -> JoinHandle + where + F: FnOnce() -> Fut, + F: Send + 'static, + Fut: Future + 'static, + Fut::Output: Send + 'static, + { + self.pool.spawn_pinned(create_task) + } +} + +impl Debug for LocalPoolHandle { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str("LocalPoolHandle") + } +} + +struct LocalPool { + workers: Vec, +} + +impl LocalPool { + /// Spawn a `?Send` future onto a worker + fn spawn_pinned(&self, create_task: F) -> JoinHandle + where + F: FnOnce() -> Fut, + F: Send + 'static, + Fut: Future + 'static, + Fut::Output: Send + 'static, + { + let (sender, receiver) = oneshot::channel(); + + let (worker, job_guard) = self.find_and_incr_least_burdened_worker(); + let worker_spawner = worker.spawner.clone(); + + // Spawn a future onto the worker's runtime so we can immediately return + // a join handle. + worker.runtime_handle.spawn(async move { + // Move the job guard into the task + let _job_guard = job_guard; + + // Propagate aborts via Abortable/AbortHandle + let (abort_handle, abort_registration) = AbortHandle::new_pair(); + let _abort_guard = AbortGuard(abort_handle); + + // Inside the future we can't run spawn_local yet because we're not + // in the context of a LocalSet. We need to send create_task to the + // LocalSet task for spawning. + let spawn_task = Box::new(move || { + // Once we're in the LocalSet context we can call spawn_local + let join_handle = + spawn_local( + async move { Abortable::new(create_task(), abort_registration).await }, + ); + + // Send the join handle back to the spawner. If sending fails, + // we assume the parent task was canceled, so cancel this task + // as well. + if let Err(join_handle) = sender.send(join_handle) { + join_handle.abort() + } + }); + + // Send the callback to the LocalSet task + if let Err(e) = worker_spawner.send(spawn_task) { + // Propagate the error as a panic in the join handle. + panic!("Failed to send job to worker: {}", e); + } + + // Wait for the task's join handle + let join_handle = match receiver.await { + Ok(handle) => handle, + Err(e) => { + // We sent the task successfully, but failed to get its + // join handle... We assume something happened to the worker + // and the task was not spawned. Propagate the error as a + // panic in the join handle. + panic!("Worker failed to send join handle: {}", e); + } + }; + + // Wait for the task to complete + let join_result = join_handle.await; + + match join_result { + Ok(Ok(output)) => output, + Ok(Err(_)) => { + // Pinned task was aborted. But that only happens if this + // task is aborted. So this is an impossible branch. + unreachable!( + "Reaching this branch means this task was previously \ + aborted but it continued running anyways" + ) + } + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else if e.is_cancelled() { + // No one else should have the join handle, so this is + // unexpected. Forward this error as a panic in the join + // handle. + panic!("spawn_pinned task was canceled: {}", e); + } else { + // Something unknown happened (not a panic or + // cancellation). Forward this error as a panic in the + // join handle. + panic!("spawn_pinned task failed: {}", e); + } + } + } + }) + } + + /// Find the worker with the least number of tasks, increment its task + /// count, and return its handle. Make sure to actually spawn a task on + /// the worker so the task count is kept consistent with load. + /// + /// A job count guard is also returned to ensure the task count gets + /// decremented when the job is done. + fn find_and_incr_least_burdened_worker(&self) -> (&LocalWorkerHandle, JobCountGuard) { + loop { + let (worker, task_count) = self + .workers + .iter() + .map(|worker| (worker, worker.task_count.load(Ordering::SeqCst))) + .min_by_key(|&(_, count)| count) + .expect("There must be more than one worker"); + + // Make sure the task count hasn't changed since when we choose this + // worker. Otherwise, restart the search. + if worker + .task_count + .compare_exchange( + task_count, + task_count + 1, + Ordering::SeqCst, + Ordering::Relaxed, + ) + .is_ok() + { + return (worker, JobCountGuard(Arc::clone(&worker.task_count))); + } + } + } +} + +/// Automatically decrements a worker's job count when a job finishes (when +/// this gets dropped). +struct JobCountGuard(Arc); + +impl Drop for JobCountGuard { + fn drop(&mut self) { + // Decrement the job count + let previous_value = self.0.fetch_sub(1, Ordering::SeqCst); + debug_assert!(previous_value >= 1); + } +} + +/// Calls abort on the handle when dropped. +struct AbortGuard(AbortHandle); + +impl Drop for AbortGuard { + fn drop(&mut self) { + self.0.abort(); + } +} + +type PinnedFutureSpawner = Box; + +struct LocalWorkerHandle { + runtime_handle: tokio::runtime::Handle, + spawner: UnboundedSender, + task_count: Arc, +} + +impl LocalWorkerHandle { + /// Create a new worker for executing pinned tasks + fn new_worker() -> LocalWorkerHandle { + let (sender, receiver) = unbounded_channel(); + let runtime = Builder::new_current_thread() + .enable_all() + .build() + .expect("Failed to start a pinned worker thread runtime"); + let runtime_handle = runtime.handle().clone(); + let task_count = Arc::new(AtomicUsize::new(0)); + let task_count_clone = Arc::clone(&task_count); + + std::thread::spawn(|| Self::run(runtime, receiver, task_count_clone)); + + LocalWorkerHandle { + runtime_handle, + spawner: sender, + task_count, + } + } + + fn run( + runtime: tokio::runtime::Runtime, + mut task_receiver: UnboundedReceiver, + task_count: Arc, + ) { + let local_set = LocalSet::new(); + local_set.block_on(&runtime, async { + while let Some(spawn_task) = task_receiver.recv().await { + // Calls spawn_local(future) + (spawn_task)(); + } + }); + + // If there are any tasks on the runtime associated with a LocalSet task + // that has already completed, but whose output has not yet been + // reported, let that task complete. + // + // Since the task_count is decremented when the runtime task exits, + // reading that counter lets us know if any such tasks completed during + // the call to `block_on`. + // + // Tasks on the LocalSet can't complete during this loop since they're + // stored on the LocalSet and we aren't accessing it. + let mut previous_task_count = task_count.load(Ordering::SeqCst); + loop { + // This call will also run tasks spawned on the runtime. + runtime.block_on(tokio::task::yield_now()); + let new_task_count = task_count.load(Ordering::SeqCst); + if new_task_count == previous_task_count { + break; + } else { + previous_task_count = new_task_count; + } + } + + // It's now no longer possible for a task on the runtime to be + // associated with a LocalSet task that has completed. Drop both the + // LocalSet and runtime to let tasks on the runtime be cancelled if and + // only if they are still on the LocalSet. + // + // Drop the LocalSet task first so that anyone awaiting the runtime + // JoinHandle will see the cancelled error after the LocalSet task + // destructor has completed. + drop(local_set); + drop(runtime); + } +} diff --git a/tokio-util/tests/spawn_pinned.rs b/tokio-util/tests/spawn_pinned.rs new file mode 100644 index 00000000000..409b8dadab5 --- /dev/null +++ b/tokio-util/tests/spawn_pinned.rs @@ -0,0 +1,193 @@ +#![warn(rust_2018_idioms)] + +use std::rc::Rc; +use std::sync::Arc; +use tokio_util::task; + +/// Simple test of running a !Send future via spawn_pinned +#[tokio::test] +async fn can_spawn_not_send_future() { + let pool = task::LocalPoolHandle::new(1); + + let output = pool + .spawn_pinned(|| { + // Rc is !Send + !Sync + let local_data = Rc::new("test"); + + // This future holds an Rc, so it is !Send + async move { local_data.to_string() } + }) + .await + .unwrap(); + + assert_eq!(output, "test"); +} + +/// Dropping the join handle still lets the task execute +#[test] +fn can_drop_future_and_still_get_output() { + let pool = task::LocalPoolHandle::new(1); + let (sender, receiver) = std::sync::mpsc::channel(); + + let _ = pool.spawn_pinned(move || { + // Rc is !Send + !Sync + let local_data = Rc::new("test"); + + // This future holds an Rc, so it is !Send + async move { + let _ = sender.send(local_data.to_string()); + } + }); + + assert_eq!(receiver.recv(), Ok("test".to_string())); +} + +#[test] +#[should_panic(expected = "assertion failed: pool_size > 0")] +fn cannot_create_zero_sized_pool() { + let _pool = task::LocalPoolHandle::new(0); +} + +/// We should be able to spawn multiple futures onto the pool at the same time. +#[tokio::test] +async fn can_spawn_multiple_futures() { + let pool = task::LocalPoolHandle::new(2); + + let join_handle1 = pool.spawn_pinned(|| { + let local_data = Rc::new("test1"); + async move { local_data.to_string() } + }); + let join_handle2 = pool.spawn_pinned(|| { + let local_data = Rc::new("test2"); + async move { local_data.to_string() } + }); + + assert_eq!(join_handle1.await.unwrap(), "test1"); + assert_eq!(join_handle2.await.unwrap(), "test2"); +} + +/// A panic in the spawned task causes the join handle to return an error. +/// But, you can continue to spawn tasks. +#[tokio::test] +async fn task_panic_propagates() { + let pool = task::LocalPoolHandle::new(1); + + let join_handle = pool.spawn_pinned(|| async { + panic!("Test panic"); + }); + + let result = join_handle.await; + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.is_panic()); + let panic_str: &str = *error.into_panic().downcast().unwrap(); + assert_eq!(panic_str, "Test panic"); + + // Trying again with a "safe" task still works + let join_handle = pool.spawn_pinned(|| async { "test" }); + let result = join_handle.await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "test"); +} + +/// A panic during task creation causes the join handle to return an error. +/// But, you can continue to spawn tasks. +#[tokio::test] +async fn callback_panic_does_not_kill_worker() { + let pool = task::LocalPoolHandle::new(1); + + let join_handle = pool.spawn_pinned(|| { + panic!("Test panic"); + #[allow(unreachable_code)] + async {} + }); + + let result = join_handle.await; + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.is_panic()); + let panic_str: &str = *error.into_panic().downcast().unwrap(); + assert_eq!(panic_str, "Test panic"); + + // Trying again with a "safe" callback works + let join_handle = pool.spawn_pinned(|| async { "test" }); + let result = join_handle.await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "test"); +} + +/// Canceling the task via the returned join handle cancels the spawned task +/// (which has a different, internal join handle). +#[tokio::test] +async fn task_cancellation_propagates() { + let pool = task::LocalPoolHandle::new(1); + let notify_dropped = Arc::new(()); + let weak_notify_dropped = Arc::downgrade(¬ify_dropped); + + let (start_sender, start_receiver) = tokio::sync::oneshot::channel(); + let (drop_sender, drop_receiver) = tokio::sync::oneshot::channel::<()>(); + let join_handle = pool.spawn_pinned(|| async move { + let _drop_sender = drop_sender; + // Move the Arc into the task + let _notify_dropped = notify_dropped; + let _ = start_sender.send(()); + + // Keep the task running until it gets aborted + futures::future::pending::<()>().await; + }); + + // Wait for the task to start + let _ = start_receiver.await; + + join_handle.abort(); + + // Wait for the inner task to abort, dropping the sender. + // The top level join handle aborts quicker than the inner task (the abort + // needs to propagate and get processed on the worker thread), so we can't + // just await the top level join handle. + let _ = drop_receiver.await; + + // Check that the Arc has been dropped. This verifies that the inner task + // was canceled as well. + assert!(weak_notify_dropped.upgrade().is_none()); +} + +/// Tasks should be given to the least burdened worker. When spawning two tasks +/// on a pool with two empty workers the tasks should be spawned on separate +/// workers. +#[tokio::test] +async fn tasks_are_balanced() { + let pool = task::LocalPoolHandle::new(2); + + // Spawn a task so one thread has a task count of 1 + let (start_sender1, start_receiver1) = tokio::sync::oneshot::channel(); + let (end_sender1, end_receiver1) = tokio::sync::oneshot::channel(); + let join_handle1 = pool.spawn_pinned(|| async move { + let _ = start_sender1.send(()); + let _ = end_receiver1.await; + std::thread::current().id() + }); + + // Wait for the first task to start up + let _ = start_receiver1.await; + + // This task should be spawned on the other thread + let (start_sender2, start_receiver2) = tokio::sync::oneshot::channel(); + let join_handle2 = pool.spawn_pinned(|| async move { + let _ = start_sender2.send(()); + std::thread::current().id() + }); + + // Wait for the second task to start up + let _ = start_receiver2.await; + + // Allow the first task to end + let _ = end_sender1.send(()); + + let thread_id1 = join_handle1.await.unwrap(); + let thread_id2 = join_handle2.await.unwrap(); + + // Since the first task was active when the second task spawned, they should + // be on separate workers/threads. + assert_ne!(thread_id1, thread_id2); +}