Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add owner id for tasks in OwnedTasks #3979

Merged
merged 11 commits into from Jul 27, 2021
8 changes: 2 additions & 6 deletions tokio/src/loom/std/atomic_u64.rs
Expand Up @@ -2,19 +2,15 @@
//! re-export of `AtomicU64`. On 32 bit platforms, this is implemented using a
//! `Mutex`.

pub(crate) use self::imp::AtomicU64;

// `AtomicU64` can only be used on targets with `target_has_atomic` is 64 or greater.
// Once `cfg_target_has_atomic` feature is stable, we can replace it with
// `#[cfg(target_has_atomic = "64")]`.
// Refs: https://github.com/rust-lang/rust/tree/master/src/librustc_target
#[cfg(not(any(target_arch = "arm", target_arch = "mips", target_arch = "powerpc")))]
mod imp {
cfg_has_atomic_u64! {
pub(crate) use std::sync::atomic::AtomicU64;
}

#[cfg(any(target_arch = "arm", target_arch = "mips", target_arch = "powerpc"))]
mod imp {
cfg_not_has_atomic_u64! {
use crate::loom::sync::Mutex;
use std::sync::atomic::Ordering;

Expand Down
26 changes: 26 additions & 0 deletions tokio/src/macros/cfg.rs
Expand Up @@ -384,3 +384,29 @@ macro_rules! cfg_not_coop {
)*
}
}

macro_rules! cfg_has_atomic_u64 {
($($item:item)*) => {
$(
#[cfg(not(any(
target_arch = "arm",
target_arch = "mips",
target_arch = "powerpc"
)))]
$item
)*
}
}

macro_rules! cfg_not_has_atomic_u64 {
($($item:item)*) => {
$(
#[cfg(any(
target_arch = "arm",
target_arch = "mips",
target_arch = "powerpc"
))]
$item
)*
}
}
24 changes: 11 additions & 13 deletions tokio/src/runtime/basic_scheduler.rs
Expand Up @@ -246,7 +246,10 @@ impl<P: Park> Inner<P> {
};

match entry {
RemoteMsg::Schedule(task) => crate::coop::budget(|| task.run()),
RemoteMsg::Schedule(task) => {
let task = context.shared.owned.assert_owner(task);
crate::coop::budget(|| task.run())
}
}
}

Expand Down Expand Up @@ -319,29 +322,25 @@ impl<P: Park> Drop for BasicScheduler<P> {
}

// Drain local queue
// We already shut down every task, so we just need to drop the task.
for task in context.tasks.borrow_mut().queue.drain(..) {
task.shutdown();
drop(task);
}

// Drain remote queue and set it to None
let mut remote_queue = scheduler.spawner.shared.queue.lock();
let remote_queue = scheduler.spawner.shared.queue.lock().take();

// Using `Option::take` to replace the shared queue with `None`.
if let Some(remote_queue) = remote_queue.take() {
// We already shut down every task, so we just need to drop the task.
if let Some(remote_queue) = remote_queue {
for entry in remote_queue {
match entry {
RemoteMsg::Schedule(task) => {
task.shutdown();
drop(task);
}
}
}
}
// By dropping the mutex lock after the full duration of the above loop,
// any thread that sees the queue in the `None` state is guaranteed that
// the runtime has fully shut down.
//
// The assert below is unrelated to this mutex.
drop(remote_queue);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason to remove this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was never necessary to hold the mutex locked for this long in the first place. It was introduced in #3752.


assert!(context.shared.owned.is_empty());
});
Expand Down Expand Up @@ -400,8 +399,7 @@ impl fmt::Debug for Spawner {

impl Schedule for Arc<Shared> {
fn release(&self, task: &Task<Self>) -> Option<Task<Self>> {
// SAFETY: Inserted into the list in bind above.
unsafe { self.owned.remove(task) }
self.owned.remove(task)
}

fn schedule(&self, task: task::Notified<Self>) {
Expand Down
2 changes: 1 addition & 1 deletion tokio/src/runtime/blocking/pool.rs
Expand Up @@ -71,7 +71,7 @@ struct Shared {
worker_thread_index: usize,
}

type Task = task::Notified<NoopSchedule>;
type Task = task::UnownedTask<NoopSchedule>;

const KEEP_ALIVE: Duration = Duration::from_secs(10);

Expand Down
31 changes: 30 additions & 1 deletion tokio/src/runtime/task/core.rs
Expand Up @@ -69,6 +69,19 @@ pub(crate) struct Header {
/// Table of function pointers for executing actions on the task.
pub(super) vtable: &'static Vtable,

/// This integer contains the id of the OwnedTasks or LocalOwnedTasks that
/// this task is stored in. If the task is not in any list, should be the
/// id of the list that it was previously in, or zero if it has never been
/// in any list.
///
/// Once a task has been bound to a list, it can never be bound to another
/// list, even if removed from the first list.
///
/// The id is not unset when removed from a list because we want to be able
/// to read the id without synchronization, even if it is concurrently being
/// removed from the list.
pub(super) owner_id: UnsafeCell<u64>,

/// The tracing ID for this instrumented task.
#[cfg(all(tokio_unstable, feature = "tracing"))]
pub(super) id: Option<tracing::Id>,
Expand Down Expand Up @@ -102,6 +115,7 @@ impl<T: Future, S: Schedule> Cell<T, S> {
owned: UnsafeCell::new(linked_list::Pointers::new()),
queue_next: UnsafeCell::new(None),
vtable: raw::vtable::<T, S>(),
owner_id: UnsafeCell::new(0),
#[cfg(all(tokio_unstable, feature = "tracing"))]
id,
},
Expand Down Expand Up @@ -267,12 +281,27 @@ impl<T: Future> CoreStage<T> {

cfg_rt_multi_thread! {
impl Header {
pub(crate) unsafe fn set_next(&self, next: Option<NonNull<Header>>) {
pub(super) unsafe fn set_next(&self, next: Option<NonNull<Header>>) {
self.queue_next.with_mut(|ptr| *ptr = next);
}
}
}

impl Header {
// safety: The caller must guarantee exclusive access to this field, and
// must ensure that the id is either 0 or the id of the OwnedTasks
// containing this task.
pub(super) unsafe fn set_owner_id(&self, owner: u64) {
self.owner_id.with_mut(|ptr| *ptr = owner);
}

pub(super) fn get_owner_id(&self) -> u64 {
// safety: If there are concurrent writes, then that write has violated
// the safety requirements on `set_owner_id`.
unsafe { self.owner_id.with(|ptr| *ptr) }
}
}

impl Trailer {
pub(crate) unsafe fn set_waker(&self, waker: Option<Waker>) {
self.waker.with_mut(|ptr| {
Expand Down
146 changes: 133 additions & 13 deletions tokio/src/runtime/task/list.rs
Expand Up @@ -8,13 +8,53 @@

use crate::future::Future;
use crate::loom::sync::Mutex;
use crate::runtime::task::{JoinHandle, Notified, Schedule, Task};
use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, Task};
use crate::util::linked_list::{Link, LinkedList};

use std::marker::PhantomData;

// The id from the module below is used to verify whether a given task is stored
// in this OwnedTasks, or some other task. The counter starts at one so we can
// use zero for tasks not owned by any list.
//
// The safety checks in this file can technically be violated if the counter is
// overflown, but the checks are not supposed to ever fail unless there is a
// bug in Tokio, so we accept that certain bugs would not be caught if the two
// mixed up runtimes happen to have the same id.

cfg_has_atomic_u64! {
use std::sync::atomic::{AtomicU64, Ordering};

static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1);

fn get_next_id() -> u64 {
loop {
let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
if id != 0 {
return id;
}
}
}
}

cfg_not_has_atomic_u64! {
use std::sync::atomic::{AtomicU32, Ordering};

static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1);

fn get_next_id() -> u64 {
loop {
let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
if id != 0 {
return u64::from(id);
}
}
}
}

pub(crate) struct OwnedTasks<S: 'static> {
inner: Mutex<OwnedTasksInner<S>>,
id: u64,
}
struct OwnedTasksInner<S: 'static> {
list: LinkedList<Task<S>, <Task<S> as Link>::Target>,
Expand All @@ -24,6 +64,7 @@ struct OwnedTasksInner<S: 'static> {
pub(crate) struct LocalOwnedTasks<S: 'static> {
list: LinkedList<Task<S>, <Task<S> as Link>::Target>,
closed: bool,
id: u64,
_not_send: PhantomData<*const ()>,
}

Expand All @@ -34,6 +75,7 @@ impl<S: 'static> OwnedTasks<S> {
list: LinkedList::new(),
closed: false,
}),
id: get_next_id(),
}
}

Expand All @@ -51,26 +93,54 @@ impl<S: 'static> OwnedTasks<S> {
{
let (task, notified, join) = super::new_task(task, scheduler);

unsafe {
// safety: We just created the task, so we have exclusive access
// to the field.
task.header().set_owner_id(self.id);
}

let mut lock = self.inner.lock();
if lock.closed {
drop(lock);
drop(task);
notified.shutdown();
drop(notified);
task.shutdown();
(join, None)
} else {
lock.list.push_front(task);
(join, Some(notified))
}
}

/// Assert that the given task is owned by this OwnedTasks and convert it to
/// a LocalNotified, giving the thread permission to poll this task.
#[inline]
pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
assert_eq!(task.0.header().get_owner_id(), self.id);

// safety: All tasks bound to this OwnedTasks are Send, so it is safe
// to poll it on this thread no matter what thread we are on.
LocalNotified {
task: task.0,
_not_send: PhantomData,
}
}

pub(crate) fn pop_back(&self) -> Option<Task<S>> {
self.inner.lock().list.pop_back()
}

/// The caller must ensure that if the provided task is stored in a
/// linked list, then it is in this linked list.
pub(crate) unsafe fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
self.inner.lock().list.remove(task.header().into())
pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
let task_id = task.header().get_owner_id();
if task_id == 0 {
// The task is unowned.
return None;
}

assert_eq!(task_id, self.id);

// safety: We just checked that the provided task is not in some other
// linked list.
unsafe { self.inner.lock().list.remove(task.header().into()) }
}

pub(crate) fn is_empty(&self) -> bool {
Expand All @@ -93,6 +163,7 @@ impl<S: 'static> LocalOwnedTasks<S> {
Self {
list: LinkedList::new(),
closed: false,
id: get_next_id(),
_not_send: PhantomData,
}
}
Expand All @@ -109,9 +180,15 @@ impl<S: 'static> LocalOwnedTasks<S> {
{
let (task, notified, join) = super::new_task(task, scheduler);

unsafe {
// safety: We just created the task, so we have exclusive access
// to the field.
task.header().set_owner_id(self.id);
}

if self.closed {
drop(task);
notified.shutdown();
drop(notified);
task.shutdown();
(join, None)
} else {
self.list.push_front(task);
Expand All @@ -123,10 +200,33 @@ impl<S: 'static> LocalOwnedTasks<S> {
self.list.pop_back()
}

/// The caller must ensure that if the provided task is stored in a
/// linked list, then it is in this linked list.
pub(crate) unsafe fn remove(&mut self, task: &Task<S>) -> Option<Task<S>> {
self.list.remove(task.header().into())
pub(crate) fn remove(&mut self, task: &Task<S>) -> Option<Task<S>> {
let task_id = task.header().get_owner_id();
if task_id == 0 {
// The task is unowned.
return None;
}

assert_eq!(task_id, self.id);

// safety: We just checked that the provided task is not in some other
// linked list.
unsafe { self.list.remove(task.header().into()) }
}

/// Assert that the given task is owned by this LocalOwnedTasks and convert
/// it to a LocalNotified, giving the thread permission to poll this task.
#[inline]
pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
Darksonn marked this conversation as resolved.
Show resolved Hide resolved
assert_eq!(task.0.header().get_owner_id(), self.id);

// safety: The task was bound to this LocalOwnedTasks, and the
// LocalOwnedTasks is not Send, so we are on the right thread for
// polling this task.
LocalNotified {
task: task.0,
_not_send: PhantomData,
}
}

pub(crate) fn is_empty(&self) -> bool {
Expand All @@ -139,3 +239,23 @@ impl<S: 'static> LocalOwnedTasks<S> {
self.closed = true;
}
}

#[cfg(all(test))]
mod tests {
use super::*;

// This test may run in parallel with other tests, so we only test that ids
// come in increasing order.
#[test]
fn test_id_not_broken() {
let mut last_id = get_next_id();
assert_ne!(last_id, 0);

for _ in 0..1000 {
let next_id = get_next_id();
assert_ne!(next_id, 0);
assert!(last_id < next_id);
last_id = next_id;
}
}
}