Skip to content

Commit

Permalink
runtime: reduce codegen per task (#5213)
Browse files Browse the repository at this point in the history
This PR should hopefully reduce the amount of code generated per
future-type spawned on the runtime. The following methods are no longer generic:

* `try_set_join_waker`
* `remote_abort`
* `clone_waker`
* `drop_waker`
* `wake_by_ref`
* `wake_by_val`

A new method is added to the vtable called schedule, which is used when a task
should be scheduled on the runtime. E.g. wake_by_ref will call it if the state change
says that the task needs to be scheduled. However, this method is only generic over
the scheduler, and not the future type, so it also isn't generated for every task.

Additionally, one of the changes involved in the above makes it possible to remove
the id field from JoinHandle and AbortHandle.
  • Loading branch information
Darksonn committed Nov 21, 2022
1 parent 304b515 commit 45e37db
Show file tree
Hide file tree
Showing 7 changed files with 323 additions and 257 deletions.
35 changes: 14 additions & 21 deletions tokio/src/runtime/task/abort.rs
@@ -1,4 +1,4 @@
use crate::runtime::task::{Id, RawTask};
use crate::runtime::task::{Header, RawTask};
use std::fmt;
use std::panic::{RefUnwindSafe, UnwindSafe};

Expand All @@ -14,13 +14,12 @@ use std::panic::{RefUnwindSafe, UnwindSafe};
/// [`JoinHandle`]: crate::task::JoinHandle
#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
pub struct AbortHandle {
raw: Option<RawTask>,
id: Id,
raw: RawTask,
}

impl AbortHandle {
pub(super) fn new(raw: Option<RawTask>, id: Id) -> Self {
Self { raw, id }
pub(super) fn new(raw: RawTask) -> Self {
Self { raw }
}

/// Abort the task associated with the handle.
Expand All @@ -35,9 +34,7 @@ impl AbortHandle {
/// [cancelled]: method@super::error::JoinError::is_cancelled
/// [`JoinHandle::abort`]: method@super::JoinHandle::abort
pub fn abort(&self) {
if let Some(ref raw) = self.raw {
raw.remote_abort();
}
self.raw.remote_abort();
}

/// Checks if the task associated with this `AbortHandle` has finished.
Expand All @@ -47,12 +44,8 @@ impl AbortHandle {
/// some time, and this method does not return `true` until it has
/// completed.
pub fn is_finished(&self) -> bool {
if let Some(raw) = self.raw {
let state = raw.header().state.load();
state.is_complete()
} else {
true
}
let state = self.raw.state().load();
state.is_complete()
}

/// Returns a [task ID] that uniquely identifies this task relative to other
Expand All @@ -67,7 +60,8 @@ impl AbortHandle {
#[cfg(tokio_unstable)]
#[cfg_attr(docsrs, doc(cfg(tokio_unstable)))]
pub fn id(&self) -> super::Id {
self.id
// Safety: The header pointer is valid.
unsafe { Header::get_id(self.raw.header_ptr()) }
}
}

Expand All @@ -79,16 +73,15 @@ impl RefUnwindSafe for AbortHandle {}

impl fmt::Debug for AbortHandle {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("AbortHandle")
.field("id", &self.id)
.finish()
// Safety: The header pointer is valid.
let id_ptr = unsafe { Header::get_id_ptr(self.raw.header_ptr()) };
let id = unsafe { id_ptr.as_ref() };
fmt.debug_struct("AbortHandle").field("id", id).finish()
}
}

impl Drop for AbortHandle {
fn drop(&mut self) {
if let Some(raw) = self.raw.take() {
raw.drop_abort_handle();
}
self.raw.drop_abort_handle();
}
}
74 changes: 67 additions & 7 deletions tokio/src/runtime/task/core.rs
Expand Up @@ -25,6 +25,9 @@ use std::task::{Context, Poll, Waker};
///
/// It is critical for `Header` to be the first field as the task structure will
/// be referenced by both *mut Cell and *mut Header.
///
/// Any changes to the layout of this struct _must_ also be reflected in the
/// const fns in raw.rs.
#[repr(C)]
pub(super) struct Cell<T: Future, S> {
/// Hot task state data
Expand All @@ -44,15 +47,19 @@ pub(super) struct CoreStage<T: Future> {
/// The core of the task.
///
/// Holds the future or output, depending on the stage of execution.
///
/// Any changes to the layout of this struct _must_ also be reflected in the
/// const fns in raw.rs.
#[repr(C)]
pub(super) struct Core<T: Future, S> {
/// Scheduler used to drive this future.
pub(super) scheduler: S,

/// Either the future or the output.
pub(super) stage: CoreStage<T>,

/// The task's ID, used for populating `JoinError`s.
pub(super) task_id: Id,

/// Either the future or the output.
pub(super) stage: CoreStage<T>,
}

/// Crate public as this is also needed by the pool.
Expand Down Expand Up @@ -82,7 +89,7 @@ pub(crate) struct Header {

/// The tracing ID for this instrumented task.
#[cfg(all(tokio_unstable, feature = "tracing"))]
pub(super) id: Option<tracing::Id>,
pub(super) tracing_id: Option<tracing::Id>,
}

unsafe impl Send for Header {}
Expand Down Expand Up @@ -117,15 +124,15 @@ impl<T: Future, S: Schedule> Cell<T, S> {
/// structures.
pub(super) fn new(future: T, scheduler: S, state: State, task_id: Id) -> Box<Cell<T, S>> {
#[cfg(all(tokio_unstable, feature = "tracing"))]
let id = future.id();
let tracing_id = future.id();
let result = Box::new(Cell {
header: Header {
state,
queue_next: UnsafeCell::new(None),
vtable: raw::vtable::<T, S>(),
owner_id: UnsafeCell::new(0),
#[cfg(all(tokio_unstable, feature = "tracing"))]
id,
tracing_id,
},
core: Core {
scheduler,
Expand All @@ -144,8 +151,16 @@ impl<T: Future, S: Schedule> Cell<T, S> {
{
let trailer_addr = (&result.trailer) as *const Trailer as usize;
let trailer_ptr = unsafe { Header::get_trailer(NonNull::from(&result.header)) };

assert_eq!(trailer_addr, trailer_ptr.as_ptr() as usize);

let scheduler_addr = (&result.core.scheduler) as *const S as usize;
let scheduler_ptr =
unsafe { Header::get_scheduler::<S>(NonNull::from(&result.header)) };
assert_eq!(scheduler_addr, scheduler_ptr.as_ptr() as usize);

let id_addr = (&result.core.task_id) as *const Id as usize;
let id_ptr = unsafe { Header::get_id_ptr(NonNull::from(&result.header)) };
assert_eq!(id_addr, id_ptr.as_ptr() as usize);
}

result
Expand Down Expand Up @@ -295,6 +310,51 @@ impl Header {
let trailer = me.as_ptr().cast::<u8>().add(offset).cast::<Trailer>();
NonNull::new_unchecked(trailer)
}

/// Gets a pointer to the scheduler of the task containing this `Header`.
///
/// # Safety
///
/// The provided raw pointer must point at the header of a task.
///
/// The generic type S must be set to the correct scheduler type for this
/// task.
pub(super) unsafe fn get_scheduler<S>(me: NonNull<Header>) -> NonNull<S> {
let offset = me.as_ref().vtable.scheduler_offset;
let scheduler = me.as_ptr().cast::<u8>().add(offset).cast::<S>();
NonNull::new_unchecked(scheduler)
}

/// Gets a pointer to the id of the task containing this `Header`.
///
/// # Safety
///
/// The provided raw pointer must point at the header of a task.
pub(super) unsafe fn get_id_ptr(me: NonNull<Header>) -> NonNull<Id> {
let offset = me.as_ref().vtable.id_offset;
let id = me.as_ptr().cast::<u8>().add(offset).cast::<Id>();
NonNull::new_unchecked(id)
}

/// Gets the id of the task containing this `Header`.
///
/// # Safety
///
/// The provided raw pointer must point at the header of a task.
pub(super) unsafe fn get_id(me: NonNull<Header>) -> Id {
let ptr = Header::get_id_ptr(me).as_ptr();
*ptr
}

/// Gets the tracing id of the task containing this `Header`.
///
/// # Safety
///
/// The provided raw pointer must point at the header of a task.
#[cfg(all(tokio_unstable, feature = "tracing"))]
pub(super) unsafe fn get_tracing_id(me: &NonNull<Header>) -> Option<&tracing::Id> {
me.as_ref().tracing_id.as_ref()
}
}

impl Trailer {
Expand Down

0 comments on commit 45e37db

Please sign in to comment.