Skip to content

Commit

Permalink
Synchronize FuturesUnordered length and "head_all"
Browse files Browse the repository at this point in the history
This alters the way FuturesUnordered tracks the number of futures it
contains so that the length can be immediately known any time an
immutable iterator is created, even when multiple threads share access,
without having to manually count the number of futures at runtime. Each
Task itself stores the number of futures in the set at the time of
insertion, so the length can be retrieved from whatever Task was loaded
from FuturesUnordered::head_all at the time the iterator was created.
The head node's length value is corrected when futures are removed,
ensuring the correct length will carry over to iterators created
afterwards.
  • Loading branch information
okready committed Jan 28, 2020
1 parent 447fba1 commit b4bc419
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 96 deletions.
41 changes: 5 additions & 36 deletions futures-util/src/stream/futures_unordered/iter.rs
@@ -1,10 +1,8 @@
use super::FuturesUnordered;
use super::task::Task;
use super::UNINITIALIZED_ITER_LEN;
use core::cell::Cell;
use core::marker::PhantomData;
use core::pin::Pin;
use core::sync::atomic::Ordering::{Acquire, Relaxed};
use core::sync::atomic::Ordering::Relaxed;

#[derive(Debug)]
/// Mutable iterator over all futures in the unordered set.
Expand All @@ -22,7 +20,7 @@ pub struct IterMut<'a, Fut: Unpin> (pub(super) IterPinMut<'a, Fut>);
/// Immutable iterator over all futures in the unordered set.
pub struct IterPinRef<'a, Fut> {
pub(super) task: *const Task<Fut>,
pub(super) len: Cell<usize>,
pub(super) len: usize,
pub(super) _marker: PhantomData<&'a FuturesUnordered<Fut>>
}

Expand Down Expand Up @@ -82,44 +80,15 @@ impl<'a, Fut> Iterator for IterPinRef<'a, Fut> {
}
unsafe {
let future = (*(*self.task).future.get()).as_ref().unwrap();

// `next_all` may still be `PENDING_NEXT_ALL` if its actual value
// has yet to be committed, so loop until we get a valid value.
let mut next;
while {
next = (*self.task).next_all.load(Acquire);
next == Task::PENDING_NEXT_ALL
} {}
let next = (*self.task).spin_next_all();
self.task = next;

// `len` only needs to be updated if it has been initialized.
let len = self.len.get();
if len != UNINITIALIZED_ITER_LEN {
self.len.set(len - 1);
}

self.len -= 1;
Some(Pin::new_unchecked(future))
}
}

fn size_hint(&self) -> (usize, Option<usize>) {
// Initialize `len` if necessary.
let mut len = self.len.get();
if len == UNINITIALIZED_ITER_LEN {
len = 0;
let mut task = self.task;
while !task.is_null() {
len += 1;
while {
task = unsafe { (*task).next_all.load(Acquire) };
task == Task::PENDING_NEXT_ALL
} {}
}

self.len.set(len);
}

(len, Some(len))
(self.len, Some(self.len))
}
}

Expand Down
134 changes: 78 additions & 56 deletions futures-util/src/stream/futures_unordered/mod.rs
Expand Up @@ -8,15 +8,15 @@ use futures_core::stream::{FusedStream, Stream};
use futures_core::task::{Context, Poll};
use futures_task::{FutureObj, LocalFutureObj, Spawn, LocalSpawn, SpawnError};
use crate::task::AtomicWaker;
use core::cell::{Cell, UnsafeCell};
use core::cell::UnsafeCell;
use core::fmt::{self, Debug};
use core::iter::FromIterator;
use core::marker::PhantomData;
use core::mem;
use core::pin::Pin;
use core::ptr;
use core::sync::atomic::Ordering::{Acquire, Relaxed, Release, SeqCst};
use core::sync::atomic::{AtomicPtr, AtomicBool, AtomicUsize};
use core::sync::atomic::{AtomicPtr, AtomicBool};
use alloc::sync::{Arc, Weak};

mod abort;
Expand All @@ -30,15 +30,6 @@ use self::task::Task;
mod ready_to_run_queue;
use self::ready_to_run_queue::{ReadyToRunQueue, Dequeue};

/// Constant used for a `FuturesUnordered` to indicate we are empty and have
/// yielded a `None` element so can return `true` from
/// `FusedStream::is_terminated`
///
/// It is safe to not check for this when incrementing as even a ZST future will
/// have a `Task` allocated for it, so we cannot ever reach usize::max_value()
/// without running out of ram.
const TERMINATED_SENTINEL_LENGTH: usize = usize::max_value();

/// Constant used for a `FuturesUnordered` to determine how many times it is
/// allowed to poll underlying futures without yielding.
///
Expand All @@ -56,22 +47,6 @@ const TERMINATED_SENTINEL_LENGTH: usize = usize::max_value();
/// See also https://github.com/rust-lang/futures-rs/issues/2047.
const YIELD_EVERY: usize = 32;

/// Indicator for uninitialized `IterPinRef::len` values.
///
/// Updates to `len` and `head_all` are not synchronized with each other when
/// pushing futures onto a `FuturesUnordered`, so we can't rely on `len`
/// providing the correct list length when `iter` grabs the `len` and `head_all`
/// fields. Existing code may rely on `IterPinRef::size_hint` returning accurate
/// values, so the size in `IterPinRef` is lazily initialized, with this value
/// used to mark that the size has not yet been initialized. Note that this
/// isn't an issue for `IterPinMut`, as one can only be created when a thread
/// already has exclusive access to the `FuturesUnordered` and any changes to
/// `len` and `head_all` have completed.
///
/// `TERMINATED_SENTINEL_LENGTH` is already reserved as a length value that can
/// never be reached for `FuturesUnordered`, so we can simple reuse it here.
const UNINITIALIZED_ITER_LEN: usize = TERMINATED_SENTINEL_LENGTH;

/// A set of futures which may complete in any order.
///
/// This structure is optimized to manage a large number of futures.
Expand All @@ -95,8 +70,8 @@ const UNINITIALIZED_ITER_LEN: usize = TERMINATED_SENTINEL_LENGTH;
#[must_use = "streams do nothing unless polled"]
pub struct FuturesUnordered<Fut> {
ready_to_run_queue: Arc<ReadyToRunQueue<Fut>>,
len: AtomicUsize,
head_all: AtomicPtr<Task<Fut>>,
is_terminated: AtomicBool,
}

unsafe impl<Fut: Send> Send for FuturesUnordered<Fut> {}
Expand Down Expand Up @@ -157,6 +132,7 @@ impl<Fut: Future> FuturesUnordered<Fut> {
future: UnsafeCell::new(None),
next_all: AtomicPtr::new(Task::PENDING_NEXT_ALL),
prev_all: UnsafeCell::new(ptr::null()),
len_all: UnsafeCell::new(0),
next_ready_to_run: AtomicPtr::new(ptr::null_mut()),
queued: AtomicBool::new(true),
ready_to_run_queue: Weak::new(),
Expand All @@ -170,9 +146,9 @@ impl<Fut: Future> FuturesUnordered<Fut> {
});

FuturesUnordered {
len: 0.into(),
head_all: AtomicPtr::new(ptr::null_mut()),
ready_to_run_queue,
is_terminated: AtomicBool::new(false),
}
}
}
Expand All @@ -188,14 +164,15 @@ impl<Fut> FuturesUnordered<Fut> {
///
/// This represents the total number of in-flight futures.
pub fn len(&self) -> usize {
let len = self.len.load(Relaxed);
if len == TERMINATED_SENTINEL_LENGTH { 0 } else { len }
let (_, len) = self.atomic_load_head_and_len_all();
len
}

/// Returns `true` if the set contains no futures.
pub fn is_empty(&self) -> bool {
let len = self.len.load(Relaxed);
len == 0 || len == TERMINATED_SENTINEL_LENGTH
// Relaxed ordering can be used here since we don't need to read from
// the head pointer, only check whether it is null.
self.head_all.load(Relaxed).is_null()
}

/// Push a future into the set.
Expand All @@ -209,11 +186,16 @@ impl<Fut> FuturesUnordered<Fut> {
future: UnsafeCell::new(Some(future)),
next_all: AtomicPtr::new(Task::PENDING_NEXT_ALL),
prev_all: UnsafeCell::new(ptr::null_mut()),
len_all: UnsafeCell::new(0),
next_ready_to_run: AtomicPtr::new(ptr::null_mut()),
queued: AtomicBool::new(true),
ready_to_run_queue: Arc::downgrade(&self.ready_to_run_queue),
});

// Reset the `is_terminated` flag if we've previously marked ourselves
// as terminated.
self.is_terminated.store(false, Relaxed);

// Right now our task has a strong reference count of 1. We transfer
// ownership of this reference count to our internal linked list
// and we'll reclaim ownership through the `unlink` method below.
Expand All @@ -233,9 +215,11 @@ impl<Fut> FuturesUnordered<Fut> {

/// Returns an iterator that allows inspecting each future in the set.
fn iter_pin_ref(self: Pin<&Self>) -> IterPinRef<'_, Fut> {
let (task, len) = self.atomic_load_head_and_len_all();

IterPinRef {
task: self.head_all.load(Acquire),
len: Cell::new(UNINITIALIZED_ITER_LEN),
task,
len,
_marker: PhantomData,
}
}
Expand All @@ -247,13 +231,41 @@ impl<Fut> FuturesUnordered<Fut> {

/// Returns an iterator that allows modifying each future in the set.
pub fn iter_pin_mut(mut self: Pin<&mut Self>) -> IterPinMut<'_, Fut> {
// `head_all` can be accessed directly and we don't need to spin on
// `Task::next_all` since we have exclusive access to the set.
let task = *self.head_all.get_mut();
let len = if task.is_null() {
0
} else {
unsafe {
*(*task).len_all.get()
}
};

IterPinMut {
task: *self.head_all.get_mut(),
len: self.len(),
task,
len,
_marker: PhantomData
}
}

/// Returns the current head node and number of futures in the list of all
/// futures within a context where access is shared with other threads
/// (mostly for use with the `len` and `iter_pin_ref` methods).
fn atomic_load_head_and_len_all(&self) -> (*const Task<Fut>, usize) {
let task = self.head_all.load(Acquire);
let len = if task.is_null() {
0
} else {
unsafe {
(*task).spin_next_all();
*(*task).len_all.get()
}
};

(task, len)
}

/// Releases the task. It destorys the future inside and either drops
/// the `Arc<Task>` or transfers ownership to the ready to run queue.
/// The task this method is called on must have been unlinked before.
Expand Down Expand Up @@ -304,17 +316,23 @@ impl<Fut> FuturesUnordered<Fut> {
// Atomically swap out the old head node to get the node that should be
// assigned to `next_all`. We don't need acquire ordering for the load
// portion, as we are only using the `head_all` value and not accessing
// its data in this context (acquire ordering is used for loads in the
// `iter` method and during `IterPinRef` use, so read ordering is
// ensured then).
// its data in this context (acquire ordering will still be used in
// other contexts where the node's contents are read, so read ordering
// is ensured then).
let next = self.head_all.swap(ptr as *mut _, Release);

unsafe {
// We can use relaxed ordering here, as the `head_all` write can be
// safely ordered after this write (reordering would actually be
// helpful, since only the final `next_all` value would be read by
// other threads instead of the temporary `PENDING_NEXT_ALL`).
(*ptr).next_all.store(next, Relaxed);
// Store the new list length in the new node.
let new_len = if next.is_null() {
1
} else {
*(*next).len_all.get() + 1
};
*(*ptr).len_all.get() = new_len;

// Write the old head as the next node pointer, signaling to other
// threads that `len_all` and `next_all` are ready to read.
(*ptr).next_all.store(next, Release);

// `prev_all` updates don't need to be synchronized, as the field is
// only ever used after exclusive access has been acquired.
Expand All @@ -323,13 +341,6 @@ impl<Fut> FuturesUnordered<Fut> {
}
}

// `len` will wrap to zero if we've previously marked ourselves as
// terminated, so an extra increment can be done to ensure `len` updates
// to the correct length.
if self.len.fetch_add(1, Relaxed) == TERMINATED_SENTINEL_LENGTH {
self.len.fetch_add(1, Relaxed);
}

ptr
}

Expand All @@ -338,6 +349,12 @@ impl<Fut> FuturesUnordered<Fut> {
/// This method is unsafe because it has be guaranteed that `task` is a
/// valid pointer.
unsafe fn unlink(&mut self, task: *const Task<Fut>) -> Arc<Task<Fut>> {
// Compute the new list length now in case we're removing the head node
// and won't be able to retrieve the correct length later.
let head = *self.head_all.get_mut();
debug_assert!(!head.is_null());
let new_len = *(*head).len_all.get() - 1;

let task = Arc::from_raw(task);
let next = task.next_all.load(Relaxed);
let prev = *task.prev_all.get();
Expand All @@ -353,8 +370,13 @@ impl<Fut> FuturesUnordered<Fut> {
} else {
*self.head_all.get_mut() = next;
}
let old_len = *self.len.get_mut();
*self.len.get_mut() = old_len - 1;

// Store the new list length in the head node.
let head = *self.head_all.get_mut();
if !head.is_null() {
*(*head).len_all.get() = new_len;
}

task
}
}
Expand All @@ -380,7 +402,7 @@ impl<Fut: Future> Stream for FuturesUnordered<Fut> {
if self.is_empty() {
// We can only consider ourselves terminated once we
// have yielded a `None`
*self.len.get_mut() = TERMINATED_SENTINEL_LENGTH;
*self.is_terminated.get_mut() = true;
return Poll::Ready(None);
} else {
return Poll::Pending;
Expand Down Expand Up @@ -568,6 +590,6 @@ impl<Fut: Future> FromIterator<Fut> for FuturesUnordered<Fut> {

impl<Fut: Future> FusedStream for FuturesUnordered<Fut> {
fn is_terminated(&self) -> bool {
self.len.load(Relaxed) == TERMINATED_SENTINEL_LENGTH
self.is_terminated.load(Relaxed)
}
}
29 changes: 25 additions & 4 deletions futures-util/src/stream/futures_unordered/task.rs
@@ -1,6 +1,6 @@
use core::cell::UnsafeCell;
use core::sync::atomic::{AtomicPtr, AtomicBool};
use core::sync::atomic::Ordering::SeqCst;
use core::sync::atomic::Ordering::{Acquire, SeqCst};
use alloc::sync::{Arc, Weak};

use crate::task::{ArcWake, WakerRef, waker_ref};
Expand All @@ -11,14 +11,18 @@ pub(super) struct Task<Fut> {
// The future
pub(super) future: UnsafeCell<Option<Fut>>,

// Next pointer for linked list tracking all active tasks (initialized to
// `PENDING_NEXT_ALL` and atomically updated to the correct value *after*
// insertion into the list)
// Next pointer for linked list tracking all active tasks (use
// `spin_next_all` to read when access is shared across threads)
pub(super) next_all: AtomicPtr<Task<Fut>>,

// Previous task in linked list tracking all active tasks
pub(super) prev_all: UnsafeCell<*const Task<Fut>>,

// Length of the linked list tracking all active tasks when this node was
// inserted (use `spin_next_all` to synchronize before reading when access
// is shared across threads)
pub(super) len_all: UnsafeCell<usize>,

// Next pointer in ready to run queue
pub(super) next_ready_to_run: AtomicPtr<Task<Fut>>,

Expand Down Expand Up @@ -94,6 +98,23 @@ impl<Fut> Task<Fut> {
pub(super) fn waker_ref<'a>(this: &'a Arc<Task<Fut>>) -> WakerRef<'a> {
waker_ref(this)
}

/// Spins until `next_all` is no longer set to `PENDING_NEXT_ALL`.
///
/// The temporary `PENDING_NEXT_ALL` value is typically overwritten fairly
/// quickly after a node is inserted into the list of all futures, so this
/// should rarely spin much.
///
/// When it returns, the correct `next_all` value is returned, and `len_all`
/// is safe to read.
pub(super) fn spin_next_all(&self) -> *const Self {
loop {
let next = self.next_all.load(Acquire);
if next != Self::PENDING_NEXT_ALL {
return next;
}
}
}
}

impl<Fut> Drop for Task<Fut> {
Expand Down

0 comments on commit b4bc419

Please sign in to comment.