diff --git a/tokio/src/runtime/basic_scheduler.rs b/tokio/src/runtime/basic_scheduler.rs index 9efe3844a39..09a3d7dd718 100644 --- a/tokio/src/runtime/basic_scheduler.rs +++ b/tokio/src/runtime/basic_scheduler.rs @@ -311,6 +311,9 @@ impl Drop for BasicScheduler

{ }; enter(&mut inner, |scheduler, context| { + // By closing the OwnedTasks, no new tasks can be spawned on it. + context.shared.owned.close(); + // Drain the OwnedTasks collection. while let Some(task) = context.shared.owned.pop_back() { task.shutdown(); } @@ -354,14 +357,18 @@ impl fmt::Debug for BasicScheduler

{ // ===== impl Spawner ===== impl Spawner { - /// Spawns a future onto the thread pool + /// Spawns a future onto the basic scheduler pub(crate) fn spawn(&self, future: F) -> JoinHandle where F: crate::future::Future + Send + 'static, F::Output: Send + 'static, { - let (task, handle) = task::joinable(future); - self.shared.schedule(task); + let (handle, notified) = self.shared.owned.bind(future, self.shared.clone()); + + if let Some(notified) = notified { + self.shared.schedule(notified); + } + handle } @@ -392,14 +399,6 @@ impl fmt::Debug for Spawner { // ===== impl Shared ===== impl Schedule for Arc { - fn bind(task: Task) -> Arc { - CURRENT.with(|maybe_cx| { - let cx = maybe_cx.expect("scheduler context missing"); - cx.shared.owned.push_front(task); - cx.shared.clone() - }) - } - fn release(&self, task: &Task) -> Option> { // SAFETY: Inserted into the list in bind above. unsafe { self.owned.remove(task) } @@ -411,16 +410,13 @@ impl Schedule for Arc { cx.tasks.borrow_mut().queue.push_back(task); } _ => { + // If the queue is None, then the runtime has shut down. We + // don't need to do anything with the notification in that case. let mut guard = self.queue.lock(); if let Some(queue) = guard.as_mut() { queue.push_back(RemoteMsg::Schedule(task)); drop(guard); self.unpark.unpark(); - } else { - // The runtime has shut down. We drop the new task - // immediately. - drop(guard); - task.shutdown(); } } }); diff --git a/tokio/src/runtime/blocking/mod.rs b/tokio/src/runtime/blocking/mod.rs index fece3c279d8..670ec3a4b34 100644 --- a/tokio/src/runtime/blocking/mod.rs +++ b/tokio/src/runtime/blocking/mod.rs @@ -8,7 +8,9 @@ pub(crate) use pool::{spawn_blocking, BlockingPool, Spawner}; mod schedule; mod shutdown; -pub(crate) mod task; +mod task; +pub(crate) use schedule::NoopSchedule; +pub(crate) use task::BlockingTask; use crate::runtime::Builder; diff --git a/tokio/src/runtime/blocking/schedule.rs b/tokio/src/runtime/blocking/schedule.rs index 4e044ab2987..54252241d94 100644 --- a/tokio/src/runtime/blocking/schedule.rs +++ b/tokio/src/runtime/blocking/schedule.rs @@ -9,11 +9,6 @@ use crate::runtime::task::{self, Task}; pub(crate) struct NoopSchedule; impl task::Schedule for NoopSchedule { - fn bind(_task: Task) -> NoopSchedule { - // Do nothing w/ the task - NoopSchedule - } - fn release(&self, _task: &Task) -> Option> { None } diff --git a/tokio/src/runtime/handle.rs b/tokio/src/runtime/handle.rs index 7dff91448f1..0bfed988a9f 100644 --- a/tokio/src/runtime/handle.rs +++ b/tokio/src/runtime/handle.rs @@ -1,4 +1,4 @@ -use crate::runtime::blocking::task::BlockingTask; +use crate::runtime::blocking::{BlockingTask, NoopSchedule}; use crate::runtime::task::{self, JoinHandle}; use crate::runtime::{blocking, context, driver, Spawner}; use crate::util::error::CONTEXT_MISSING_ERROR; @@ -213,7 +213,7 @@ impl Handle { #[cfg(not(all(tokio_unstable, feature = "tracing")))] let _ = name; - let (task, handle) = task::joinable(fut); + let (task, handle) = task::unowned(fut, NoopSchedule); let _ = self.blocking_spawner.spawn(task, &self); handle } diff --git a/tokio/src/runtime/queue.rs b/tokio/src/runtime/queue.rs index 6e91dfa2363..c45cb6a5a3d 100644 --- a/tokio/src/runtime/queue.rs +++ b/tokio/src/runtime/queue.rs @@ -106,13 +106,8 @@ impl Local { break tail; } else if steal != real { // Concurrently stealing, this will free up capacity, so only - // push the new task onto the inject queue - // - // If the task fails to be pushed on the injection queue, there - // is nothing to be done at this point as the task cannot be a - // newly spawned task. Shutting down this task is handled by the - // worker shutdown process. - let _ = inject.push(task); + // push the task onto the inject queue + inject.push(task); return; } else { // Push the current task and half of the queue into the diff --git a/tokio/src/runtime/task/core.rs b/tokio/src/runtime/task/core.rs index e4624c7b709..06848e73460 100644 --- a/tokio/src/runtime/task/core.rs +++ b/tokio/src/runtime/task/core.rs @@ -93,7 +93,7 @@ pub(super) enum Stage { impl Cell { /// Allocates a new task cell, containing the header, trailer, and core /// structures. - pub(super) fn new(future: T, state: State) -> Box> { + pub(super) fn new(future: T, scheduler: S, state: State) -> Box> { #[cfg(all(tokio_unstable, feature = "tracing"))] let id = future.id(); Box::new(Cell { @@ -107,7 +107,7 @@ impl Cell { }, core: Core { scheduler: Scheduler { - scheduler: UnsafeCell::new(None), + scheduler: UnsafeCell::new(Some(scheduler)), }, stage: CoreStage { stage: UnsafeCell::new(Stage::Running(future)), @@ -125,34 +125,6 @@ impl Scheduler { self.scheduler.with_mut(f) } - /// Bind a scheduler to the task. - /// - /// This only happens on the first poll and must be preceded by a call to - /// `is_bound` to determine if binding is appropriate or not. - /// - /// # Safety - /// - /// Binding must not be done concurrently since it will mutate the task - /// core through a shared reference. - pub(super) fn bind_scheduler(&self, task: Task) { - // This function may be called concurrently, but the __first__ time it - // is called, the caller has unique access to this field. All subsequent - // concurrent calls will be via the `Waker`, which will "happens after" - // the first poll. - // - // In other words, it is always safe to read the field and it is safe to - // write to the field when it is `None`. - debug_assert!(!self.is_bound()); - - // Bind the task to the scheduler - let scheduler = S::bind(task); - - // Safety: As `scheduler` is not set, this is the first poll - self.scheduler.with_mut(|ptr| unsafe { - *ptr = Some(scheduler); - }); - } - /// Returns true if the task is bound to a scheduler. pub(super) fn is_bound(&self) -> bool { // Safety: never called concurrently w/ a mutation. diff --git a/tokio/src/runtime/task/harness.rs b/tokio/src/runtime/task/harness.rs index 9f0b1071130..7d19f3d6488 100644 --- a/tokio/src/runtime/task/harness.rs +++ b/tokio/src/runtime/task/harness.rs @@ -254,16 +254,13 @@ where } fn transition_to_running(&self) -> TransitionToRunning { - // If this is the first time the task is polled, the task will be bound - // to the scheduler, in which case the task ref count must be - // incremented. - let is_not_bound = !self.scheduler.is_bound(); + debug_assert!(self.scheduler.is_bound()); // Transition the task to the running state. // // A failure to transition here indicates the task has been cancelled // while in the run queue pending execution. - let snapshot = match self.header.state.transition_to_running(is_not_bound) { + let snapshot = match self.header.state.transition_to_running() { Ok(snapshot) => snapshot, Err(_) => { // The task was shutdown while in the run queue. At this point, @@ -273,20 +270,6 @@ where } }; - if is_not_bound { - // Ensure the task is bound to a scheduler instance. Since this is - // the first time polling the task, a scheduler instance is pulled - // from the local context and assigned to the task. - // - // The scheduler maintains ownership of the task and responds to - // `wake` calls. - // - // The task reference count has been incremented. - // - // Safety: Since we have unique access to the task so that we can - // safely call `bind_scheduler`. - self.scheduler.bind_scheduler(self.to_task()); - } TransitionToRunning::Ok(snapshot) } } diff --git a/tokio/src/runtime/task/inject.rs b/tokio/src/runtime/task/inject.rs index 8ca3187a722..640da648f74 100644 --- a/tokio/src/runtime/task/inject.rs +++ b/tokio/src/runtime/task/inject.rs @@ -75,15 +75,13 @@ impl Inject { /// Pushes a value into the queue. /// - /// Returns `Err(task)` if pushing fails due to the queue being shutdown. - /// The caller is expected to call `shutdown()` on the task **if and only - /// if** it is a newly spawned task. - pub(crate) fn push(&self, task: task::Notified) -> Result<(), task::Notified> { + /// This does nothing if the queue is closed. + pub(crate) fn push(&self, task: task::Notified) { // Acquire queue lock let mut p = self.pointers.lock(); if p.is_closed { - return Err(task); + return; } // safety: only mutated with the lock held @@ -102,7 +100,6 @@ impl Inject { p.tail = Some(task); self.len.store(len + 1, Release); - Ok(()) } /// Pushes several values into the queue. diff --git a/tokio/src/runtime/task/list.rs b/tokio/src/runtime/task/list.rs index 45e22a72af2..63812f9c76f 100644 --- a/tokio/src/runtime/task/list.rs +++ b/tokio/src/runtime/task/list.rs @@ -1,33 +1,141 @@ +//! This module has containers for storing the tasks spawned on a scheduler. The +//! `OwnedTasks` container is thread-safe but can only store tasks that +//! implement Send. The `LocalOwnedTasks` container is not thread safe, but can +//! store non-Send tasks. +//! +//! The collections can be closed to prevent adding new tasks during shutdown of +//! the scheduler with the collection. + +use crate::future::Future; use crate::loom::sync::Mutex; -use crate::runtime::task::Task; +use crate::runtime::task::{JoinHandle, Notified, Schedule, Task}; use crate::util::linked_list::{Link, LinkedList}; +use std::marker::PhantomData; + pub(crate) struct OwnedTasks { - list: Mutex, as Link>::Target>>, + inner: Mutex>, +} +struct OwnedTasksInner { + list: LinkedList, as Link>::Target>, + closed: bool, +} + +pub(crate) struct LocalOwnedTasks { + list: LinkedList, as Link>::Target>, + closed: bool, + _not_send: PhantomData<*const ()>, } impl OwnedTasks { pub(crate) fn new() -> Self { Self { - list: Mutex::new(LinkedList::new()), + inner: Mutex::new(OwnedTasksInner { + list: LinkedList::new(), + closed: false, + }), } } - pub(crate) fn push_front(&self, task: Task) { - self.list.lock().push_front(task); + /// Bind the provided task to this OwnedTasks instance. This fails if the + /// OwnedTasks has been closed. + pub(crate) fn bind( + &self, + task: T, + scheduler: S, + ) -> (JoinHandle, Option>) + where + S: Schedule, + T: Future + Send + 'static, + T::Output: Send + 'static, + { + let (task, notified, join) = super::new_task(task, scheduler); + + let mut lock = self.inner.lock(); + if lock.closed { + drop(lock); + drop(task); + notified.shutdown(); + (join, None) + } else { + lock.list.push_front(task); + (join, Some(notified)) + } } pub(crate) fn pop_back(&self) -> Option> { - self.list.lock().pop_back() + self.inner.lock().list.pop_back() } /// The caller must ensure that if the provided task is stored in a /// linked list, then it is in this linked list. pub(crate) unsafe fn remove(&self, task: &Task) -> Option> { - self.list.lock().remove(task.header().into()) + self.inner.lock().list.remove(task.header().into()) + } + + pub(crate) fn is_empty(&self) -> bool { + self.inner.lock().list.is_empty() + } + + #[cfg(feature = "rt-multi-thread")] + pub(crate) fn is_closed(&self) -> bool { + self.inner.lock().closed + } + + /// Close the OwnedTasks. This prevents adding new tasks to the collection. + pub(crate) fn close(&self) { + self.inner.lock().closed = true; + } +} + +impl LocalOwnedTasks { + pub(crate) fn new() -> Self { + Self { + list: LinkedList::new(), + closed: false, + _not_send: PhantomData, + } + } + + pub(crate) fn bind( + &mut self, + task: T, + scheduler: S, + ) -> (JoinHandle, Option>) + where + S: Schedule, + T: Future + 'static, + T::Output: 'static, + { + let (task, notified, join) = super::new_task(task, scheduler); + + if self.closed { + drop(task); + notified.shutdown(); + (join, None) + } else { + self.list.push_front(task); + (join, Some(notified)) + } + } + + pub(crate) fn pop_back(&mut self) -> Option> { + self.list.pop_back() + } + + /// The caller must ensure that if the provided task is stored in a + /// linked list, then it is in this linked list. + pub(crate) unsafe fn remove(&mut self, task: &Task) -> Option> { + self.list.remove(task.header().into()) } pub(crate) fn is_empty(&self) -> bool { - self.list.lock().is_empty() + self.list.is_empty() + } + + /// Close the LocalOwnedTasks. This prevents adding new tasks to the + /// collection. + pub(crate) fn close(&mut self) { + self.closed = true; } } diff --git a/tokio/src/runtime/task/mod.rs b/tokio/src/runtime/task/mod.rs index 5e2477906c6..1ca5ba579b5 100644 --- a/tokio/src/runtime/task/mod.rs +++ b/tokio/src/runtime/task/mod.rs @@ -19,7 +19,7 @@ mod join; pub use self::join::JoinHandle; mod list; -pub(crate) use self::list::OwnedTasks; +pub(crate) use self::list::{LocalOwnedTasks, OwnedTasks}; mod raw; use self::raw::RawTask; @@ -57,13 +57,6 @@ unsafe impl Sync for Notified {} pub(crate) type Result = std::result::Result; pub(crate) trait Schedule: Sync + Sized + 'static { - /// Bind a task to the executor. - /// - /// Guaranteed to be called from the thread that called `poll` on the task. - /// The returned `Schedule` instance is associated with the task and is used - /// as `&self` in the other methods on this trait. - fn bind(task: Task) -> Self; - /// The task has completed work and is ready to be released. The scheduler /// should release it immediately and return it. The task module will batch /// the ref-dec with setting other options. @@ -82,42 +75,46 @@ pub(crate) trait Schedule: Sync + Sized + 'static { } cfg_rt! { - /// Create a new task with an associated join handle - pub(crate) fn joinable(task: T) -> (Notified, JoinHandle) + /// This is the constructor for a new task. Three references to the task are + /// created. The first task reference is usually put into an OwnedTasks + /// immediately. The Notified is sent to the scheduler as an ordinary + /// notification. + fn new_task( + task: T, + scheduler: S + ) -> (Task, Notified, JoinHandle) where - T: Future + Send + 'static, S: Schedule, + T: Future + 'static, + T::Output: 'static, { - let raw = RawTask::new::<_, S>(task); - + let raw = RawTask::new::(task, scheduler); let task = Task { raw, _p: PhantomData, }; - + let notified = Notified(Task { + raw, + _p: PhantomData, + }); let join = JoinHandle::new(raw); - (Notified(task), join) + (task, notified, join) } -} -cfg_rt! { - /// Create a new `!Send` task with an associated join handle - pub(crate) unsafe fn joinable_local(task: T) -> (Notified, JoinHandle) + /// Create a new task with an associated join handle. This method is used + /// only when the task is not going to be stored in an `OwnedTasks` list. + /// + /// Currently only blocking tasks and tests use this method. + pub(crate) fn unowned(task: T, scheduler: S) -> (Notified, JoinHandle) where - T: Future + 'static, S: Schedule, + T: Send + Future + 'static, + T::Output: Send + 'static, { - let raw = RawTask::new::<_, S>(task); - - let task = Task { - raw, - _p: PhantomData, - }; - - let join = JoinHandle::new(raw); - - (Notified(task), join) + let (task, notified, join) = new_task(task, scheduler); + drop(task); + (notified, join) } } diff --git a/tokio/src/runtime/task/raw.rs b/tokio/src/runtime/task/raw.rs index 56d65d5a649..8c2c3f73291 100644 --- a/tokio/src/runtime/task/raw.rs +++ b/tokio/src/runtime/task/raw.rs @@ -42,12 +42,12 @@ pub(super) fn vtable() -> &'static Vtable { } impl RawTask { - pub(super) fn new(task: T) -> RawTask + pub(super) fn new(task: T, scheduler: S) -> RawTask where T: Future, S: Schedule, { - let ptr = Box::into_raw(Cell::<_, S>::new(task, State::new())); + let ptr = Box::into_raw(Cell::<_, S>::new(task, scheduler, State::new())); let ptr = unsafe { NonNull::new_unchecked(ptr as *mut Header) }; RawTask { ptr } diff --git a/tokio/src/runtime/task/state.rs b/tokio/src/runtime/task/state.rs index 6037721623f..2580e9c83d2 100644 --- a/tokio/src/runtime/task/state.rs +++ b/tokio/src/runtime/task/state.rs @@ -54,22 +54,23 @@ const REF_ONE: usize = 1 << REF_COUNT_SHIFT; /// State a task is initialized with /// -/// A task is initialized with two references: one for the scheduler and one for -/// the `JoinHandle`. As the task starts with a `JoinHandle`, `JOIN_INTEREST` is -/// set. A new task is immediately pushed into the run queue for execution and -/// starts with the `NOTIFIED` flag set. -const INITIAL_STATE: usize = (REF_ONE * 2) | JOIN_INTEREST | NOTIFIED; +/// A task is initialized with three references: +/// +/// * A reference that will be stored in an OwnedTasks or LocalOwnedTasks. +/// * A reference that will be sent to the scheduler as an ordinary notification. +/// * A reference for the JoinHandle. +/// +/// As the task starts with a `JoinHandle`, `JOIN_INTEREST` is set. +/// As the task starts with a `Notified`, `NOTIFIED` is set. +const INITIAL_STATE: usize = (REF_ONE * 3) | JOIN_INTEREST | NOTIFIED; /// All transitions are performed via RMW operations. This establishes an /// unambiguous modification order. impl State { /// Return a task's initial state pub(super) fn new() -> State { - // A task is initialized with three references: one for the scheduler, - // one for the `JoinHandle`, one for the task handle made available in - // release. As the task starts with a `JoinHandle`, `JOIN_INTEREST` is - // set. A new task is immediately pushed into the run queue for - // execution and starts with the `NOTIFIED` flag set. + // The raw task returned by this method has a ref-count of three. See + // the comment on INITIAL_STATE for more. State { val: AtomicUsize::new(INITIAL_STATE), } @@ -82,10 +83,8 @@ impl State { /// Attempt to transition the lifecycle to `Running`. /// - /// If `ref_inc` is set, the reference count is also incremented. - /// /// The `NOTIFIED` bit is always unset. - pub(super) fn transition_to_running(&self, ref_inc: bool) -> UpdateResult { + pub(super) fn transition_to_running(&self) -> UpdateResult { self.fetch_update(|curr| { assert!(curr.is_notified()); @@ -95,10 +94,6 @@ impl State { return None; } - if ref_inc { - next.ref_inc(); - } - next.set_running(); next.unset_notified(); Some(next) diff --git a/tokio/src/runtime/tests/loom_queue.rs b/tokio/src/runtime/tests/loom_queue.rs index 977c9159b26..a1ed1717b90 100644 --- a/tokio/src/runtime/tests/loom_queue.rs +++ b/tokio/src/runtime/tests/loom_queue.rs @@ -1,5 +1,6 @@ +use crate::runtime::blocking::NoopSchedule; use crate::runtime::queue; -use crate::runtime::task::{self, Inject, Schedule, Task}; +use crate::runtime::task::Inject; use loom::thread; @@ -30,7 +31,7 @@ fn basic() { for _ in 0..2 { for _ in 0..2 { - let (task, _) = super::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); } @@ -39,7 +40,7 @@ fn basic() { } // Push another task - let (task, _) = super::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); while local.pop().is_some() { @@ -81,7 +82,7 @@ fn steal_overflow() { let mut n = 0; // push a task, pop a task - let (task, _) = super::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); if local.pop().is_some() { @@ -89,7 +90,7 @@ fn steal_overflow() { } for _ in 0..6 { - let (task, _) = super::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); } @@ -111,7 +112,7 @@ fn steal_overflow() { fn multi_stealer() { const NUM_TASKS: usize = 5; - fn steal_tasks(steal: queue::Steal) -> usize { + fn steal_tasks(steal: queue::Steal) -> usize { let (_, mut local) = queue::local(); if steal.steal_into(&mut local).is_none() { @@ -133,7 +134,7 @@ fn multi_stealer() { // Push work for _ in 0..NUM_TASKS { - let (task, _) = super::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); } @@ -170,10 +171,10 @@ fn chained_steal() { // Load up some tasks for _ in 0..4 { - let (task, _) = super::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); l1.push_back(task, &inject); - let (task, _) = super::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); l2.push_back(task, &inject); } @@ -197,20 +198,3 @@ fn chained_steal() { while inject.pop().is_some() {} }); } - -struct Runtime; - -impl Schedule for Runtime { - fn bind(task: Task) -> Runtime { - std::mem::forget(task); - Runtime - } - - fn release(&self, _task: &Task) -> Option> { - None - } - - fn schedule(&self, _task: task::Notified) { - unreachable!(); - } -} diff --git a/tokio/src/runtime/tests/mod.rs b/tokio/src/runtime/tests/mod.rs index 596e47dfd00..f3c6a9bd700 100644 --- a/tokio/src/runtime/tests/mod.rs +++ b/tokio/src/runtime/tests/mod.rs @@ -1,21 +1,30 @@ -#[cfg(not(all(tokio_unstable, feature = "tracing")))] -use crate::runtime::task::joinable; +use self::unowned_wrapper::unowned; -#[cfg(all(tokio_unstable, feature = "tracing"))] -use self::joinable_wrapper::joinable; +mod unowned_wrapper { + use crate::runtime::blocking::NoopSchedule; + use crate::runtime::task::{JoinHandle, Notified}; -#[cfg(all(tokio_unstable, feature = "tracing"))] -mod joinable_wrapper { - use crate::runtime::task::{JoinHandle, Notified, Schedule}; - use tracing::Instrument; - - pub(crate) fn joinable(task: T) -> (Notified, JoinHandle) + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(crate) fn unowned(task: T) -> (Notified, JoinHandle) where T: std::future::Future + Send + 'static, - S: Schedule, + T::Output: Send + 'static, { + use tracing::Instrument; let span = tracing::trace_span!("test_span"); - crate::runtime::task::joinable(task.instrument(span)) + let task = task.instrument(span); + let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule); + (task, handle) + } + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + pub(crate) fn unowned(task: T) -> (Notified, JoinHandle) + where + T: std::future::Future + Send + 'static, + T::Output: Send + 'static, + { + let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule); + (task, handle) } } diff --git a/tokio/src/runtime/tests/queue.rs b/tokio/src/runtime/tests/queue.rs index e08dc6d99e6..428b002071a 100644 --- a/tokio/src/runtime/tests/queue.rs +++ b/tokio/src/runtime/tests/queue.rs @@ -10,7 +10,7 @@ fn fits_256() { let inject = Inject::new(); for _ in 0..256 { - let (task, _) = super::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); } @@ -25,7 +25,7 @@ fn overflow() { let inject = Inject::new(); for _ in 0..257 { - let (task, _) = super::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); } @@ -49,7 +49,7 @@ fn steal_batch() { let inject = Inject::new(); for _ in 0..4 { - let (task, _) = super::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local1.push_back(task, &inject); } @@ -103,7 +103,7 @@ fn stress1() { for _ in 0..NUM_LOCAL { for _ in 0..NUM_PUSH { - let (task, _) = super::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); } @@ -158,7 +158,7 @@ fn stress2() { let mut num_pop = 0; for i in 0..NUM_TASKS { - let (task, _) = super::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); if i % 128 == 0 && local.pop().is_some() { @@ -187,11 +187,6 @@ fn stress2() { struct Runtime; impl Schedule for Runtime { - fn bind(task: Task) -> Runtime { - std::mem::forget(task); - Runtime - } - fn release(&self, _task: &Task) -> Option> { None } diff --git a/tokio/src/runtime/tests/task.rs b/tokio/src/runtime/tests/task.rs index 1f3e89d7661..ae5c04c840c 100644 --- a/tokio/src/runtime/tests/task.rs +++ b/tokio/src/runtime/tests/task.rs @@ -1,43 +1,185 @@ -use crate::runtime::task::{self, OwnedTasks, Schedule, Task}; +use crate::runtime::blocking::NoopSchedule; +use crate::runtime::task::{self, unowned, JoinHandle, OwnedTasks, Schedule, Task}; use crate::util::TryLock; use std::collections::VecDeque; +use std::future::Future; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; +struct AssertDropHandle { + is_dropped: Arc, +} +impl AssertDropHandle { + #[track_caller] + fn assert_dropped(&self) { + assert!(self.is_dropped.load(Ordering::SeqCst)); + } + + #[track_caller] + fn assert_not_dropped(&self) { + assert!(!self.is_dropped.load(Ordering::SeqCst)); + } +} + +struct AssertDrop { + is_dropped: Arc, +} +impl AssertDrop { + fn new() -> (Self, AssertDropHandle) { + let shared = Arc::new(AtomicBool::new(false)); + ( + AssertDrop { + is_dropped: shared.clone(), + }, + AssertDropHandle { + is_dropped: shared.clone(), + }, + ) + } +} +impl Drop for AssertDrop { + fn drop(&mut self) { + self.is_dropped.store(true, Ordering::SeqCst); + } +} + +// A Notified does not shut down on drop, but it is dropped once the ref-count +// hits zero. +#[test] +fn create_drop1() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + ); + drop(notified); + handle.assert_not_dropped(); + drop(join); + handle.assert_dropped(); +} + #[test] -fn create_drop() { - let _ = super::joinable::<_, Runtime>(async { unreachable!() }); +fn create_drop2() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + ); + drop(join); + handle.assert_not_dropped(); + drop(notified); + handle.assert_dropped(); +} + +// Shutting down through Notified works +#[test] +fn create_shutdown1() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + ); + drop(join); + handle.assert_not_dropped(); + notified.shutdown(); + handle.assert_dropped(); +} + +#[test] +fn create_shutdown2() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + ); + handle.assert_not_dropped(); + notified.shutdown(); + handle.assert_dropped(); + drop(join); } #[test] fn schedule() { with(|rt| { - let (task, _) = super::joinable(async { + rt.spawn(async { crate::task::yield_now().await; }); - rt.schedule(task); - assert_eq!(2, rt.tick()); + rt.shutdown(); }) } #[test] fn shutdown() { with(|rt| { - let (task, _) = super::joinable(async { + rt.spawn(async { loop { crate::task::yield_now().await; } }); - rt.schedule(task); rt.tick_max(1); rt.shutdown(); }) } +#[test] +fn shutdown_immediately() { + with(|rt| { + rt.spawn(async { + loop { + crate::task::yield_now().await; + } + }); + + rt.shutdown(); + }) +} + +#[test] +fn spawn_during_shutdown() { + static DID_SPAWN: AtomicBool = AtomicBool::new(false); + + struct SpawnOnDrop(Runtime); + impl Drop for SpawnOnDrop { + fn drop(&mut self) { + DID_SPAWN.store(true, Ordering::SeqCst); + self.0.spawn(async {}); + } + } + + with(|rt| { + let rt2 = rt.clone(); + rt.spawn(async move { + let _spawn_on_drop = SpawnOnDrop(rt2); + + loop { + crate::task::yield_now().await; + } + }); + + rt.tick_max(1); + rt.shutdown(); + }); + + assert!(DID_SPAWN.load(Ordering::SeqCst)); +} + fn with(f: impl FnOnce(Runtime)) { struct Reset; @@ -75,6 +217,20 @@ struct Core { static CURRENT: TryLock> = TryLock::new(None); impl Runtime { + fn spawn(&self, future: T) -> JoinHandle + where + T: 'static + Send + Future, + T::Output: 'static + Send, + { + let (handle, notified) = self.0.owned.bind(future, self.clone()); + + if let Some(notified) = notified { + self.schedule(notified); + } + + handle + } + fn tick(&self) -> usize { self.tick_max(usize::MAX) } @@ -102,6 +258,7 @@ impl Runtime { fn shutdown(&self) { let mut core = self.0.core.try_lock().unwrap(); + self.0.owned.close(); while let Some(task) = self.0.owned.pop_back() { task.shutdown(); } @@ -117,12 +274,6 @@ impl Runtime { } impl Schedule for Runtime { - fn bind(task: Task) -> Runtime { - let rt = CURRENT.try_lock().unwrap().as_ref().unwrap().clone(); - rt.0.owned.push_front(task); - rt - } - fn release(&self, task: &Task) -> Option> { // safety: copying worker.rs unsafe { self.0.owned.remove(task) } diff --git a/tokio/src/runtime/thread_pool/mod.rs b/tokio/src/runtime/thread_pool/mod.rs index 96312d34618..3808aa26465 100644 --- a/tokio/src/runtime/thread_pool/mod.rs +++ b/tokio/src/runtime/thread_pool/mod.rs @@ -12,7 +12,7 @@ pub(crate) use worker::Launch; pub(crate) use worker::block_in_place; use crate::loom::sync::Arc; -use crate::runtime::task::{self, JoinHandle}; +use crate::runtime::task::JoinHandle; use crate::runtime::Parker; use std::fmt; @@ -30,7 +30,7 @@ pub(crate) struct ThreadPool { /// /// The `Spawner` handle is *only* used for spawning new futures. It does not /// impact the lifecycle of the thread pool in any way. The thread pool may -/// shutdown while there are outstanding `Spawner` instances. +/// shut down while there are outstanding `Spawner` instances. /// /// `Spawner` instances are obtained by calling [`ThreadPool::spawner`]. /// @@ -93,15 +93,7 @@ impl Spawner { F: crate::future::Future + Send + 'static, F::Output: Send + 'static, { - let (task, handle) = task::joinable(future); - - if let Err(task) = self.shared.schedule(task, false) { - // The newly spawned task could not be scheduled because the runtime - // is shutting down. The task must be explicitly shutdown at this point. - task.shutdown(); - } - - handle + worker::Shared::bind_new_task(&self.shared, future) } pub(crate) fn shutdown(&mut self) { diff --git a/tokio/src/runtime/thread_pool/worker.rs b/tokio/src/runtime/thread_pool/worker.rs index e91e2ea4b34..c0535238ced 100644 --- a/tokio/src/runtime/thread_pool/worker.rs +++ b/tokio/src/runtime/thread_pool/worker.rs @@ -3,15 +3,68 @@ //! run queue and other state. When `block_in_place` is called, the worker's //! "core" is handed off to a new thread allowing the scheduler to continue to //! make progress while the originating thread blocks. +//! +//! # Shutdown +//! +//! Shutting down the runtime involves the following steps: +//! +//! 1. The Shared::close method is called. This closes the inject queue and +//! OwnedTasks instance and wakes up all worker threads. +//! +//! 2. Each worker thread observes the close signal next time it runs +//! Core::maintenance by checking whether the inject queue is closed. +//! The Core::is_shutdown flag is set to true. +//! +//! 3. The worker thread calls `pre_shutdown` in parallel. Here, the worker +//! will keep removing tasks from OwnedTasks until it is empty. No new +//! tasks can be pushed to the OwnedTasks during or after this step as it +//! was closed in step 1. +//! +//! 5. The workers call Shared::shutdown to enter the single-threaded phase of +//! shutdown. These calls will push their core to Shared::shutdown_cores, +//! and the last thread to push its core will finish the shutdown procedure. +//! +//! 6. The local run queue of each core is emptied, then the inject queue is +//! emptied. +//! +//! At this point, shutdown has completed. It is not possible for any of the +//! collections to contain any tasks at this point, as each collection was +//! closed first, then emptied afterwards. +//! +//! ## Spawns during shutdown +//! +//! When spawning tasks during shutdown, there are two cases: +//! +//! * The spawner observes the OwnedTasks being open, and the inject queue is +//! closed. +//! * The spawner observes the OwnedTasks being closed and doesn't check the +//! inject queue. +//! +//! The first case can only happen if the OwnedTasks::bind call happens before +//! or during step 1 of shutdown. In this case, the runtime will clean up the +//! task in step 3 of shutdown. +//! +//! In the latter case, the task was not spawned and the task is immediately +//! cancelled by the spawner. +//! +//! The correctness of shutdown requires both the inject queue and OwnedTasks +//! collection to have a closed bit. With a close bit on only the inject queue, +//! spawning could run in to a situation where a task is successfully bound long +//! after the runtime has shut down. With a close bit on only the OwnedTasks, +//! the first spawning situation could result in the notification being pushed +//! to the inject queue after step 6 of shutdown, which would leave a task in +//! the inject queue indefinitely. This would be a ref-count cycle and a memory +//! leak. use crate::coop; +use crate::future::Future; use crate::loom::rand::seed; use crate::loom::sync::{Arc, Mutex}; use crate::park::{Park, Unpark}; use crate::runtime; use crate::runtime::enter::EnterContext; use crate::runtime::park::{Parker, Unparker}; -use crate::runtime::task::{Inject, OwnedTasks}; +use crate::runtime::task::{Inject, JoinHandle, OwnedTasks}; use crate::runtime::thread_pool::{AtomicCell, Idle}; use crate::runtime::{queue, task}; use crate::util::FastRand; @@ -44,7 +97,7 @@ struct Core { lifo_slot: Option, /// The worker-local run queue. - run_queue: queue::Local>, + run_queue: queue::Local>, /// True if the worker is currently searching for more work. Searching /// involves attempting to steal from other workers. @@ -70,13 +123,13 @@ pub(super) struct Shared { remotes: Box<[Remote]>, /// Submit work to the scheduler while **not** currently on a worker thread. - inject: Inject>, + inject: Inject>, /// Coordinates idle workers idle: Idle, /// Collection of all active tasks spawned onto this executor. - owned: OwnedTasks>, + owned: OwnedTasks>, /// Cores that have observed the shutdown signal /// @@ -89,7 +142,7 @@ pub(super) struct Shared { /// Used to communicate with a worker from other threads. struct Remote { /// Steal tasks from this worker. - steal: queue::Steal>, + steal: queue::Steal>, /// Unparks the associated worker thread unpark: Unparker, @@ -113,10 +166,10 @@ pub(crate) struct Launch(Vec>); type RunResult = Result, ()>; /// A task handle -type Task = task::Task>; +type Task = task::Task>; /// A notified task handle -type Notified = task::Notified>; +type Notified = task::Notified>; // Tracks thread-local state scoped_thread_local!(static CURRENT: Context); @@ -543,13 +596,16 @@ impl Core { /// Signals all tasks to shut down, and waits for them to complete. Must run /// before we enter the single-threaded phase of shutdown processing. fn pre_shutdown(&mut self, worker: &Worker) { + // The OwnedTasks was closed in Shared::close. + debug_assert!(worker.shared.owned.is_closed()); + // Signal to all tasks to shut down. while let Some(header) = worker.shared.owned.pop_back() { header.shutdown(); } } - // Shutdown the core + /// Shutdown the core fn shutdown(&mut self) { // Take the core let mut park = self.park.take().expect("park missing"); @@ -563,46 +619,42 @@ impl Core { impl Worker { /// Returns a reference to the scheduler's injection queue - fn inject(&self) -> &Inject> { + fn inject(&self) -> &Inject> { &self.shared.inject } } -impl task::Schedule for Arc { - fn bind(task: Task) -> Arc { - CURRENT.with(|maybe_cx| { - let cx = maybe_cx.expect("scheduler context missing"); - - // Track the task - cx.worker.shared.owned.push_front(task); - - // Return a clone of the worker - cx.worker.clone() - }) - } - +impl task::Schedule for Arc { fn release(&self, task: &Task) -> Option { // SAFETY: Inserted into owned in bind. - unsafe { self.shared.owned.remove(task) } + unsafe { self.owned.remove(task) } } fn schedule(&self, task: Notified) { - // Because this is not a newly spawned task, if scheduling fails due to - // the runtime shutting down, there is no special work that must happen - // here. - let _ = self.shared.schedule(task, false); + (**self).schedule(task, false); } fn yield_now(&self, task: Notified) { - // Because this is not a newly spawned task, if scheduling fails due to - // the runtime shutting down, there is no special work that must happen - // here. - let _ = self.shared.schedule(task, true); + (**self).schedule(task, true); } } impl Shared { - pub(super) fn schedule(&self, task: Notified, is_yield: bool) -> Result<(), Notified> { + pub(super) fn bind_new_task(me: &Arc, future: T) -> JoinHandle + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + let (handle, notified) = me.owned.bind(future, me.clone()); + + if let Some(notified) = notified { + me.schedule(notified, false); + } + + handle + } + + pub(super) fn schedule(&self, task: Notified, is_yield: bool) { CURRENT.with(|maybe_cx| { if let Some(cx) = maybe_cx { // Make sure the task is part of the **current** scheduler. @@ -610,15 +662,14 @@ impl Shared { // And the current thread still holds a core if let Some(core) = cx.core.borrow_mut().as_mut() { self.schedule_local(core, task, is_yield); - return Ok(()); + return; } } } - // Otherwise, use the inject queue - self.inject.push(task)?; + // Otherwise, use the inject queue. + self.inject.push(task); self.notify_parked(); - Ok(()) }) } @@ -654,6 +705,7 @@ impl Shared { pub(super) fn close(&self) { if self.inject.close() { + self.owned.close(); self.notify_all(); } } diff --git a/tokio/src/task/local.rs b/tokio/src/task/local.rs index 7404cc2c19b..c74c8b438ed 100644 --- a/tokio/src/task/local.rs +++ b/tokio/src/task/local.rs @@ -1,6 +1,6 @@ //! Runs `!Send` futures on the current thread. use crate::loom::sync::{Arc, Mutex}; -use crate::runtime::task::{self, JoinHandle, OwnedTasks, Task}; +use crate::runtime::task::{self, JoinHandle, LocalOwnedTasks, Task}; use crate::sync::AtomicWaker; use std::cell::{Cell, RefCell}; @@ -232,7 +232,7 @@ struct Context { struct Tasks { /// Collection of all active tasks spawned onto this executor. - owned: OwnedTasks>, + owned: LocalOwnedTasks>, /// Local run queue sender and receiver. queue: VecDeque>>, @@ -308,10 +308,12 @@ cfg_rt! { let cx = maybe_cx .expect("`spawn_local` called from outside of a `task::LocalSet`"); - // Safety: Tasks are only polled and dropped from the thread that - // spawns them. - let (task, handle) = unsafe { task::joinable_local(future) }; - cx.tasks.borrow_mut().queue.push_back(task); + let (handle, notified) = cx.tasks.borrow_mut().owned.bind(future, cx.shared.clone()); + + if let Some(notified) = notified { + cx.shared.schedule(notified); + } + handle }) } @@ -333,7 +335,7 @@ impl LocalSet { tick: Cell::new(0), context: Context { tasks: RefCell::new(Tasks { - owned: OwnedTasks::new(), + owned: LocalOwnedTasks::new(), queue: VecDeque::with_capacity(INITIAL_CAPACITY), }), shared: Arc::new(Shared { @@ -388,8 +390,18 @@ impl LocalSet { F::Output: 'static, { let future = crate::util::trace::task(future, "local", None); - let (task, handle) = unsafe { task::joinable_local(future) }; - self.context.tasks.borrow_mut().queue.push_back(task); + + let (handle, notified) = self + .context + .tasks + .borrow_mut() + .owned + .bind(future, self.context.shared.clone()); + + if let Some(notified) = notified { + self.context.shared.schedule(notified); + } + self.context.shared.waker.wake(); handle } @@ -593,6 +605,12 @@ impl Default for LocalSet { impl Drop for LocalSet { fn drop(&mut self) { self.with(|| { + // Close the LocalOwnedTasks. This ensures that any calls to + // spawn_local in the destructor of a future on this LocalSet will + // immediately cancel the task, and prevents the task from being + // added to `owned`. + self.context.tasks.borrow_mut().owned.close(); + // Loop required here to ensure borrow is dropped between iterations #[allow(clippy::while_let_loop)] loop { @@ -671,14 +689,6 @@ impl Shared { } impl task::Schedule for Arc { - fn bind(task: Task) -> Arc { - CURRENT.with(|maybe_cx| { - let cx = maybe_cx.expect("scheduler context missing"); - cx.tasks.borrow_mut().owned.push_front(task); - cx.shared.clone() - }) - } - fn release(&self, task: &Task) -> Option> { CURRENT.with(|maybe_cx| { let cx = maybe_cx.expect("scheduler context missing"); @@ -686,7 +696,7 @@ impl task::Schedule for Arc { assert!(cx.shared.ptr_eq(self)); // safety: task must be contained by list. It is inserted into the - // list in `bind`. + // list when spawning. unsafe { cx.tasks.borrow_mut().owned.remove(&task) } }) } diff --git a/tokio/tests/support/mock_file.rs b/tokio/tests/support/mock_file.rs index 1ce326b62aa..60f6bedbb73 100644 --- a/tokio/tests/support/mock_file.rs +++ b/tokio/tests/support/mock_file.rs @@ -211,7 +211,7 @@ impl Read for &'_ File { assert!(dst.len() >= data.len()); assert!(dst.len() <= 16 * 1024, "actual = {}", dst.len()); // max buffer - &mut dst[..data.len()].copy_from_slice(&data); + dst[..data.len()].copy_from_slice(&data); Ok(data.len()) } Some(Read(Err(e))) => Err(e),