Skip to content

Commit

Permalink
WIP: implement cancellation and raw handles
Browse files Browse the repository at this point in the history
  • Loading branch information
udoprog committed May 2, 2020
1 parent 20b5df9 commit 4ef9002
Show file tree
Hide file tree
Showing 3 changed files with 266 additions and 44 deletions.
99 changes: 59 additions & 40 deletions tokio/src/runtime/task/harness.rs
Expand Up @@ -154,56 +154,75 @@ where

/// Read the task output into `dst`.
pub(super) fn try_read_output(self, dst: &mut Poll<super::Result<T::Output>>, waker: &Waker) {
if self.try_poll_completion(waker) {
*dst = Poll::Ready(self.core().take_output());
}
}

pub(super) fn try_read_completion(self, dst: &mut Poll<super::Result<()>>, waker: &Waker) {
if self.try_poll_completion(waker) {
// TODO: might be room for more optimization to avoid reading out
// the output (avoiding the map).
*dst = Poll::Ready(self.core().take_output().map(|_| ()));
}
}

/// Try to poll the task associated with the harness to completion, or
/// install the associated waker to be notified when the task is complete.
///
/// Returns a boolean indicating if the task is complete or not.
fn try_poll_completion(&self, waker: &Waker) -> bool {
// Load a snapshot of the current task state
let snapshot = self.header().state.load();

debug_assert!(snapshot.is_join_interested());

if !snapshot.is_complete() {
// The waker must be stored in the task struct.
let res = if snapshot.has_join_waker() {
// There already is a waker stored in the struct. If it matches
// the provided waker, then there is no further work to do.
// Otherwise, the waker must be swapped.
let will_wake = unsafe {
// Safety: when `JOIN_INTEREST` is set, only `JOIN_HANDLE`
// may mutate the `waker` field.
self.trailer()
.waker
.with(|ptr| (*ptr).as_ref().unwrap().will_wake(waker))
};

if will_wake {
// The task is not complete **and** the waker is up to date,
// there is nothing further that needs to be done.
return;
}
if snapshot.is_complete() {
return true;
}

// Unset the `JOIN_WAKER` to gain mutable access to the `waker`
// field then update the field with the new join worker.
//
// This requires two atomic operations, unsetting the bit and
// then resetting it. If the task transitions to complete
// concurrently to either one of those operations, then setting
// the join waker fails and we proceed to reading the task
// output.
self.header()
.state
.unset_waker()
.and_then(|snapshot| self.set_join_waker(waker.clone(), snapshot))
} else {
self.set_join_waker(waker.clone(), snapshot)
// The waker must be stored in the task struct.
let res = if snapshot.has_join_waker() {
// There already is a waker stored in the struct. If it matches
// the provided waker, then there is no further work to do.
// Otherwise, the waker must be swapped.
let will_wake = unsafe {
// Safety: when `JOIN_INTEREST` is set, only `JOIN_HANDLE`
// may mutate the `waker` field.
self.trailer()
.waker
.with(|ptr| (*ptr).as_ref().unwrap().will_wake(waker))
};

match res {
Ok(_) => return,
Err(snapshot) => {
assert!(snapshot.is_complete());
}
if will_wake {
// The task is not complete **and** the waker is up to date,
// there is nothing further that needs to be done.
return false;
}
}

*dst = Poll::Ready(self.core().take_output());
// Unset the `JOIN_WAKER` to gain mutable access to the `waker`
// field then update the field with the new join worker.
//
// This requires two atomic operations, unsetting the bit and
// then resetting it. If the task transitions to complete
// concurrently to either one of those operations, then setting
// the join waker fails and we proceed to reading the task
// output.
self.header()
.state
.unset_waker()
.and_then(|snapshot| self.set_join_waker(waker.clone(), snapshot))
} else {
self.set_join_waker(waker.clone(), snapshot)
};

match res {
Ok(_) => false,
Err(snapshot) => {
assert!(snapshot.is_complete());
true
}
}
}

fn set_join_waker(&self, waker: Waker, snapshot: Snapshot) -> Result<Snapshot, Snapshot> {
Expand Down
190 changes: 186 additions & 4 deletions tokio/src/runtime/task/join.rs
Expand Up @@ -76,9 +76,26 @@ doc_rt_core! {
/// [`task::spawn_blocking`]: crate::task::spawn_blocking
/// [`std::thread::JoinHandle`]: std::thread::JoinHandle
pub struct JoinHandle<T> {
raw: Option<RawTask>,
raw: RawJoinHandle,
_p: PhantomData<T>,
}

/// A type-erased variant of [`JoinHandle`], these are created by using
/// [`JoinHandle::into_raw_handle`].
///
/// Raw join handles erase the type information associated with a
/// [`JoinHandle`], allowing them to easily be stored in containers for
/// future cancellation.
///
/// They behave exactly the same as regular [`JoinHandle`], except that
/// instead of resolving to `Result<T, JoinError>`, they resolve to
/// `Result<(), JoinError>`.
///
/// [`JoinHandle`]: JoinHandle
/// [`JoinHandle::into_raw_handle`]: JoinHandle::into_raw_handle
pub struct RawJoinHandle {
raw: Option<RawTask>,
}
}

unsafe impl<T: Send> Send for JoinHandle<T> {}
Expand All @@ -87,13 +104,171 @@ unsafe impl<T: Send> Sync for JoinHandle<T> {}
impl<T> JoinHandle<T> {
pub(super) fn new(raw: RawTask) -> JoinHandle<T> {
JoinHandle {
raw: Some(raw),
raw: RawJoinHandle { raw: Some(raw) },
_p: PhantomData,
}
}

/// Cancel the task associated with the handle.
///
/// Awaiting a cancelled task might complete as usual if the task was
/// already completed at the time it was cancelled, but most likely it
/// will complete with a `Err(JoinError::Cancelled)`.
///
/// ```rust
/// use tokio::time;
///
/// #[tokio::main]
/// async fn main() {
/// let mut handles = Vec::new();
///
/// handles.push(tokio::spawn(async {
/// time::delay_for(time::Duration::from_secs(10)).await;
/// true
/// }));
///
/// handles.push(tokio::spawn(async {
/// time::delay_for(time::Duration::from_secs(10)).await;
/// false
/// }));
///
/// for handle in &handles {
/// handle.cancel();
/// }
///
/// for handle in handles {
/// assert!(handle.await.unwrap_err().is_cancelled());
/// }
/// }
/// ```
pub fn cancel(&self) {
if let Some(raw) = self.raw.raw {
raw.shutdown();
}
}

/// Convert the join handle into a raw join handle.
///
/// Raw join handles erase the type information associated with a
/// [`JoinHandle`], allowing them to easily be stored in containers for
/// future cancellation.
///
/// [`JoinHandle`]: JoinHandle
///
/// # Examples
///
/// ```rust
/// use tokio::time;
///
/// #[tokio::main]
/// async fn main() {
/// let mut handles = Vec::new();
///
/// let handle = tokio::spawn(async {
/// time::delay_for(time::Duration::from_secs(10)).await;
/// true
/// });
///
/// handles.push(handle.into_raw_handle());
///
/// let handle = tokio::spawn(async {
/// time::delay_for(time::Duration::from_secs(10)).await;
/// 1u32
/// });
///
/// handles.push(handle.into_raw_handle());
///
/// for handle in &handles {
/// handle.cancel();
/// }
///
/// for handle in handles {
/// assert!(handle.await.unwrap_err().is_cancelled());
/// }
/// }
/// ```
pub fn into_raw_handle(self) -> RawJoinHandle {
self.raw
}
}

impl RawJoinHandle {
/// Cancel the task associated with the handle.
///
/// Awaiting a cancelled task might complete as usual if the task was
/// already completed at the time it was cancelled, but most likely it
/// will complete with a `Err(JoinError::Cancelled)`.
///
/// # Examples
///
/// ```rust
/// use tokio::time;
///
/// #[tokio::main]
/// async fn main() {
/// let mut handles = Vec::new();
///
/// let handle = tokio::spawn(async {
/// time::delay_for(time::Duration::from_secs(10)).await;
/// true
/// });
///
/// handles.push(handle.into_raw_handle());
///
/// let handle = tokio::spawn(async {
/// time::delay_for(time::Duration::from_secs(10)).await;
/// 1u32
/// });
///
/// handles.push(handle.into_raw_handle());
///
/// for handle in &handles {
/// handle.cancel();
/// }
///
/// for handle in handles {
/// assert!(handle.await.unwrap_err().is_cancelled());
/// }
/// }
/// ```
pub fn cancel(&self) {
if let Some(raw) = self.raw {
raw.shutdown();
}
}
}

impl<T> Unpin for JoinHandle<T> {}
impl Unpin for RawJoinHandle {}

impl Future for RawJoinHandle {
type Output = super::Result<()>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut ret = Poll::Pending;

// Keep track of task budget
ready!(crate::coop::poll_proceed(cx));

// Raw should always be set. If it is not, this is due to polling after
// completion
let raw = self
.raw
.as_ref()
.expect("polling after `JoinHandle` already completed");

// Try to read the task completion. If the task is not yet complete, the
// waker is stored and is notified once the task does complete.
//
// Safety:
//
// `ret` must be of type `Poll<()>`.
unsafe {
raw.try_read_completion(&mut ret as *mut _ as *mut (), cx.waker());
}

ret
}
}

impl<T> Future for JoinHandle<T> {
type Output = super::Result<T>;
Expand All @@ -107,6 +282,7 @@ impl<T> Future for JoinHandle<T> {
// Raw should always be set. If it is not, this is due to polling after
// completion
let raw = self
.raw
.raw
.as_ref()
.expect("polling after `JoinHandle` already completed");
Expand All @@ -130,7 +306,7 @@ impl<T> Future for JoinHandle<T> {
}
}

impl<T> Drop for JoinHandle<T> {
impl Drop for RawJoinHandle {
fn drop(&mut self) {
if let Some(raw) = self.raw.take() {
if raw.header().state.drop_join_handle_fast().is_ok() {
Expand All @@ -150,3 +326,9 @@ where
fmt.debug_struct("JoinHandle").finish()
}
}

impl fmt::Debug for RawJoinHandle {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("RawJoinHandle").finish()
}
}
21 changes: 21 additions & 0 deletions tokio/src/runtime/task/raw.rs
Expand Up @@ -19,6 +19,9 @@ pub(super) struct Vtable {
/// Read the task output, if complete
pub(super) try_read_output: unsafe fn(NonNull<Header>, *mut (), &Waker),

/// Read the task completion, if complete
pub(super) try_read_completion: unsafe fn(NonNull<Header>, *mut (), &Waker),

/// The join handle has been dropped
pub(super) drop_join_handle_slow: unsafe fn(NonNull<Header>),

Expand All @@ -32,6 +35,7 @@ pub(super) fn vtable<T: Future, S: Schedule>() -> &'static Vtable {
poll: poll::<T, S>,
dealloc: dealloc::<T, S>,
try_read_output: try_read_output::<T, S>,
try_read_completion: try_read_completion::<T, S>,
drop_join_handle_slow: drop_join_handle_slow::<T, S>,
shutdown: shutdown::<T, S>,
}
Expand Down Expand Up @@ -80,6 +84,12 @@ impl RawTask {
(vtable.try_read_output)(self.ptr, dst, waker);
}

/// Safety: `dst` must be a `*mut Poll<super::Result<()>>`.
pub(super) unsafe fn try_read_completion(self, dst: *mut (), waker: &Waker) {
let vtable = self.header().vtable;
(vtable.try_read_completion)(self.ptr, dst, waker);
}

pub(super) fn drop_join_handle_slow(self) {
let vtable = self.header().vtable;
unsafe { (vtable.drop_join_handle_slow)(self.ptr) }
Expand Down Expand Up @@ -120,6 +130,17 @@ unsafe fn try_read_output<T: Future, S: Schedule>(
harness.try_read_output(out, waker);
}

unsafe fn try_read_completion<T: Future, S: Schedule>(
ptr: NonNull<Header>,
dst: *mut (),
waker: &Waker,
) {
let out = &mut *(dst as *mut Poll<super::Result<()>>);

let harness = Harness::<T, S>::from_raw(ptr);
harness.try_read_completion(out, waker);
}

unsafe fn drop_join_handle_slow<T: Future, S: Schedule>(ptr: NonNull<Header>) {
let harness = Harness::<T, S>::from_raw(ptr);
harness.drop_join_handle_slow()
Expand Down

0 comments on commit 4ef9002

Please sign in to comment.