diff --git a/tokio/src/loom/std/atomic_u64.rs b/tokio/src/loom/std/atomic_u64.rs index 7eb457a2405..8ea6bd403a4 100644 --- a/tokio/src/loom/std/atomic_u64.rs +++ b/tokio/src/loom/std/atomic_u64.rs @@ -2,19 +2,15 @@ //! re-export of `AtomicU64`. On 32 bit platforms, this is implemented using a //! `Mutex`. -pub(crate) use self::imp::AtomicU64; - // `AtomicU64` can only be used on targets with `target_has_atomic` is 64 or greater. // Once `cfg_target_has_atomic` feature is stable, we can replace it with // `#[cfg(target_has_atomic = "64")]`. // Refs: https://github.com/rust-lang/rust/tree/master/src/librustc_target -#[cfg(not(any(target_arch = "arm", target_arch = "mips", target_arch = "powerpc")))] -mod imp { +cfg_has_atomic_u64! { pub(crate) use std::sync::atomic::AtomicU64; } -#[cfg(any(target_arch = "arm", target_arch = "mips", target_arch = "powerpc"))] -mod imp { +cfg_not_has_atomic_u64! { use crate::loom::sync::Mutex; use std::sync::atomic::Ordering; diff --git a/tokio/src/macros/cfg.rs b/tokio/src/macros/cfg.rs index 1e77556d8db..dcd2d795acf 100644 --- a/tokio/src/macros/cfg.rs +++ b/tokio/src/macros/cfg.rs @@ -384,3 +384,29 @@ macro_rules! cfg_not_coop { )* } } + +macro_rules! cfg_has_atomic_u64 { + ($($item:item)*) => { + $( + #[cfg(not(any( + target_arch = "arm", + target_arch = "mips", + target_arch = "powerpc" + )))] + $item + )* + } +} + +macro_rules! cfg_not_has_atomic_u64 { + ($($item:item)*) => { + $( + #[cfg(any( + target_arch = "arm", + target_arch = "mips", + target_arch = "powerpc" + ))] + $item + )* + } +} diff --git a/tokio/src/runtime/basic_scheduler.rs b/tokio/src/runtime/basic_scheduler.rs index 09a3d7dd718..0c0888c9e6a 100644 --- a/tokio/src/runtime/basic_scheduler.rs +++ b/tokio/src/runtime/basic_scheduler.rs @@ -246,7 +246,10 @@ impl Inner

{ }; match entry { - RemoteMsg::Schedule(task) => crate::coop::budget(|| task.run()), + RemoteMsg::Schedule(task) => { + let task = context.shared.owned.assert_owner(task); + crate::coop::budget(|| task.run()) + } } } @@ -319,29 +322,25 @@ impl Drop for BasicScheduler

{ } // Drain local queue + // We already shut down every task, so we just need to drop the task. for task in context.tasks.borrow_mut().queue.drain(..) { - task.shutdown(); + drop(task); } // Drain remote queue and set it to None - let mut remote_queue = scheduler.spawner.shared.queue.lock(); + let remote_queue = scheduler.spawner.shared.queue.lock().take(); // Using `Option::take` to replace the shared queue with `None`. - if let Some(remote_queue) = remote_queue.take() { + // We already shut down every task, so we just need to drop the task. + if let Some(remote_queue) = remote_queue { for entry in remote_queue { match entry { RemoteMsg::Schedule(task) => { - task.shutdown(); + drop(task); } } } } - // By dropping the mutex lock after the full duration of the above loop, - // any thread that sees the queue in the `None` state is guaranteed that - // the runtime has fully shut down. - // - // The assert below is unrelated to this mutex. - drop(remote_queue); assert!(context.shared.owned.is_empty()); }); @@ -400,8 +399,7 @@ impl fmt::Debug for Spawner { impl Schedule for Arc { fn release(&self, task: &Task) -> Option> { - // SAFETY: Inserted into the list in bind above. - unsafe { self.owned.remove(task) } + self.owned.remove(task) } fn schedule(&self, task: task::Notified) { diff --git a/tokio/src/runtime/blocking/pool.rs b/tokio/src/runtime/blocking/pool.rs index b7d725128d7..0c23bb0dc86 100644 --- a/tokio/src/runtime/blocking/pool.rs +++ b/tokio/src/runtime/blocking/pool.rs @@ -71,7 +71,7 @@ struct Shared { worker_thread_index: usize, } -type Task = task::Notified; +type Task = task::UnownedTask; const KEEP_ALIVE: Duration = Duration::from_secs(10); diff --git a/tokio/src/runtime/task/core.rs b/tokio/src/runtime/task/core.rs index e8e4fef7cff..31668027330 100644 --- a/tokio/src/runtime/task/core.rs +++ b/tokio/src/runtime/task/core.rs @@ -65,6 +65,19 @@ pub(crate) struct Header { /// Table of function pointers for executing actions on the task. pub(super) vtable: &'static Vtable, + /// This integer contains the id of the OwnedTasks or LocalOwnedTasks that + /// this task is stored in. If the task is not in any list, should be the + /// id of the list that it was previously in, or zero if it has never been + /// in any list. + /// + /// Once a task has been bound to a list, it can never be bound to another + /// list, even if removed from the first list. + /// + /// The id is not unset when removed from a list because we want to be able + /// to read the id without synchronization, even if it is concurrently being + /// removed from the list. + pub(super) owner_id: UnsafeCell, + /// The tracing ID for this instrumented task. #[cfg(all(tokio_unstable, feature = "tracing"))] pub(super) id: Option, @@ -98,6 +111,7 @@ impl Cell { owned: UnsafeCell::new(linked_list::Pointers::new()), queue_next: UnsafeCell::new(None), vtable: raw::vtable::(), + owner_id: UnsafeCell::new(0), #[cfg(all(tokio_unstable, feature = "tracing"))] id, }, @@ -203,12 +217,27 @@ impl CoreStage { cfg_rt_multi_thread! { impl Header { - pub(crate) unsafe fn set_next(&self, next: Option>) { + pub(super) unsafe fn set_next(&self, next: Option>) { self.queue_next.with_mut(|ptr| *ptr = next); } } } +impl Header { + // safety: The caller must guarantee exclusive access to this field, and + // must ensure that the id is either 0 or the id of the OwnedTasks + // containing this task. + pub(super) unsafe fn set_owner_id(&self, owner: u64) { + self.owner_id.with_mut(|ptr| *ptr = owner); + } + + pub(super) fn get_owner_id(&self) -> u64 { + // safety: If there are concurrent writes, then that write has violated + // the safety requirements on `set_owner_id`. + unsafe { self.owner_id.with(|ptr| *ptr) } + } +} + impl Trailer { pub(crate) unsafe fn set_waker(&self, waker: Option) { self.waker.with_mut(|ptr| { diff --git a/tokio/src/runtime/task/list.rs b/tokio/src/runtime/task/list.rs index 63812f9c76f..3be8ed3b6b2 100644 --- a/tokio/src/runtime/task/list.rs +++ b/tokio/src/runtime/task/list.rs @@ -8,13 +8,53 @@ use crate::future::Future; use crate::loom::sync::Mutex; -use crate::runtime::task::{JoinHandle, Notified, Schedule, Task}; +use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, Task}; use crate::util::linked_list::{Link, LinkedList}; use std::marker::PhantomData; +// The id from the module below is used to verify whether a given task is stored +// in this OwnedTasks, or some other task. The counter starts at one so we can +// use zero for tasks not owned by any list. +// +// The safety checks in this file can technically be violated if the counter is +// overflown, but the checks are not supposed to ever fail unless there is a +// bug in Tokio, so we accept that certain bugs would not be caught if the two +// mixed up runtimes happen to have the same id. + +cfg_has_atomic_u64! { + use std::sync::atomic::{AtomicU64, Ordering}; + + static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1); + + fn get_next_id() -> u64 { + loop { + let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed); + if id != 0 { + return id; + } + } + } +} + +cfg_not_has_atomic_u64! { + use std::sync::atomic::{AtomicU32, Ordering}; + + static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1); + + fn get_next_id() -> u64 { + loop { + let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed); + if id != 0 { + return u64::from(id); + } + } + } +} + pub(crate) struct OwnedTasks { inner: Mutex>, + id: u64, } struct OwnedTasksInner { list: LinkedList, as Link>::Target>, @@ -24,7 +64,8 @@ struct OwnedTasksInner { pub(crate) struct LocalOwnedTasks { list: LinkedList, as Link>::Target>, closed: bool, - _not_send: PhantomData<*const ()>, + id: u64, + _not_send_or_sync: PhantomData<*const ()>, } impl OwnedTasks { @@ -34,6 +75,7 @@ impl OwnedTasks { list: LinkedList::new(), closed: false, }), + id: get_next_id(), } } @@ -51,11 +93,17 @@ impl OwnedTasks { { let (task, notified, join) = super::new_task(task, scheduler); + unsafe { + // safety: We just created the task, so we have exclusive access + // to the field. + task.header().set_owner_id(self.id); + } + let mut lock = self.inner.lock(); if lock.closed { drop(lock); - drop(task); - notified.shutdown(); + drop(notified); + task.shutdown(); (join, None) } else { lock.list.push_front(task); @@ -63,14 +111,36 @@ impl OwnedTasks { } } + /// Assert that the given task is owned by this OwnedTasks and convert it to + /// a LocalNotified, giving the thread permission to poll this task. + #[inline] + pub(crate) fn assert_owner(&self, task: Notified) -> LocalNotified { + assert_eq!(task.0.header().get_owner_id(), self.id); + + // safety: All tasks bound to this OwnedTasks are Send, so it is safe + // to poll it on this thread no matter what thread we are on. + LocalNotified { + task: task.0, + _not_send: PhantomData, + } + } + pub(crate) fn pop_back(&self) -> Option> { self.inner.lock().list.pop_back() } - /// The caller must ensure that if the provided task is stored in a - /// linked list, then it is in this linked list. - pub(crate) unsafe fn remove(&self, task: &Task) -> Option> { - self.inner.lock().list.remove(task.header().into()) + pub(crate) fn remove(&self, task: &Task) -> Option> { + let task_id = task.header().get_owner_id(); + if task_id == 0 { + // The task is unowned. + return None; + } + + assert_eq!(task_id, self.id); + + // safety: We just checked that the provided task is not in some other + // linked list. + unsafe { self.inner.lock().list.remove(task.header().into()) } } pub(crate) fn is_empty(&self) -> bool { @@ -93,7 +163,8 @@ impl LocalOwnedTasks { Self { list: LinkedList::new(), closed: false, - _not_send: PhantomData, + id: get_next_id(), + _not_send_or_sync: PhantomData, } } @@ -109,9 +180,15 @@ impl LocalOwnedTasks { { let (task, notified, join) = super::new_task(task, scheduler); + unsafe { + // safety: We just created the task, so we have exclusive access + // to the field. + task.header().set_owner_id(self.id); + } + if self.closed { - drop(task); - notified.shutdown(); + drop(notified); + task.shutdown(); (join, None) } else { self.list.push_front(task); @@ -123,10 +200,33 @@ impl LocalOwnedTasks { self.list.pop_back() } - /// The caller must ensure that if the provided task is stored in a - /// linked list, then it is in this linked list. - pub(crate) unsafe fn remove(&mut self, task: &Task) -> Option> { - self.list.remove(task.header().into()) + pub(crate) fn remove(&mut self, task: &Task) -> Option> { + let task_id = task.header().get_owner_id(); + if task_id == 0 { + // The task is unowned. + return None; + } + + assert_eq!(task_id, self.id); + + // safety: We just checked that the provided task is not in some other + // linked list. + unsafe { self.list.remove(task.header().into()) } + } + + /// Assert that the given task is owned by this LocalOwnedTasks and convert + /// it to a LocalNotified, giving the thread permission to poll this task. + #[inline] + pub(crate) fn assert_owner(&self, task: Notified) -> LocalNotified { + assert_eq!(task.0.header().get_owner_id(), self.id); + + // safety: The task was bound to this LocalOwnedTasks, and the + // LocalOwnedTasks is not Send or Sync, so we are on the right thread + // for polling this task. + LocalNotified { + task: task.0, + _not_send: PhantomData, + } } pub(crate) fn is_empty(&self) -> bool { @@ -139,3 +239,23 @@ impl LocalOwnedTasks { self.closed = true; } } + +#[cfg(all(test))] +mod tests { + use super::*; + + // This test may run in parallel with other tests, so we only test that ids + // come in increasing order. + #[test] + fn test_id_not_broken() { + let mut last_id = get_next_id(); + assert_ne!(last_id, 0); + + for _ in 0..1000 { + let next_id = get_next_id(); + assert_ne!(next_id, 0); + assert!(last_id < next_id); + last_id = next_id; + } + } +} diff --git a/tokio/src/runtime/task/mod.rs b/tokio/src/runtime/task/mod.rs index 1ca5ba579b5..adc91d81d3f 100644 --- a/tokio/src/runtime/task/mod.rs +++ b/tokio/src/runtime/task/mod.rs @@ -1,6 +1,6 @@ mod core; use self::core::Cell; -pub(crate) use self::core::Header; +use self::core::Header; mod error; #[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 @@ -46,13 +46,34 @@ pub(crate) struct Task { unsafe impl Send for Task {} unsafe impl Sync for Task {} -/// A task was notified +/// A task was notified. #[repr(transparent)] pub(crate) struct Notified(Task); +// safety: This type cannot be used to touch the task without first verifying +// that the value is on a thread where it is safe to poll the task. unsafe impl Send for Notified {} unsafe impl Sync for Notified {} +/// A non-Send variant of Notified with the invariant that it is on a thread +/// where it is safe to poll it. +#[repr(transparent)] +pub(crate) struct LocalNotified { + task: Task, + _not_send: PhantomData<*const ()>, +} + +/// A task that is not owned by any OwnedTasks. Used for blocking tasks. +/// This type holds two ref-counts. +pub(crate) struct UnownedTask { + raw: RawTask, + _p: PhantomData, +} + +// safety: This type can only be created given a Send task. +unsafe impl Send for UnownedTask {} +unsafe impl Sync for UnownedTask {} + /// Task result sent back pub(crate) type Result = std::result::Result; @@ -105,41 +126,50 @@ cfg_rt! { /// Create a new task with an associated join handle. This method is used /// only when the task is not going to be stored in an `OwnedTasks` list. /// - /// Currently only blocking tasks and tests use this method. - pub(crate) fn unowned(task: T, scheduler: S) -> (Notified, JoinHandle) + /// Currently only blocking tasks use this method. + pub(crate) fn unowned(task: T, scheduler: S) -> (UnownedTask, JoinHandle) where S: Schedule, T: Send + Future + 'static, T::Output: Send + 'static, { let (task, notified, join) = new_task(task, scheduler); - drop(task); - (notified, join) + + // This transfers the ref-count of task and notified into an UnownedTask. + // This is valid because an UnownedTask holds two ref-counts. + let unowned = UnownedTask { + raw: task.raw, + _p: PhantomData, + }; + std::mem::forget(task); + std::mem::forget(notified); + + (unowned, join) } } impl Task { - pub(crate) unsafe fn from_raw(ptr: NonNull

) -> Task { + unsafe fn from_raw(ptr: NonNull
) -> Task { Task { raw: RawTask::from_raw(ptr), _p: PhantomData, } } - pub(crate) fn header(&self) -> &Header { + fn header(&self) -> &Header { self.raw.header() } } cfg_rt_multi_thread! { impl Notified { - pub(crate) unsafe fn from_raw(ptr: NonNull
) -> Notified { + unsafe fn from_raw(ptr: NonNull
) -> Notified { Notified(Task::from_raw(ptr)) } } impl Task { - pub(crate) fn into_raw(self) -> NonNull
{ + fn into_raw(self) -> NonNull
{ let ret = self.header().into(); mem::forget(self); ret @@ -147,7 +177,7 @@ cfg_rt_multi_thread! { } impl Notified { - pub(crate) fn into_raw(self) -> NonNull
{ + fn into_raw(self) -> NonNull
{ self.0.into_raw() } } @@ -160,16 +190,45 @@ impl Task { } } -impl Notified { +impl LocalNotified { /// Run the task pub(crate) fn run(self) { - self.0.raw.poll(); + self.task.raw.poll(); + mem::forget(self); + } +} + +impl UnownedTask { + // Used in test of the inject queue. + #[cfg(test)] + pub(super) fn into_notified(self) -> Notified { + Notified(self.into_task()) + } + + fn into_task(self) -> Task { + // Convert into a task. + let task = Task { + raw: self.raw, + _p: PhantomData, + }; + mem::forget(self); + + // Drop a ref-count since an UnownedTask holds two. + task.header().state.ref_dec(); + + task + } + + pub(crate) fn run(self) { + // Decrement the ref-count + self.raw.header().state.ref_dec(); + // Poll the task + self.raw.poll(); mem::forget(self); } - /// Pre-emptively cancel the task as part of the shutdown process. pub(crate) fn shutdown(self) { - self.0.shutdown(); + self.into_task().shutdown() } } @@ -183,6 +242,16 @@ impl Drop for Task { } } +impl Drop for UnownedTask { + fn drop(&mut self) { + // Decrement the ref count + if self.raw.header().state.ref_dec_twice() { + // Deallocate if this is the final ref count + self.raw.dealloc(); + } + } +} + impl fmt::Debug for Task { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { write!(fmt, "Task({:p})", self.header()) diff --git a/tokio/src/runtime/task/state.rs b/tokio/src/runtime/task/state.rs index 2580e9c83d2..3641af083b9 100644 --- a/tokio/src/runtime/task/state.rs +++ b/tokio/src/runtime/task/state.rs @@ -324,6 +324,12 @@ impl State { prev.ref_count() == 1 } + /// Returns `true` if the task should be released. + pub(super) fn ref_dec_twice(&self) -> bool { + let prev = Snapshot(self.val.fetch_sub(2 * REF_ONE, AcqRel)); + prev.ref_count() == 2 + } + fn fetch_update(&self, mut f: F) -> Result where F: FnMut(Snapshot) -> Option, diff --git a/tokio/src/runtime/tests/mod.rs b/tokio/src/runtime/tests/mod.rs index 48919fa716e..be36d6ffe4d 100644 --- a/tokio/src/runtime/tests/mod.rs +++ b/tokio/src/runtime/tests/mod.rs @@ -14,7 +14,7 @@ mod unowned_wrapper { let span = tracing::trace_span!("test_span"); let task = task.instrument(span); let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule); - (task, handle) + (task.into_notified(), handle) } #[cfg(not(all(tokio_unstable, feature = "tracing")))] @@ -24,7 +24,7 @@ mod unowned_wrapper { T::Output: Send + 'static, { let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule); - (task, handle) + (task.into_notified(), handle) } } diff --git a/tokio/src/runtime/tests/task.rs b/tokio/src/runtime/tests/task.rs index ae5c04c840c..39cb1d2204c 100644 --- a/tokio/src/runtime/tests/task.rs +++ b/tokio/src/runtime/tests/task.rs @@ -241,6 +241,7 @@ impl Runtime { while !self.is_empty() && n < max { let task = self.next_task(); n += 1; + let task = self.0.owned.assert_owner(task); task.run(); } @@ -264,7 +265,7 @@ impl Runtime { } while let Some(task) = core.queue.pop_back() { - task.shutdown(); + drop(task); } drop(core); @@ -275,8 +276,7 @@ impl Runtime { impl Schedule for Runtime { fn release(&self, task: &Task) -> Option> { - // safety: copying worker.rs - unsafe { self.0.owned.remove(task) } + self.0.owned.remove(task) } fn schedule(&self, task: task::Notified) { diff --git a/tokio/src/runtime/thread_pool/worker.rs b/tokio/src/runtime/thread_pool/worker.rs index c0535238ced..608a7353ae9 100644 --- a/tokio/src/runtime/thread_pool/worker.rs +++ b/tokio/src/runtime/thread_pool/worker.rs @@ -384,6 +384,8 @@ impl Context { } fn run_task(&self, task: Notified, mut core: Box) -> RunResult { + let task = self.worker.shared.owned.assert_owner(task); + // Make sure the worker is not in the **searching** state. This enables // another idle worker to try to steal work. core.transition_from_searching(&self.worker); @@ -414,6 +416,7 @@ impl Context { if coop::has_budget_remaining() { // Run the LIFO task, then loop *self.core.borrow_mut() = Some(core); + let task = self.worker.shared.owned.assert_owner(task); task.run(); } else { // Not enough budget left to run the LIFO task, push it to @@ -626,8 +629,7 @@ impl Worker { impl task::Schedule for Arc { fn release(&self, task: &Task) -> Option { - // SAFETY: Inserted into owned in bind. - unsafe { self.owned.remove(task) } + self.owned.remove(task) } fn schedule(&self, task: Notified) { @@ -762,8 +764,10 @@ impl Shared { } // Drain the injection queue + // + // We already shut down every task, so we can simply drop the tasks. while let Some(task) = self.inject.pop() { - task.shutdown(); + drop(task); } } diff --git a/tokio/src/task/local.rs b/tokio/src/task/local.rs index f05e0ebafd4..98d03f62cca 100644 --- a/tokio/src/task/local.rs +++ b/tokio/src/task/local.rs @@ -540,11 +540,11 @@ impl LocalSet { true } - fn next_task(&self) -> Option>> { + fn next_task(&self) -> Option>> { let tick = self.tick.get(); self.tick.set(tick.wrapping_add(1)); - if tick % REMOTE_FIRST_INTERVAL == 0 { + let task = if tick % REMOTE_FIRST_INTERVAL == 0 { self.context .shared .queue @@ -566,7 +566,9 @@ impl LocalSet { .as_mut() .and_then(|queue| queue.pop_front()) }) - } + }; + + task.map(|task| self.context.tasks.borrow_mut().owned.assert_owner(task)) } fn with(&self, f: impl FnOnce() -> T) -> T { @@ -631,15 +633,17 @@ impl Drop for LocalSet { task.shutdown(); } + // We already called shutdown on all tasks above, so there is no + // need to call shutdown. for task in self.context.tasks.borrow_mut().queue.drain(..) { - task.shutdown(); + drop(task); } // Take the queue from the Shared object to prevent pushing // notifications to it in the future. let queue = self.context.shared.queue.lock().take().unwrap(); for task in queue { - task.shutdown(); + drop(task); } assert!(self.context.tasks.borrow().owned.is_empty()); @@ -711,12 +715,8 @@ impl task::Schedule for Arc { fn release(&self, task: &Task) -> Option> { CURRENT.with(|maybe_cx| { let cx = maybe_cx.expect("scheduler context missing"); - assert!(cx.shared.ptr_eq(self)); - - // safety: task must be contained by list. It is inserted into the - // list when spawning. - unsafe { cx.tasks.borrow_mut().owned.remove(&task) } + cx.tasks.borrow_mut().owned.remove(&task) }) }