Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

runtime: reduce codegen per task #5213

Merged
merged 2 commits into from Nov 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 14 additions & 21 deletions tokio/src/runtime/task/abort.rs
@@ -1,4 +1,4 @@
use crate::runtime::task::{Id, RawTask};
use crate::runtime::task::{Header, RawTask};
use std::fmt;
use std::panic::{RefUnwindSafe, UnwindSafe};

Expand All @@ -14,13 +14,12 @@ use std::panic::{RefUnwindSafe, UnwindSafe};
/// [`JoinHandle`]: crate::task::JoinHandle
#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
pub struct AbortHandle {
raw: Option<RawTask>,
id: Id,
raw: RawTask,
}

impl AbortHandle {
pub(super) fn new(raw: Option<RawTask>, id: Id) -> Self {
Self { raw, id }
pub(super) fn new(raw: RawTask) -> Self {
Self { raw }
}

/// Abort the task associated with the handle.
Expand All @@ -35,9 +34,7 @@ impl AbortHandle {
/// [cancelled]: method@super::error::JoinError::is_cancelled
/// [`JoinHandle::abort`]: method@super::JoinHandle::abort
pub fn abort(&self) {
if let Some(ref raw) = self.raw {
raw.remote_abort();
}
self.raw.remote_abort();
}

/// Checks if the task associated with this `AbortHandle` has finished.
Expand All @@ -47,12 +44,8 @@ impl AbortHandle {
/// some time, and this method does not return `true` until it has
/// completed.
pub fn is_finished(&self) -> bool {
if let Some(raw) = self.raw {
let state = raw.header().state.load();
state.is_complete()
} else {
true
}
let state = self.raw.state().load();
state.is_complete()
}

/// Returns a [task ID] that uniquely identifies this task relative to other
Expand All @@ -67,7 +60,8 @@ impl AbortHandle {
#[cfg(tokio_unstable)]
#[cfg_attr(docsrs, doc(cfg(tokio_unstable)))]
pub fn id(&self) -> super::Id {
self.id
// Safety: The header pointer is valid.
unsafe { Header::get_id(self.raw.header_ptr()) }
}
}

Expand All @@ -79,16 +73,15 @@ impl RefUnwindSafe for AbortHandle {}

impl fmt::Debug for AbortHandle {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("AbortHandle")
.field("id", &self.id)
.finish()
// Safety: The header pointer is valid.
let id_ptr = unsafe { Header::get_id_ptr(self.raw.header_ptr()) };
let id = unsafe { id_ptr.as_ref() };
fmt.debug_struct("AbortHandle").field("id", id).finish()
}
}

impl Drop for AbortHandle {
fn drop(&mut self) {
if let Some(raw) = self.raw.take() {
raw.drop_abort_handle();
}
self.raw.drop_abort_handle();
}
}
74 changes: 67 additions & 7 deletions tokio/src/runtime/task/core.rs
Expand Up @@ -25,6 +25,9 @@ use std::task::{Context, Poll, Waker};
///
/// It is critical for `Header` to be the first field as the task structure will
/// be referenced by both *mut Cell and *mut Header.
///
/// Any changes to the layout of this struct _must_ also be reflected in the
/// const fns in raw.rs.
#[repr(C)]
pub(super) struct Cell<T: Future, S> {
/// Hot task state data
Expand All @@ -44,15 +47,19 @@ pub(super) struct CoreStage<T: Future> {
/// The core of the task.
///
/// Holds the future or output, depending on the stage of execution.
///
/// Any changes to the layout of this struct _must_ also be reflected in the
/// const fns in raw.rs.
#[repr(C)]
pub(super) struct Core<T: Future, S> {
/// Scheduler used to drive this future.
pub(super) scheduler: S,

/// Either the future or the output.
pub(super) stage: CoreStage<T>,

/// The task's ID, used for populating `JoinError`s.
pub(super) task_id: Id,

/// Either the future or the output.
pub(super) stage: CoreStage<T>,
}

/// Crate public as this is also needed by the pool.
Expand Down Expand Up @@ -82,7 +89,7 @@ pub(crate) struct Header {

/// The tracing ID for this instrumented task.
#[cfg(all(tokio_unstable, feature = "tracing"))]
pub(super) id: Option<tracing::Id>,
pub(super) tracing_id: Option<tracing::Id>,
}

unsafe impl Send for Header {}
Expand Down Expand Up @@ -117,15 +124,15 @@ impl<T: Future, S: Schedule> Cell<T, S> {
/// structures.
pub(super) fn new(future: T, scheduler: S, state: State, task_id: Id) -> Box<Cell<T, S>> {
#[cfg(all(tokio_unstable, feature = "tracing"))]
let id = future.id();
let tracing_id = future.id();
let result = Box::new(Cell {
header: Header {
state,
queue_next: UnsafeCell::new(None),
vtable: raw::vtable::<T, S>(),
owner_id: UnsafeCell::new(0),
#[cfg(all(tokio_unstable, feature = "tracing"))]
id,
tracing_id,
},
core: Core {
scheduler,
Expand All @@ -144,8 +151,16 @@ impl<T: Future, S: Schedule> Cell<T, S> {
{
let trailer_addr = (&result.trailer) as *const Trailer as usize;
let trailer_ptr = unsafe { Header::get_trailer(NonNull::from(&result.header)) };

assert_eq!(trailer_addr, trailer_ptr.as_ptr() as usize);

let scheduler_addr = (&result.core.scheduler) as *const S as usize;
let scheduler_ptr =
unsafe { Header::get_scheduler::<S>(NonNull::from(&result.header)) };
assert_eq!(scheduler_addr, scheduler_ptr.as_ptr() as usize);

let id_addr = (&result.core.task_id) as *const Id as usize;
let id_ptr = unsafe { Header::get_id_ptr(NonNull::from(&result.header)) };
assert_eq!(id_addr, id_ptr.as_ptr() as usize);
}

result
Expand Down Expand Up @@ -295,6 +310,51 @@ impl Header {
let trailer = me.as_ptr().cast::<u8>().add(offset).cast::<Trailer>();
NonNull::new_unchecked(trailer)
}

/// Gets a pointer to the scheduler of the task containing this `Header`.
///
/// # Safety
///
/// The provided raw pointer must point at the header of a task.
///
/// The generic type S must be set to the correct scheduler type for this
/// task.
pub(super) unsafe fn get_scheduler<S>(me: NonNull<Header>) -> NonNull<S> {
let offset = me.as_ref().vtable.scheduler_offset;
let scheduler = me.as_ptr().cast::<u8>().add(offset).cast::<S>();
NonNull::new_unchecked(scheduler)
}

/// Gets a pointer to the id of the task containing this `Header`.
///
/// # Safety
///
/// The provided raw pointer must point at the header of a task.
pub(super) unsafe fn get_id_ptr(me: NonNull<Header>) -> NonNull<Id> {
let offset = me.as_ref().vtable.id_offset;
let id = me.as_ptr().cast::<u8>().add(offset).cast::<Id>();
NonNull::new_unchecked(id)
}

/// Gets the id of the task containing this `Header`.
///
/// # Safety
///
/// The provided raw pointer must point at the header of a task.
pub(super) unsafe fn get_id(me: NonNull<Header>) -> Id {
let ptr = Header::get_id_ptr(me).as_ptr();
*ptr
}

/// Gets the tracing id of the task containing this `Header`.
///
/// # Safety
///
/// The provided raw pointer must point at the header of a task.
#[cfg(all(tokio_unstable, feature = "tracing"))]
pub(super) unsafe fn get_tracing_id(me: &NonNull<Header>) -> Option<&tracing::Id> {
me.as_ref().tracing_id.as_ref()
}
}

impl Trailer {
Expand Down