Skip to content

Commit

Permalink
runtime: add owner id for tasks in OwnedTasks (#3979)
Browse files Browse the repository at this point in the history
  • Loading branch information
Darksonn committed Jul 27, 2021
1 parent 0de0542 commit f2a06bf
Show file tree
Hide file tree
Showing 12 changed files with 317 additions and 69 deletions.
8 changes: 2 additions & 6 deletions tokio/src/loom/std/atomic_u64.rs
Expand Up @@ -2,19 +2,15 @@
//! re-export of `AtomicU64`. On 32 bit platforms, this is implemented using a
//! `Mutex`.

pub(crate) use self::imp::AtomicU64;

// `AtomicU64` can only be used on targets with `target_has_atomic` is 64 or greater.
// Once `cfg_target_has_atomic` feature is stable, we can replace it with
// `#[cfg(target_has_atomic = "64")]`.
// Refs: https://github.com/rust-lang/rust/tree/master/src/librustc_target
#[cfg(not(any(target_arch = "arm", target_arch = "mips", target_arch = "powerpc")))]
mod imp {
cfg_has_atomic_u64! {
pub(crate) use std::sync::atomic::AtomicU64;
}

#[cfg(any(target_arch = "arm", target_arch = "mips", target_arch = "powerpc"))]
mod imp {
cfg_not_has_atomic_u64! {
use crate::loom::sync::Mutex;
use std::sync::atomic::Ordering;

Expand Down
26 changes: 26 additions & 0 deletions tokio/src/macros/cfg.rs
Expand Up @@ -384,3 +384,29 @@ macro_rules! cfg_not_coop {
)*
}
}

macro_rules! cfg_has_atomic_u64 {
($($item:item)*) => {
$(
#[cfg(not(any(
target_arch = "arm",
target_arch = "mips",
target_arch = "powerpc"
)))]
$item
)*
}
}

macro_rules! cfg_not_has_atomic_u64 {
($($item:item)*) => {
$(
#[cfg(any(
target_arch = "arm",
target_arch = "mips",
target_arch = "powerpc"
))]
$item
)*
}
}
24 changes: 11 additions & 13 deletions tokio/src/runtime/basic_scheduler.rs
Expand Up @@ -246,7 +246,10 @@ impl<P: Park> Inner<P> {
};

match entry {
RemoteMsg::Schedule(task) => crate::coop::budget(|| task.run()),
RemoteMsg::Schedule(task) => {
let task = context.shared.owned.assert_owner(task);
crate::coop::budget(|| task.run())
}
}
}

Expand Down Expand Up @@ -319,29 +322,25 @@ impl<P: Park> Drop for BasicScheduler<P> {
}

// Drain local queue
// We already shut down every task, so we just need to drop the task.
for task in context.tasks.borrow_mut().queue.drain(..) {
task.shutdown();
drop(task);
}

// Drain remote queue and set it to None
let mut remote_queue = scheduler.spawner.shared.queue.lock();
let remote_queue = scheduler.spawner.shared.queue.lock().take();

// Using `Option::take` to replace the shared queue with `None`.
if let Some(remote_queue) = remote_queue.take() {
// We already shut down every task, so we just need to drop the task.
if let Some(remote_queue) = remote_queue {
for entry in remote_queue {
match entry {
RemoteMsg::Schedule(task) => {
task.shutdown();
drop(task);
}
}
}
}
// By dropping the mutex lock after the full duration of the above loop,
// any thread that sees the queue in the `None` state is guaranteed that
// the runtime has fully shut down.
//
// The assert below is unrelated to this mutex.
drop(remote_queue);

assert!(context.shared.owned.is_empty());
});
Expand Down Expand Up @@ -400,8 +399,7 @@ impl fmt::Debug for Spawner {

impl Schedule for Arc<Shared> {
fn release(&self, task: &Task<Self>) -> Option<Task<Self>> {
// SAFETY: Inserted into the list in bind above.
unsafe { self.owned.remove(task) }
self.owned.remove(task)
}

fn schedule(&self, task: task::Notified<Self>) {
Expand Down
2 changes: 1 addition & 1 deletion tokio/src/runtime/blocking/pool.rs
Expand Up @@ -71,7 +71,7 @@ struct Shared {
worker_thread_index: usize,
}

type Task = task::Notified<NoopSchedule>;
type Task = task::UnownedTask<NoopSchedule>;

const KEEP_ALIVE: Duration = Duration::from_secs(10);

Expand Down
31 changes: 30 additions & 1 deletion tokio/src/runtime/task/core.rs
Expand Up @@ -65,6 +65,19 @@ pub(crate) struct Header {
/// Table of function pointers for executing actions on the task.
pub(super) vtable: &'static Vtable,

/// This integer contains the id of the OwnedTasks or LocalOwnedTasks that
/// this task is stored in. If the task is not in any list, should be the
/// id of the list that it was previously in, or zero if it has never been
/// in any list.
///
/// Once a task has been bound to a list, it can never be bound to another
/// list, even if removed from the first list.
///
/// The id is not unset when removed from a list because we want to be able
/// to read the id without synchronization, even if it is concurrently being
/// removed from the list.
pub(super) owner_id: UnsafeCell<u64>,

/// The tracing ID for this instrumented task.
#[cfg(all(tokio_unstable, feature = "tracing"))]
pub(super) id: Option<tracing::Id>,
Expand Down Expand Up @@ -98,6 +111,7 @@ impl<T: Future, S: Schedule> Cell<T, S> {
owned: UnsafeCell::new(linked_list::Pointers::new()),
queue_next: UnsafeCell::new(None),
vtable: raw::vtable::<T, S>(),
owner_id: UnsafeCell::new(0),
#[cfg(all(tokio_unstable, feature = "tracing"))]
id,
},
Expand Down Expand Up @@ -203,12 +217,27 @@ impl<T: Future> CoreStage<T> {

cfg_rt_multi_thread! {
impl Header {
pub(crate) unsafe fn set_next(&self, next: Option<NonNull<Header>>) {
pub(super) unsafe fn set_next(&self, next: Option<NonNull<Header>>) {
self.queue_next.with_mut(|ptr| *ptr = next);
}
}
}

impl Header {
// safety: The caller must guarantee exclusive access to this field, and
// must ensure that the id is either 0 or the id of the OwnedTasks
// containing this task.
pub(super) unsafe fn set_owner_id(&self, owner: u64) {
self.owner_id.with_mut(|ptr| *ptr = owner);
}

pub(super) fn get_owner_id(&self) -> u64 {
// safety: If there are concurrent writes, then that write has violated
// the safety requirements on `set_owner_id`.
unsafe { self.owner_id.with(|ptr| *ptr) }
}
}

impl Trailer {
pub(crate) unsafe fn set_waker(&self, waker: Option<Waker>) {
self.waker.with_mut(|ptr| {
Expand Down
150 changes: 135 additions & 15 deletions tokio/src/runtime/task/list.rs
Expand Up @@ -8,13 +8,53 @@

use crate::future::Future;
use crate::loom::sync::Mutex;
use crate::runtime::task::{JoinHandle, Notified, Schedule, Task};
use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, Task};
use crate::util::linked_list::{Link, LinkedList};

use std::marker::PhantomData;

// The id from the module below is used to verify whether a given task is stored
// in this OwnedTasks, or some other task. The counter starts at one so we can
// use zero for tasks not owned by any list.
//
// The safety checks in this file can technically be violated if the counter is
// overflown, but the checks are not supposed to ever fail unless there is a
// bug in Tokio, so we accept that certain bugs would not be caught if the two
// mixed up runtimes happen to have the same id.

cfg_has_atomic_u64! {
use std::sync::atomic::{AtomicU64, Ordering};

static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1);

fn get_next_id() -> u64 {
loop {
let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
if id != 0 {
return id;
}
}
}
}

cfg_not_has_atomic_u64! {
use std::sync::atomic::{AtomicU32, Ordering};

static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1);

fn get_next_id() -> u64 {
loop {
let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
if id != 0 {
return u64::from(id);
}
}
}
}

pub(crate) struct OwnedTasks<S: 'static> {
inner: Mutex<OwnedTasksInner<S>>,
id: u64,
}
struct OwnedTasksInner<S: 'static> {
list: LinkedList<Task<S>, <Task<S> as Link>::Target>,
Expand All @@ -24,7 +64,8 @@ struct OwnedTasksInner<S: 'static> {
pub(crate) struct LocalOwnedTasks<S: 'static> {
list: LinkedList<Task<S>, <Task<S> as Link>::Target>,
closed: bool,
_not_send: PhantomData<*const ()>,
id: u64,
_not_send_or_sync: PhantomData<*const ()>,
}

impl<S: 'static> OwnedTasks<S> {
Expand All @@ -34,6 +75,7 @@ impl<S: 'static> OwnedTasks<S> {
list: LinkedList::new(),
closed: false,
}),
id: get_next_id(),
}
}

Expand All @@ -51,26 +93,54 @@ impl<S: 'static> OwnedTasks<S> {
{
let (task, notified, join) = super::new_task(task, scheduler);

unsafe {
// safety: We just created the task, so we have exclusive access
// to the field.
task.header().set_owner_id(self.id);
}

let mut lock = self.inner.lock();
if lock.closed {
drop(lock);
drop(task);
notified.shutdown();
drop(notified);
task.shutdown();
(join, None)
} else {
lock.list.push_front(task);
(join, Some(notified))
}
}

/// Assert that the given task is owned by this OwnedTasks and convert it to
/// a LocalNotified, giving the thread permission to poll this task.
#[inline]
pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
assert_eq!(task.0.header().get_owner_id(), self.id);

// safety: All tasks bound to this OwnedTasks are Send, so it is safe
// to poll it on this thread no matter what thread we are on.
LocalNotified {
task: task.0,
_not_send: PhantomData,
}
}

pub(crate) fn pop_back(&self) -> Option<Task<S>> {
self.inner.lock().list.pop_back()
}

/// The caller must ensure that if the provided task is stored in a
/// linked list, then it is in this linked list.
pub(crate) unsafe fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
self.inner.lock().list.remove(task.header().into())
pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
let task_id = task.header().get_owner_id();
if task_id == 0 {
// The task is unowned.
return None;
}

assert_eq!(task_id, self.id);

// safety: We just checked that the provided task is not in some other
// linked list.
unsafe { self.inner.lock().list.remove(task.header().into()) }
}

pub(crate) fn is_empty(&self) -> bool {
Expand All @@ -93,7 +163,8 @@ impl<S: 'static> LocalOwnedTasks<S> {
Self {
list: LinkedList::new(),
closed: false,
_not_send: PhantomData,
id: get_next_id(),
_not_send_or_sync: PhantomData,
}
}

Expand All @@ -109,9 +180,15 @@ impl<S: 'static> LocalOwnedTasks<S> {
{
let (task, notified, join) = super::new_task(task, scheduler);

unsafe {
// safety: We just created the task, so we have exclusive access
// to the field.
task.header().set_owner_id(self.id);
}

if self.closed {
drop(task);
notified.shutdown();
drop(notified);
task.shutdown();
(join, None)
} else {
self.list.push_front(task);
Expand All @@ -123,10 +200,33 @@ impl<S: 'static> LocalOwnedTasks<S> {
self.list.pop_back()
}

/// The caller must ensure that if the provided task is stored in a
/// linked list, then it is in this linked list.
pub(crate) unsafe fn remove(&mut self, task: &Task<S>) -> Option<Task<S>> {
self.list.remove(task.header().into())
pub(crate) fn remove(&mut self, task: &Task<S>) -> Option<Task<S>> {
let task_id = task.header().get_owner_id();
if task_id == 0 {
// The task is unowned.
return None;
}

assert_eq!(task_id, self.id);

// safety: We just checked that the provided task is not in some other
// linked list.
unsafe { self.list.remove(task.header().into()) }
}

/// Assert that the given task is owned by this LocalOwnedTasks and convert
/// it to a LocalNotified, giving the thread permission to poll this task.
#[inline]
pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
assert_eq!(task.0.header().get_owner_id(), self.id);

// safety: The task was bound to this LocalOwnedTasks, and the
// LocalOwnedTasks is not Send or Sync, so we are on the right thread
// for polling this task.
LocalNotified {
task: task.0,
_not_send: PhantomData,
}
}

pub(crate) fn is_empty(&self) -> bool {
Expand All @@ -139,3 +239,23 @@ impl<S: 'static> LocalOwnedTasks<S> {
self.closed = true;
}
}

#[cfg(all(test))]
mod tests {
use super::*;

// This test may run in parallel with other tests, so we only test that ids
// come in increasing order.
#[test]
fn test_id_not_broken() {
let mut last_id = get_next_id();
assert_ne!(last_id, 0);

for _ in 0..1000 {
let next_id = get_next_id();
assert_ne!(next_id, 0);
assert!(last_id < next_id);
last_id = next_id;
}
}
}

0 comments on commit f2a06bf

Please sign in to comment.