diff --git a/futures-util/src/stream/futures_unordered/iter.rs b/futures-util/src/stream/futures_unordered/iter.rs index 6a1efd7e3c..d6f340109f 100644 --- a/futures-util/src/stream/futures_unordered/iter.rs +++ b/futures-util/src/stream/futures_unordered/iter.rs @@ -1,10 +1,8 @@ use super::FuturesUnordered; use super::task::Task; -use super::UNINITIALIZED_ITER_LEN; -use core::cell::Cell; use core::marker::PhantomData; use core::pin::Pin; -use core::sync::atomic::Ordering::{Acquire, Relaxed}; +use core::sync::atomic::Ordering::Relaxed; #[derive(Debug)] /// Mutable iterator over all futures in the unordered set. @@ -22,7 +20,7 @@ pub struct IterMut<'a, Fut: Unpin> (pub(super) IterPinMut<'a, Fut>); /// Immutable iterator over all futures in the unordered set. pub struct IterPinRef<'a, Fut> { pub(super) task: *const Task, - pub(super) len: Cell, + pub(super) len: usize, pub(super) pending_next_all: *mut Task, pub(super) _marker: PhantomData<&'a FuturesUnordered> } @@ -82,44 +80,15 @@ impl<'a, Fut> Iterator for IterPinRef<'a, Fut> { } unsafe { let future = (*(*self.task).future.get()).as_ref().unwrap(); - - // `next_all` may still be `pending_next_all` if its actual value - // has yet to be committed, so loop until we get a valid value. - let mut next; - while { - next = (*self.task).next_all.load(Acquire); - next == self.pending_next_all - } {} + let next = (*self.task).spin_next_all(self.pending_next_all); self.task = next; - - // `len` only needs to be updated if it has been initialized. - let len = self.len.get(); - if len != UNINITIALIZED_ITER_LEN { - self.len.set(len - 1); - } - + self.len -= 1; Some(Pin::new_unchecked(future)) } } fn size_hint(&self) -> (usize, Option) { - // Initialize `len` if necessary. - let mut len = self.len.get(); - if len == UNINITIALIZED_ITER_LEN { - len = 0; - let mut task = self.task; - while !task.is_null() { - len += 1; - while { - task = unsafe { (*task).next_all.load(Acquire) }; - task == self.pending_next_all - } {} - } - - self.len.set(len); - } - - (len, Some(len)) + (self.len, Some(self.len)) } } diff --git a/futures-util/src/stream/futures_unordered/mod.rs b/futures-util/src/stream/futures_unordered/mod.rs index 2a5114bf6c..f2b3bee705 100644 --- a/futures-util/src/stream/futures_unordered/mod.rs +++ b/futures-util/src/stream/futures_unordered/mod.rs @@ -8,7 +8,7 @@ use futures_core::stream::{FusedStream, Stream}; use futures_core::task::{Context, Poll}; use futures_task::{FutureObj, LocalFutureObj, Spawn, LocalSpawn, SpawnError}; use crate::task::AtomicWaker; -use core::cell::{Cell, UnsafeCell}; +use core::cell::UnsafeCell; use core::fmt::{self, Debug}; use core::iter::FromIterator; use core::marker::PhantomData; @@ -16,7 +16,7 @@ use core::mem; use core::pin::Pin; use core::ptr; use core::sync::atomic::Ordering::{Acquire, Relaxed, Release, SeqCst}; -use core::sync::atomic::{AtomicPtr, AtomicBool, AtomicUsize}; +use core::sync::atomic::{AtomicPtr, AtomicBool}; use alloc::sync::{Arc, Weak}; mod abort; @@ -30,15 +30,6 @@ use self::task::Task; mod ready_to_run_queue; use self::ready_to_run_queue::{ReadyToRunQueue, Dequeue}; -/// Constant used for a `FuturesUnordered` to indicate we are empty and have -/// yielded a `None` element so can return `true` from -/// `FusedStream::is_terminated` -/// -/// It is safe to not check for this when incrementing as even a ZST future will -/// have a `Task` allocated for it, so we cannot ever reach usize::max_value() -/// without running out of ram. -const TERMINATED_SENTINEL_LENGTH: usize = usize::max_value(); - /// Constant used for a `FuturesUnordered` to determine how many times it is /// allowed to poll underlying futures without yielding. /// @@ -56,22 +47,6 @@ const TERMINATED_SENTINEL_LENGTH: usize = usize::max_value(); /// See also https://github.com/rust-lang/futures-rs/issues/2047. const YIELD_EVERY: usize = 32; -/// Indicator for uninitialized `IterPinRef::len` values. -/// -/// Updates to `len` and `head_all` are not synchronized with each other when -/// pushing futures onto a `FuturesUnordered`, so we can't rely on `len` -/// providing the correct list length when `iter` grabs the `len` and `head_all` -/// fields. Existing code may rely on `IterPinRef::size_hint` returning accurate -/// values, so the size in `IterPinRef` is lazily initialized, with this value -/// used to mark that the size has not yet been initialized. Note that this -/// isn't an issue for `IterPinMut`, as one can only be created when a thread -/// already has exclusive access to the `FuturesUnordered` and any changes to -/// `len` and `head_all` have completed. -/// -/// `TERMINATED_SENTINEL_LENGTH` is already reserved as a length value that can -/// never be reached for `FuturesUnordered`, so we can simple reuse it here. -const UNINITIALIZED_ITER_LEN: usize = TERMINATED_SENTINEL_LENGTH; - /// A set of futures which may complete in any order. /// /// This structure is optimized to manage a large number of futures. @@ -95,8 +70,8 @@ const UNINITIALIZED_ITER_LEN: usize = TERMINATED_SENTINEL_LENGTH; #[must_use = "streams do nothing unless polled"] pub struct FuturesUnordered { ready_to_run_queue: Arc>, - len: AtomicUsize, head_all: AtomicPtr>, + is_terminated: AtomicBool, } unsafe impl Send for FuturesUnordered {} @@ -157,6 +132,7 @@ impl FuturesUnordered { future: UnsafeCell::new(None), next_all: AtomicPtr::new(ptr::null_mut()), prev_all: UnsafeCell::new(ptr::null()), + len_all: UnsafeCell::new(0), next_ready_to_run: AtomicPtr::new(ptr::null_mut()), queued: AtomicBool::new(true), ready_to_run_queue: Weak::new(), @@ -170,9 +146,9 @@ impl FuturesUnordered { }); FuturesUnordered { - len: 0.into(), head_all: AtomicPtr::new(ptr::null_mut()), ready_to_run_queue, + is_terminated: AtomicBool::new(false), } } } @@ -188,14 +164,15 @@ impl FuturesUnordered { /// /// This represents the total number of in-flight futures. pub fn len(&self) -> usize { - let len = self.len.load(Relaxed); - if len == TERMINATED_SENTINEL_LENGTH { 0 } else { len } + let (_, len) = self.atomic_load_head_and_len_all(); + len } /// Returns `true` if the set contains no futures. pub fn is_empty(&self) -> bool { - let len = self.len.load(Relaxed); - len == 0 || len == TERMINATED_SENTINEL_LENGTH + // Relaxed ordering can be used here since we don't need to read from + // the head pointer, only check whether it is null. + self.head_all.load(Relaxed).is_null() } /// Push a future into the set. @@ -209,11 +186,16 @@ impl FuturesUnordered { future: UnsafeCell::new(Some(future)), next_all: AtomicPtr::new(self.pending_next_all()), prev_all: UnsafeCell::new(ptr::null_mut()), + len_all: UnsafeCell::new(0), next_ready_to_run: AtomicPtr::new(ptr::null_mut()), queued: AtomicBool::new(true), ready_to_run_queue: Arc::downgrade(&self.ready_to_run_queue), }); + // Reset the `is_terminated` flag if we've previously marked ourselves + // as terminated. + self.is_terminated.store(false, Relaxed); + // Right now our task has a strong reference count of 1. We transfer // ownership of this reference count to our internal linked list // and we'll reclaim ownership through the `unlink` method below. @@ -233,9 +215,11 @@ impl FuturesUnordered { /// Returns an iterator that allows inspecting each future in the set. fn iter_pin_ref(self: Pin<&Self>) -> IterPinRef<'_, Fut> { + let (task, len) = self.atomic_load_head_and_len_all(); + IterPinRef { - task: self.head_all.load(Acquire), - len: Cell::new(UNINITIALIZED_ITER_LEN), + task, + len, pending_next_all: self.pending_next_all(), _marker: PhantomData, } @@ -248,13 +232,41 @@ impl FuturesUnordered { /// Returns an iterator that allows modifying each future in the set. pub fn iter_pin_mut(mut self: Pin<&mut Self>) -> IterPinMut<'_, Fut> { + // `head_all` can be accessed directly and we don't need to spin on + // `Task::next_all` since we have exclusive access to the set. + let task = *self.head_all.get_mut(); + let len = if task.is_null() { + 0 + } else { + unsafe { + *(*task).len_all.get() + } + }; + IterPinMut { - task: *self.head_all.get_mut(), - len: self.len(), + task, + len, _marker: PhantomData } } + /// Returns the current head node and number of futures in the list of all + /// futures within a context where access is shared with other threads + /// (mostly for use with the `len` and `iter_pin_ref` methods). + fn atomic_load_head_and_len_all(&self) -> (*const Task, usize) { + let task = self.head_all.load(Acquire); + let len = if task.is_null() { + 0 + } else { + unsafe { + (*task).spin_next_all(self.pending_next_all()); + *(*task).len_all.get() + } + }; + + (task, len) + } + /// Releases the task. It destorys the future inside and either drops /// the `Arc` or transfers ownership to the ready to run queue. /// The task this method is called on must have been unlinked before. @@ -305,17 +317,23 @@ impl FuturesUnordered { // Atomically swap out the old head node to get the node that should be // assigned to `next_all`. We don't need acquire ordering for the load // portion, as we are only using the `head_all` value and not accessing - // its data in this context (acquire ordering is used for loads in the - // `iter` method and during `IterPinRef` use, so read ordering is - // ensured then). + // its data in this context (acquire ordering will still be used in + // other contexts where the node's contents are read, so read ordering + // is ensured then). let next = self.head_all.swap(ptr as *mut _, Release); unsafe { - // We can use relaxed ordering here, as the `head_all` write can be - // safely ordered after this write (reordering would actually be - // helpful, since only the final `next_all` value would be read by - // other threads instead of the pending state value). - (*ptr).next_all.store(next, Relaxed); + // Store the new list length in the new node. + let new_len = if next.is_null() { + 1 + } else { + *(*next).len_all.get() + 1 + }; + *(*ptr).len_all.get() = new_len; + + // Write the old head as the next node pointer, signaling to other + // threads that `len_all` and `next_all` are ready to read. + (*ptr).next_all.store(next, Release); // `prev_all` updates don't need to be synchronized, as the field is // only ever used after exclusive access has been acquired. @@ -324,13 +342,6 @@ impl FuturesUnordered { } } - // `len` will wrap to zero if we've previously marked ourselves as - // terminated, so an extra increment can be done to ensure `len` updates - // to the correct length. - if self.len.fetch_add(1, Relaxed) == TERMINATED_SENTINEL_LENGTH { - self.len.fetch_add(1, Relaxed); - } - ptr } @@ -339,6 +350,12 @@ impl FuturesUnordered { /// This method is unsafe because it has be guaranteed that `task` is a /// valid pointer. unsafe fn unlink(&mut self, task: *const Task) -> Arc> { + // Compute the new list length now in case we're removing the head node + // and won't be able to retrieve the correct length later. + let head = *self.head_all.get_mut(); + debug_assert!(!head.is_null()); + let new_len = *(*head).len_all.get() - 1; + let task = Arc::from_raw(task); let next = task.next_all.load(Relaxed); let prev = *task.prev_all.get(); @@ -354,8 +371,13 @@ impl FuturesUnordered { } else { *self.head_all.get_mut() = next; } - let old_len = *self.len.get_mut(); - *self.len.get_mut() = old_len - 1; + + // Store the new list length in the head node. + let head = *self.head_all.get_mut(); + if !head.is_null() { + *(*head).len_all.get() = new_len; + } + task } @@ -410,7 +432,7 @@ impl Stream for FuturesUnordered { if self.is_empty() { // We can only consider ourselves terminated once we // have yielded a `None` - *self.len.get_mut() = TERMINATED_SENTINEL_LENGTH; + *self.is_terminated.get_mut() = true; return Poll::Ready(None); } else { return Poll::Pending; @@ -598,6 +620,6 @@ impl FromIterator for FuturesUnordered { impl FusedStream for FuturesUnordered { fn is_terminated(&self) -> bool { - self.len.load(Relaxed) == TERMINATED_SENTINEL_LENGTH + self.is_terminated.load(Relaxed) } } diff --git a/futures-util/src/stream/futures_unordered/task.rs b/futures-util/src/stream/futures_unordered/task.rs index 03ea7a6478..a78a6589e1 100644 --- a/futures-util/src/stream/futures_unordered/task.rs +++ b/futures-util/src/stream/futures_unordered/task.rs @@ -1,6 +1,6 @@ use core::cell::UnsafeCell; use core::sync::atomic::{AtomicPtr, AtomicBool}; -use core::sync::atomic::Ordering::SeqCst; +use core::sync::atomic::Ordering::{Acquire, SeqCst}; use alloc::sync::{Arc, Weak}; use crate::task::{ArcWake, WakerRef, waker_ref}; @@ -11,14 +11,18 @@ pub(super) struct Task { // The future pub(super) future: UnsafeCell>, - // Next pointer for linked list tracking all active tasks (initialized to a - // reserved "pending" state value and atomically updated to the correct - // value *after* insertion into the list) + // Next pointer for linked list tracking all active tasks (use + // `spin_next_all` to read when access is shared across threads) pub(super) next_all: AtomicPtr>, // Previous task in linked list tracking all active tasks pub(super) prev_all: UnsafeCell<*const Task>, + // Length of the linked list tracking all active tasks when this node was + // inserted (use `spin_next_all` to synchronize before reading when access + // is shared across threads) + pub(super) len_all: UnsafeCell, + // Next pointer in ready to run queue pub(super) next_ready_to_run: AtomicPtr>, @@ -69,6 +73,26 @@ impl Task { pub(super) fn waker_ref<'a>(this: &'a Arc>) -> WakerRef<'a> { waker_ref(this) } + + /// Spins until `next_all` is no longer set to `pending_next_all`. + /// + /// The temporary `pending_next_all` value is typically overwritten fairly + /// quickly after a node is inserted into the list of all futures, so this + /// should rarely spin much. + /// + /// When it returns, the correct `next_all` value is returned, and `len_all` + /// is safe to read. + pub(super) fn spin_next_all( + &self, + pending_next_all: *mut Self, + ) -> *const Self { + loop { + let next = self.next_all.load(Acquire); + if next != pending_next_all { + return next; + } + } + } } impl Drop for Task { diff --git a/futures/tests/futures_unordered.rs b/futures/tests/futures_unordered.rs index 1995a2bc25..57eb98fd1b 100644 --- a/futures/tests/futures_unordered.rs +++ b/futures/tests/futures_unordered.rs @@ -245,3 +245,43 @@ fn futures_not_moved_after_poll() { assert_stream_next!(stream, ()); assert_stream_done!(stream); } + +#[test] +fn len_valid_during_out_of_order_completion() { + // Complete futures out-of-order and add new futures afterwards to ensure + // length values remain correct. + let (a_tx, a_rx) = oneshot::channel::(); + let (b_tx, b_rx) = oneshot::channel::(); + let (c_tx, c_rx) = oneshot::channel::(); + let (d_tx, d_rx) = oneshot::channel::(); + + let mut cx = noop_context(); + let mut stream = FuturesUnordered::new(); + assert_eq!(stream.len(), 0); + + stream.push(a_rx); + assert_eq!(stream.len(), 1); + stream.push(b_rx); + assert_eq!(stream.len(), 2); + stream.push(c_rx); + assert_eq!(stream.len(), 3); + + b_tx.send(4).unwrap(); + assert_eq!(stream.poll_next_unpin(&mut cx), Poll::Ready(Some(Ok(4)))); + assert_eq!(stream.len(), 2); + + stream.push(d_rx); + assert_eq!(stream.len(), 3); + + c_tx.send(5).unwrap(); + assert_eq!(stream.poll_next_unpin(&mut cx), Poll::Ready(Some(Ok(5)))); + assert_eq!(stream.len(), 2); + + d_tx.send(6).unwrap(); + assert_eq!(stream.poll_next_unpin(&mut cx), Poll::Ready(Some(Ok(6)))); + assert_eq!(stream.len(), 1); + + a_tx.send(7).unwrap(); + assert_eq!(stream.poll_next_unpin(&mut cx), Poll::Ready(Some(Ok(7)))); + assert_eq!(stream.len(), 0); +}