Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Abortable streams #2410

Merged
merged 9 commits into from May 10, 2021
185 changes: 185 additions & 0 deletions futures-util/src/abortable.rs
@@ -0,0 +1,185 @@
use crate::task::AtomicWaker;
use alloc::sync::Arc;
use core::fmt;
use core::pin::Pin;
use core::sync::atomic::{AtomicBool, Ordering};
use futures_core::future::Future;
use futures_core::task::{Context, Poll};
use futures_core::Stream;
use pin_project_lite::pin_project;

pin_project! {
/// A future/stream which can be remotely short-circuited using an `AbortHandle`.
#[derive(Debug, Clone)]
#[must_use = "futures/streams do nothing unless you poll them"]
pub struct Abortable<T> {
#[pin]
task: T,
inner: Arc<AbortInner>,
}
}

impl<T> Abortable<T> {
/// Creates a new `Abortable` future/stream using an existing `AbortRegistration`.
/// `AbortRegistration`s can be acquired through `AbortHandle::new`.
///
/// When `abort` is called on the handle tied to `reg` or if `abort` has
/// already been called, the future/stream will complete immediately without making
/// any further progress.
///
/// # Examples:
///
/// Usage with futures:
///
/// ```
/// # futures::executor::block_on(async {
/// use futures::future::{Abortable, AbortHandle, Aborted};
///
/// let (abort_handle, abort_registration) = AbortHandle::new_pair();
/// let future = Abortable::new(async { 2 }, abort_registration);
/// abort_handle.abort();
/// assert_eq!(future.await, Err(Aborted));
/// # });
/// ```
///
/// Usage with streams:
///
/// ```
/// # futures::executor::block_on(async {
/// # use futures::future::{Abortable, AbortHandle};
/// # use futures::stream::{self, StreamExt};
///
/// let (abort_handle, abort_registration) = AbortHandle::new_pair();
/// let mut stream = Abortable::new(stream::iter(vec![1, 2, 3]), abort_registration);
/// abort_handle.abort();
/// assert_eq!(stream.next().await, None);
/// # });
/// ```
pub fn new(task: T, reg: AbortRegistration) -> Self {
Self { task, inner: reg.inner }
}

/// Checks whether the task has been aborted. Note that all this
/// method indicates is whether [`AbortHandle::abort`] was *called*.
/// This means that it will return `true` even if:
/// * `abort` was called after the task had completed.
/// * `abort` was called while the task was being polled - the task may still be running and
/// will not be stopped until `poll` returns.
pub fn is_aborted(&self) -> bool {
self.inner.aborted.load(Ordering::Relaxed)
}
}

/// A registration handle for an `Abortable` task.
/// Values of this type can be acquired from `AbortHandle::new` and are used
/// in calls to `Abortable::new`.
#[derive(Debug)]
pub struct AbortRegistration {
inner: Arc<AbortInner>,
}

/// A handle to an `Abortable` task.
#[derive(Debug, Clone)]
pub struct AbortHandle {
inner: Arc<AbortInner>,
}

impl AbortHandle {
/// Creates an (`AbortHandle`, `AbortRegistration`) pair which can be used
/// to abort a running future or stream.
///
/// This function is usually paired with a call to [`Abortable::new`].
pub fn new_pair() -> (Self, AbortRegistration) {
let inner =
Arc::new(AbortInner { waker: AtomicWaker::new(), aborted: AtomicBool::new(false) });

(Self { inner: inner.clone() }, AbortRegistration { inner })
}
}

// Inner type storing the waker to awaken and a bool indicating that it
// should be aborted.
#[derive(Debug)]
struct AbortInner {
waker: AtomicWaker,
aborted: AtomicBool,
}

/// Indicator that the `Abortable` task was aborted.
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub struct Aborted;

impl fmt::Display for Aborted {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "`Abortable` future has been aborted")
}
}

#[cfg(feature = "std")]
impl std::error::Error for Aborted {}

impl<T> Abortable<T> {
fn try_poll<I>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
poll: impl Fn(Pin<&mut T>, &mut Context<'_>) -> Poll<I>,
) -> Poll<Result<I, Aborted>> {
// Check if the task has been aborted
if self.is_aborted() {
return Poll::Ready(Err(Aborted));
}

// attempt to complete the task
if let Poll::Ready(x) = poll(self.as_mut().project().task, cx) {
return Poll::Ready(Ok(x));
}

// Register to receive a wakeup if the task is aborted in the future
self.inner.waker.register(cx.waker());

// Check to see if the task was aborted between the first check and
// registration.
// Checking with `is_aborted` which uses `Relaxed` is sufficient because
// `register` introduces an `AcqRel` barrier.
if self.is_aborted() {
return Poll::Ready(Err(Aborted));
}

Poll::Pending
}
}

impl<Fut> Future for Abortable<Fut>
where
Fut: Future,
{
type Output = Result<Fut::Output, Aborted>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.try_poll(cx, |fut, cx| fut.poll(cx))
}
}

impl<St> Stream for Abortable<St>
where
St: Stream,
{
type Item = St::Item;
ibraheemdev marked this conversation as resolved.
Show resolved Hide resolved

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.try_poll(cx, |stream, cx| stream.poll_next(cx)).map(Result::ok).map(Option::flatten)
}
}

impl AbortHandle {
/// Abort the `Abortable` stream/future associated with this handle.
///
/// Notifies the Abortable task associated with this handle that it
/// should abort. Note that if the task is currently being polled on
/// another thread, it will not immediately stop running. Instead, it will
/// continue to run until its poll method returns.
pub fn abort(&self) {
self.inner.aborted.store(true, Ordering::Relaxed);
self.inner.waker.wake();
}
}
158 changes: 4 additions & 154 deletions futures-util/src/future/abortable.rs
@@ -1,101 +1,8 @@
use super::assert_future;
use crate::task::AtomicWaker;
use alloc::sync::Arc;
use core::fmt;
use core::pin::Pin;
use core::sync::atomic::{AtomicBool, Ordering};
use crate::future::{AbortHandle, Abortable, Aborted};
use futures_core::future::Future;
use futures_core::task::{Context, Poll};
use pin_project_lite::pin_project;

pin_project! {
/// A future which can be remotely short-circuited using an `AbortHandle`.
#[derive(Debug, Clone)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Abortable<Fut> {
#[pin]
future: Fut,
inner: Arc<AbortInner>,
}
}

impl<Fut> Abortable<Fut>
where
Fut: Future,
{
/// Creates a new `Abortable` future using an existing `AbortRegistration`.
/// `AbortRegistration`s can be acquired through `AbortHandle::new`.
///
/// When `abort` is called on the handle tied to `reg` or if `abort` has
/// already been called, the future will complete immediately without making
/// any further progress.
///
/// Example:
///
/// ```
/// # futures::executor::block_on(async {
/// use futures::future::{Abortable, AbortHandle, Aborted};
///
/// let (abort_handle, abort_registration) = AbortHandle::new_pair();
/// let future = Abortable::new(async { 2 }, abort_registration);
/// abort_handle.abort();
/// assert_eq!(future.await, Err(Aborted));
/// # });
/// ```
pub fn new(future: Fut, reg: AbortRegistration) -> Self {
assert_future::<Result<Fut::Output, Aborted>, _>(Self { future, inner: reg.inner })
}
}

/// A registration handle for a `Abortable` future.
/// Values of this type can be acquired from `AbortHandle::new` and are used
/// in calls to `Abortable::new`.
#[derive(Debug)]
pub struct AbortRegistration {
inner: Arc<AbortInner>,
}

/// A handle to a `Abortable` future.
#[derive(Debug, Clone)]
pub struct AbortHandle {
inner: Arc<AbortInner>,
}

impl AbortHandle {
/// Creates an (`AbortHandle`, `AbortRegistration`) pair which can be used
/// to abort a running future.
///
/// This function is usually paired with a call to `Abortable::new`.
///
/// Example:
///
/// ```
/// # futures::executor::block_on(async {
/// use futures::future::{Abortable, AbortHandle, Aborted};
///
/// let (abort_handle, abort_registration) = AbortHandle::new_pair();
/// let future = Abortable::new(async { 2 }, abort_registration);
/// abort_handle.abort();
/// assert_eq!(future.await, Err(Aborted));
/// # });
/// ```
pub fn new_pair() -> (Self, AbortRegistration) {
let inner =
Arc::new(AbortInner { waker: AtomicWaker::new(), cancel: AtomicBool::new(false) });

(Self { inner: inner.clone() }, AbortRegistration { inner })
}
}

// Inner type storing the waker to awaken and a bool indicating that it
// should be cancelled.
#[derive(Debug)]
struct AbortInner {
waker: AtomicWaker,
cancel: AtomicBool,
}

/// Creates a new `Abortable` future and a `AbortHandle` which can be used to stop it.
/// Creates a new `Abortable` future and an `AbortHandle` which can be used to stop it.
///
/// This function is a convenient (but less flexible) alternative to calling
/// `AbortHandle::new` and `Abortable::new` manually.
Expand All @@ -107,63 +14,6 @@ where
Fut: Future,
{
let (handle, reg) = AbortHandle::new_pair();
(Abortable::new(future, reg), handle)
}

/// Indicator that the `Abortable` future was aborted.
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub struct Aborted;

impl fmt::Display for Aborted {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "`Abortable` future has been aborted")
}
}

#[cfg(feature = "std")]
impl std::error::Error for Aborted {}

impl<Fut> Future for Abortable<Fut>
where
Fut: Future,
{
type Output = Result<Fut::Output, Aborted>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// Check if the future has been aborted
if self.inner.cancel.load(Ordering::Relaxed) {
return Poll::Ready(Err(Aborted));
}

// attempt to complete the future
if let Poll::Ready(x) = self.as_mut().project().future.poll(cx) {
return Poll::Ready(Ok(x));
}

// Register to receive a wakeup if the future is aborted in the... future
self.inner.waker.register(cx.waker());

// Check to see if the future was aborted between the first check and
// registration.
// Checking with `Relaxed` is sufficient because `register` introduces an
// `AcqRel` barrier.
if self.inner.cancel.load(Ordering::Relaxed) {
return Poll::Ready(Err(Aborted));
}

Poll::Pending
}
}

impl AbortHandle {
/// Abort the `Abortable` future associated with this handle.
///
/// Notifies the Abortable future associated with this handle that it
/// should abort. Note that if the future is currently being polled on
/// another thread, it will not immediately stop running. Instead, it will
/// continue to run until its poll method returns.
pub fn abort(&self) {
self.inner.cancel.store(true, Ordering::Relaxed);
self.inner.waker.wake();
}
let abortable = assert_future::<Result<Fut::Output, Aborted>, _>(Abortable::new(future, reg));
(abortable, handle)
}
4 changes: 3 additions & 1 deletion futures-util/src/future/mod.rs
Expand Up @@ -110,7 +110,9 @@ cfg_target_has_atomic! {
#[cfg(feature = "alloc")]
mod abortable;
#[cfg(feature = "alloc")]
pub use self::abortable::{abortable, Abortable, AbortHandle, AbortRegistration, Aborted};
pub use crate::abortable::{Abortable, AbortHandle, AbortRegistration, Aborted};
#[cfg(feature = "alloc")]
pub use abortable::abortable;
}

// Just a helper function to ensure the futures we're returning all have the
Expand Down
5 changes: 5 additions & 0 deletions futures-util/src/lib.rs
Expand Up @@ -334,5 +334,10 @@ pub use crate::io::{
#[cfg(feature = "alloc")]
pub mod lock;

cfg_target_has_atomic! {
#[cfg(feature = "alloc")]
mod abortable;
}

mod fns;
mod unfold_state;
19 changes: 19 additions & 0 deletions futures-util/src/stream/abortable.rs
@@ -0,0 +1,19 @@
use super::assert_stream;
use crate::stream::{AbortHandle, Abortable};
use crate::Stream;

/// Creates a new `Abortable` stream and an `AbortHandle` which can be used to stop it.
///
/// This function is a convenient (but less flexible) alternative to calling
/// `AbortHandle::new` and `Abortable::new` manually.
///
/// This function is only available when the `std` or `alloc` feature of this
/// library is activated, and it is activated by default.
pub fn abortable<St>(stream: St) -> (Abortable<St>, AbortHandle)
where
St: Stream,
{
let (handle, reg) = AbortHandle::new_pair();
let abortable = assert_stream::<St::Item, _>(Abortable::new(stream, reg));
(abortable, handle)
}