Skip to content

Commit

Permalink
Change SelectAll iterators to return stream instead of StreamFuture
Browse files Browse the repository at this point in the history
  • Loading branch information
taiki-e committed May 10, 2021
1 parent 5818cca commit ddeec92
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 33 deletions.
20 changes: 10 additions & 10 deletions futures-util/src/stream/futures_unordered/iter.rs
Expand Up @@ -4,33 +4,33 @@ use core::marker::PhantomData;
use core::pin::Pin;
use core::sync::atomic::Ordering::Relaxed;

#[derive(Debug)]
/// Mutable iterator over all futures in the unordered set.
#[derive(Debug)]
pub struct IterPinMut<'a, Fut> {
pub(super) task: *const Task<Fut>,
pub(super) len: usize,
pub(super) _marker: PhantomData<&'a mut FuturesUnordered<Fut>>,
}

#[derive(Debug)]
/// Mutable iterator over all futures in the unordered set.
#[derive(Debug)]
pub struct IterMut<'a, Fut: Unpin>(pub(super) IterPinMut<'a, Fut>);

#[derive(Debug)]
/// Immutable iterator over all futures in the unordered set.
#[derive(Debug)]
pub struct IterPinRef<'a, Fut> {
pub(super) task: *const Task<Fut>,
pub(super) len: usize,
pub(super) pending_next_all: *mut Task<Fut>,
pub(super) _marker: PhantomData<&'a FuturesUnordered<Fut>>,
}

#[derive(Debug)]
/// Immutable iterator over all the futures in the unordered set.
#[derive(Debug)]
pub struct Iter<'a, Fut: Unpin>(pub(super) IterPinRef<'a, Fut>);

#[derive(Debug)]
/// Owned iterator over all futures in the unordered set.
#[derive(Debug)]
pub struct IntoIter<Fut: Unpin> {
pub(super) len: usize,
pub(super) inner: FuturesUnordered<Fut>,
Expand All @@ -39,7 +39,7 @@ pub struct IntoIter<Fut: Unpin> {
impl<Fut: Unpin> Iterator for IntoIter<Fut> {
type Item = Fut;

fn next(&mut self) -> Option<Fut> {
fn next(&mut self) -> Option<Self::Item> {
// `head_all` can be accessed directly and we don't need to spin on
// `Task::next_all` since we have exclusive access to the set.
let task = self.inner.head_all.get_mut();
Expand Down Expand Up @@ -73,7 +73,7 @@ impl<Fut: Unpin> ExactSizeIterator for IntoIter<Fut> {}
impl<'a, Fut> Iterator for IterPinMut<'a, Fut> {
type Item = Pin<&'a mut Fut>;

fn next(&mut self) -> Option<Pin<&'a mut Fut>> {
fn next(&mut self) -> Option<Self::Item> {
if self.task.is_null() {
return None;
}
Expand Down Expand Up @@ -102,7 +102,7 @@ impl<Fut> ExactSizeIterator for IterPinMut<'_, Fut> {}
impl<'a, Fut: Unpin> Iterator for IterMut<'a, Fut> {
type Item = &'a mut Fut;

fn next(&mut self) -> Option<&'a mut Fut> {
fn next(&mut self) -> Option<Self::Item> {
self.0.next().map(Pin::get_mut)
}

Expand All @@ -116,7 +116,7 @@ impl<Fut: Unpin> ExactSizeIterator for IterMut<'_, Fut> {}
impl<'a, Fut> Iterator for IterPinRef<'a, Fut> {
type Item = Pin<&'a Fut>;

fn next(&mut self) -> Option<Pin<&'a Fut>> {
fn next(&mut self) -> Option<Self::Item> {
if self.task.is_null() {
return None;
}
Expand Down Expand Up @@ -145,7 +145,7 @@ impl<Fut> ExactSizeIterator for IterPinRef<'_, Fut> {}
impl<'a, Fut: Unpin> Iterator for Iter<'a, Fut> {
type Item = &'a Fut;

fn next(&mut self) -> Option<&'a Fut> {
fn next(&mut self) -> Option<Self::Item> {
self.0.next().map(Pin::get_ref)
}

Expand Down
2 changes: 1 addition & 1 deletion futures-util/src/stream/mod.rs
Expand Up @@ -105,7 +105,7 @@ cfg_target_has_atomic! {
pub use self::futures_unordered::FuturesUnordered;

#[cfg(feature = "alloc")]
mod select_all;
pub mod select_all;
#[cfg(feature = "alloc")]
pub use self::select_all::{select_all, SelectAll};

Expand Down
100 changes: 78 additions & 22 deletions futures-util/src/stream/select_all.rs
Expand Up @@ -11,7 +11,7 @@ use futures_core::task::{Context, Poll};
use pin_project_lite::pin_project;

use super::assert_stream;
use crate::stream::futures_unordered::{IntoIter, Iter, IterMut, IterPinMut, IterPinRef};
use crate::stream::futures_unordered;
use crate::stream::{FuturesUnordered, StreamExt, StreamFuture};

pin_project! {
Expand Down Expand Up @@ -72,23 +72,13 @@ impl<St: Stream + Unpin> SelectAll<St> {
}

/// Returns an iterator that allows inspecting each future in the set.
pub fn iter(&self) -> Iter<'_, StreamFuture<St>> {
self.inner.iter()
}

/// Returns an iterator that allows inspecting each future in the set.
pub fn iter_pin_ref(self: Pin<&'_ Self>) -> IterPinRef<'_, StreamFuture<St>> {
self.project_ref().inner.iter_pin_ref()
}

/// Returns an iterator that allows modifying each future in the set.
pub fn iter_mut(&mut self) -> IterMut<'_, StreamFuture<St>> {
self.inner.iter_mut()
pub fn iter(&self) -> Iter<'_, St> {
Iter(self.inner.iter())
}

/// Returns an iterator that allows modifying each future in the set.
pub fn iter_pin_mut(self: Pin<&mut Self>) -> IterPinMut<'_, StreamFuture<St>> {
self.project().inner.iter_pin_mut()
pub fn iter_mut(&mut self) -> IterMut<'_, St> {
IterMut(self.inner.iter_mut())
}

/// Clears the set, removing all futures.
Expand Down Expand Up @@ -172,28 +162,94 @@ impl<St: Stream + Unpin> Extend<St> for SelectAll<St> {
}

impl<St: Stream + Unpin> IntoIterator for SelectAll<St> {
type Item = StreamFuture<St>;
type IntoIter = IntoIter<StreamFuture<St>>;
type Item = St;
type IntoIter = IntoIter<St>;

fn into_iter(self) -> Self::IntoIter {
self.inner.into_iter()
IntoIter(self.inner.into_iter())
}
}

impl<'a, St: Stream + Unpin> IntoIterator for &'a SelectAll<St> {
type Item = &'a StreamFuture<St>;
type IntoIter = Iter<'a, StreamFuture<St>>;
type Item = &'a St;
type IntoIter = Iter<'a, St>;

fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}

impl<'a, St: Stream + Unpin> IntoIterator for &'a mut SelectAll<St> {
type Item = &'a mut StreamFuture<St>;
type IntoIter = IterMut<'a, StreamFuture<St>>;
type Item = &'a mut St;
type IntoIter = IterMut<'a, St>;

fn into_iter(self) -> Self::IntoIter {
self.iter_mut()
}
}

/// Immutable iterator over all futures in the unordered set.
#[derive(Debug)]
pub struct Iter<'a, St: Unpin>(futures_unordered::Iter<'a, StreamFuture<St>>);

/// Mutable iterator over all futures in the unordered set.
#[derive(Debug)]
pub struct IterMut<'a, St: Unpin>(futures_unordered::IterMut<'a, StreamFuture<St>>);

/// Owned iterator over all futures in the unordered set.
#[derive(Debug)]
pub struct IntoIter<St: Unpin>(futures_unordered::IntoIter<StreamFuture<St>>);

impl<'a, St: Stream + Unpin> Iterator for Iter<'a, St> {
type Item = &'a St;

fn next(&mut self) -> Option<Self::Item> {
let st = self.0.next()?;
let next = st.get_ref();
// This should always be true because FuturesUnordered removes completed futures.
debug_assert!(next.is_some());
next
}

fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}
}

impl<St: Stream + Unpin> ExactSizeIterator for Iter<'_, St> {}

impl<'a, St: Stream + Unpin> Iterator for IterMut<'a, St> {
type Item = &'a mut St;

fn next(&mut self) -> Option<Self::Item> {
let st = self.0.next()?;
let next = st.get_mut();
// This should always be true because FuturesUnordered removes completed futures.
debug_assert!(next.is_some());
next
}

fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}
}

impl<St: Stream + Unpin> ExactSizeIterator for IterMut<'_, St> {}

impl<St: Stream + Unpin> Iterator for IntoIter<St> {
type Item = St;

fn next(&mut self) -> Option<Self::Item> {
let st = self.0.next()?;
let next = st.into_inner();
// This should always be true because FuturesUnordered removes completed futures.
debug_assert!(next.is_some());
next
}

fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}
}

impl<St: Stream + Unpin> ExactSizeIterator for IntoIter<St> {}
96 changes: 96 additions & 0 deletions futures/tests/stream_select_all.rs
Expand Up @@ -99,3 +99,99 @@ fn clear() {
tasks.clear();
assert!(!tasks.is_terminated());
}

#[test]
fn iter_mut() {
let mut stream =
vec![stream::pending::<()>(), stream::pending::<()>(), stream::pending::<()>()]
.into_iter()
.collect::<SelectAll<_>>();

let mut iter = stream.iter_mut();
assert_eq!(iter.len(), 3);
assert!(iter.next().is_some());
assert_eq!(iter.len(), 2);
assert!(iter.next().is_some());
assert_eq!(iter.len(), 1);
assert!(iter.next().is_some());
assert_eq!(iter.len(), 0);
assert!(iter.next().is_none());

let mut stream = vec![stream::iter(vec![]), stream::iter(vec![1]), stream::iter(vec![2])]
.into_iter()
.collect::<SelectAll<_>>();

assert_eq!(stream.len(), 3);
assert_eq!(block_on(stream.next()), Some(1));
assert_eq!(stream.len(), 2);
let mut iter = stream.iter_mut();
assert_eq!(iter.len(), 2);
assert!(iter.next().is_some());
assert_eq!(iter.len(), 1);
assert!(iter.next().is_some());
assert_eq!(iter.len(), 0);
assert!(iter.next().is_none());

assert_eq!(block_on(stream.next()), Some(2));
assert_eq!(stream.len(), 2);
assert_eq!(block_on(stream.next()), None);
let mut iter = stream.iter_mut();
assert_eq!(iter.len(), 0);
assert!(iter.next().is_none());
}

#[test]
fn iter() {
let stream = vec![stream::pending::<()>(), stream::pending::<()>(), stream::pending::<()>()]
.into_iter()
.collect::<SelectAll<_>>();

let mut iter = stream.iter();
assert_eq!(iter.len(), 3);
assert!(iter.next().is_some());
assert_eq!(iter.len(), 2);
assert!(iter.next().is_some());
assert_eq!(iter.len(), 1);
assert!(iter.next().is_some());
assert_eq!(iter.len(), 0);
assert!(iter.next().is_none());

let mut stream = vec![stream::iter(vec![]), stream::iter(vec![1]), stream::iter(vec![2])]
.into_iter()
.collect::<SelectAll<_>>();

assert_eq!(stream.len(), 3);
assert_eq!(block_on(stream.next()), Some(1));
assert_eq!(stream.len(), 2);
let mut iter = stream.iter();
assert_eq!(iter.len(), 2);
assert!(iter.next().is_some());
assert_eq!(iter.len(), 1);
assert!(iter.next().is_some());
assert_eq!(iter.len(), 0);
assert!(iter.next().is_none());

assert_eq!(block_on(stream.next()), Some(2));
assert_eq!(stream.len(), 2);
assert_eq!(block_on(stream.next()), None);
let mut iter = stream.iter();
assert_eq!(iter.len(), 0);
assert!(iter.next().is_none());
}

#[test]
fn into_iter() {
let stream = vec![stream::pending::<()>(), stream::pending::<()>(), stream::pending::<()>()]
.into_iter()
.collect::<SelectAll<_>>();

let mut iter = stream.into_iter();
assert_eq!(iter.len(), 3);
assert!(iter.next().is_some());
assert_eq!(iter.len(), 2);
assert!(iter.next().is_some());
assert_eq!(iter.len(), 1);
assert!(iter.next().is_some());
assert_eq!(iter.len(), 0);
assert!(iter.next().is_none());
}

0 comments on commit ddeec92

Please sign in to comment.