Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to FuturesOrdered dynamically in try_join_all #2556

Merged
merged 4 commits into from Jun 5, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions futures-util/src/future/join_all.rs
Expand Up @@ -15,7 +15,7 @@ use super::{assert_future, MaybeDone};
#[cfg(not(futures_no_atomic_cas))]
use crate::stream::{Collect, FuturesOrdered, StreamExt};

fn iter_pin_mut<T>(slice: Pin<&mut [T]>) -> impl Iterator<Item = Pin<&mut T>> {
pub(crate) fn iter_pin_mut<T>(slice: Pin<&mut [T]>) -> impl Iterator<Item = Pin<&mut T>> {
// Safety: `std` _could_ make this unsound if it were to decide Pin's
// invariants aren't required to transmit through slices. Otherwise this has
// the same safety as a normal field pin projection.
Expand All @@ -32,9 +32,9 @@ where
}

#[cfg(not(futures_no_atomic_cas))]
const SMALL: usize = 30;
pub(crate) const SMALL: usize = 30;

pub(crate) enum JoinAllKind<F>
enum JoinAllKind<F>
where
F: Future,
{
Expand Down
125 changes: 87 additions & 38 deletions futures-util/src/future/try_join_all.rs
Expand Up @@ -10,14 +10,10 @@ use core::mem;
use core::pin::Pin;
use core::task::{Context, Poll};

use super::{assert_future, TryFuture, TryMaybeDone};
use super::{assert_future, join_all, TryFuture, TryMaybeDone};

fn iter_pin_mut<T>(slice: Pin<&mut [T]>) -> impl Iterator<Item = Pin<&mut T>> {
// Safety: `std` _could_ make this unsound if it were to decide Pin's
// invariants aren't required to transmit through slices. Otherwise this has
// the same safety as a normal field pin projection.
unsafe { slice.get_unchecked_mut() }.iter_mut().map(|t| unsafe { Pin::new_unchecked(t) })
}
#[cfg(not(futures_no_atomic_cas))]
use crate::stream::{FuturesOrdered, TryCollect, TryStreamExt};

enum FinalState<E = ()> {
Pending,
Expand All @@ -31,17 +27,37 @@ pub struct TryJoinAll<F>
where
F: TryFuture,
{
elems: Pin<Box<[TryMaybeDone<F>]>>,
kind: TryJoinAllKind<F>,
}

enum TryJoinAllKind<F>
where
F: TryFuture,
{
Small {
elems: Pin<Box<[TryMaybeDone<F>]>>,
},
#[cfg(not(futures_no_atomic_cas))]
Big {
fut: TryCollect<FuturesOrdered<F>, Vec<F::Ok>>,
},
}

impl<F> fmt::Debug for TryJoinAll<F>
where
F: TryFuture + fmt::Debug,
F::Ok: fmt::Debug,
F::Error: fmt::Debug,
F::Output: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TryJoinAll").field("elems", &self.elems).finish()
match self.kind {
TryJoinAllKind::Small { ref elems } => {
f.debug_struct("TryJoinAll").field("elems", elems).finish()
}
#[cfg(not(futures_no_atomic_cas))]
TryJoinAllKind::Big { ref fut, .. } => fmt::Debug::fmt(fut, f),
}
}
}

Expand Down Expand Up @@ -83,54 +99,87 @@ where
/// assert_eq!(try_join_all(futures).await, Err(2));
/// # });
/// ```
pub fn try_join_all<I>(i: I) -> TryJoinAll<I::Item>
pub fn try_join_all<I>(iter: I) -> TryJoinAll<I::Item>
where
I: IntoIterator,
I::Item: TryFuture,
I::Item: TryFuture
+ Future<Output = Result<<I::Item as TryFuture>::Ok, <I::Item as TryFuture>::Error>>,
ibraheemdev marked this conversation as resolved.
Show resolved Hide resolved
{
let elems: Box<[_]> = i.into_iter().map(TryMaybeDone::Future).collect();
assert_future::<Result<Vec<<I::Item as TryFuture>::Ok>, <I::Item as TryFuture>::Error>, _>(
TryJoinAll { elems: elems.into() },
)
#[cfg(futures_no_atomic_cas)]
{
let elems = iter.into_iter().map(TryMaybeDone::Future).try_collect::<Box<[_]>>().into();
let kind = TryJoinAllKind::Small { elems };
assert_future::<Result<Vec<<I::Item as TryFuture>::Ok>, <I::Item as TryFuture>::Error>, _>(
TryJoinAll { kind },
)
}
#[cfg(not(futures_no_atomic_cas))]
{
let iter = iter.into_iter();
let kind = match iter.size_hint().1 {
None => TryJoinAllKind::Big { fut: iter.collect::<FuturesOrdered<_>>().try_collect() },
Some(max) => {
if max <= join_all::SMALL {
let elems = iter.map(TryMaybeDone::Future).collect::<Box<[_]>>().into();
TryJoinAllKind::Small { elems }
} else {
TryJoinAllKind::Big { fut: iter.collect::<FuturesOrdered<_>>().try_collect() }
}
}
};
assert_future::<Result<Vec<<I::Item as TryFuture>::Ok>, <I::Item as TryFuture>::Error>, _>(
TryJoinAll { kind },
)
}
}

impl<F> Future for TryJoinAll<F>
where
F: TryFuture,
F: TryFuture + Future<Output = Result<F::Ok, F::Error>>,
{
type Output = Result<Vec<F::Ok>, F::Error>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut state = FinalState::AllDone;
match &mut self.kind {
TryJoinAllKind::Small { elems } => {
let mut state = FinalState::AllDone;

for elem in iter_pin_mut(self.elems.as_mut()) {
match elem.try_poll(cx) {
Poll::Pending => state = FinalState::Pending,
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(e)) => {
state = FinalState::Error(e);
break;
for elem in join_all::iter_pin_mut(elems.as_mut()) {
match elem.try_poll(cx) {
Poll::Pending => state = FinalState::Pending,
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(e)) => {
state = FinalState::Error(e);
break;
}
}
}
}
}

match state {
FinalState::Pending => Poll::Pending,
FinalState::AllDone => {
let mut elems = mem::replace(&mut self.elems, Box::pin([]));
let results =
iter_pin_mut(elems.as_mut()).map(|e| e.take_output().unwrap()).collect();
Poll::Ready(Ok(results))
}
FinalState::Error(e) => {
let _ = mem::replace(&mut self.elems, Box::pin([]));
Poll::Ready(Err(e))
match state {
FinalState::Pending => Poll::Pending,
FinalState::AllDone => {
let mut elems = mem::replace(elems, Box::pin([]));
let results = join_all::iter_pin_mut(elems.as_mut())
.map(|e| e.take_output().unwrap())
.collect();
Poll::Ready(Ok(results))
}
FinalState::Error(e) => {
let _ = mem::replace(elems, Box::pin([]));
Poll::Ready(Err(e))
}
}
}
#[cfg(not(futures_no_atomic_cas))]
TryJoinAllKind::Big { fut } => Pin::new(fut).poll(cx),
}
}
}

impl<F: TryFuture> FromIterator<F> for TryJoinAll<F> {
impl<F> FromIterator<F> for TryJoinAll<F>
where
F: TryFuture + Future<Output = Result<F::Ok, F::Error>>,
{
fn from_iter<T: IntoIterator<Item = F>>(iter: T) -> Self {
try_join_all(iter)
}
Expand Down