Skip to content

Commit

Permalink
Make panic handling better + significantly improve performance by pol…
Browse files Browse the repository at this point in the history
…ling stream in a loop
  • Loading branch information
olegnn committed Jan 9, 2022
1 parent 9bc4e4e commit b39e690
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 90 deletions.
180 changes: 111 additions & 69 deletions futures-util/src/stream/stream/flatten_unordered.rs
Expand Up @@ -50,28 +50,47 @@ struct SharedPollState {
}

impl SharedPollState {
/// Constructs new `SharedPollState` with given state.
fn new(state: u8) -> SharedPollState {
SharedPollState { state: Arc::new(AtomicU8::new(state)) }
/// Constructs new `SharedPollState` with the given state.
fn new(value: u8) -> SharedPollState {
SharedPollState { state: Arc::new(AtomicU8::new(value)) }
}

/// Attempts to start polling, returning stored state in case of success.
/// Returns `None` if state some waker is waking at the moment.
fn start_polling(&self) -> Option<u8> {
fn start_polling(&self) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&SharedPollState)>)> {
self.state
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |state| {
if state & WAKING_ANYTHING == NONE {
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| {
if value & WAKING_ANYTHING == NONE {
Some(POLLING)
} else {
None
}
})
.ok()
.map(|value| {
(
value,
PollStateBomb::new(self, move |state| {
state.stop_polling(NEED_TO_POLL_ALL);
}),
)
})
}

/// Starts the waking process and performs bitwise or with the given value.
fn start_waking(&self, to_poll: u8, waking: u8) -> u8 {
self.state.fetch_or(to_poll | waking, Ordering::SeqCst)
fn start_waking(
&self,
to_poll: u8,
waking: u8,
) -> (u8, PollStateBomb<'_, impl FnOnce(&SharedPollState)>) {
let value = self.state.fetch_or(to_poll | waking, Ordering::SeqCst);

(
value,
PollStateBomb::new(self, move |state| {
state.stop_waking(waking);
}),
)
}

/// Toggles state to non-waking, allowing to start polling.
Expand All @@ -82,8 +101,8 @@ impl SharedPollState {
/// Sets current state to `!POLLING`, allowing to use wakers.
fn stop_polling(&self, to_poll: u8) -> u8 {
self.state
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |state| {
Some((state | to_poll) & !POLLING)
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| {
Some((value | to_poll) & !POLLING)
})
.unwrap()
}
Expand All @@ -98,7 +117,6 @@ struct InnerWaker {
}

unsafe impl Send for InnerWaker {}

unsafe impl Sync for InnerWaker {}

impl InnerWaker {
Expand All @@ -115,20 +133,39 @@ impl InnerWaker {
waker(self_arc.clone())
}

// Flags state that walking is started for the walker with the given value.
fn start_waking(&self) -> u8 {
// Flags state that waking is started for the waker with the given value.
fn start_waking(&self) -> (u8, PollStateBomb<'_, impl FnOnce(&SharedPollState)>) {
self.poll_state.start_waking(self.need_to_poll, self.need_to_poll << 3)
}
}

///
struct PollStateBomb<'a, F: FnOnce(&SharedPollState)> {
state: &'a SharedPollState,
drop: Option<F>,
}

impl<'a, F: FnOnce(&SharedPollState)> PollStateBomb<'a, F> {
fn new(state: &'a SharedPollState, drop: F) -> Self {
Self { state, drop: Some(drop) }
}

fn deactivate(mut self) {
self.drop.take();
}
}

// Flags state that walking is finished for the walker with the given value.
fn stop_waking(&self) -> u8 {
self.poll_state.stop_waking(self.need_to_poll << 3)
impl<F: FnOnce(&SharedPollState)> Drop for PollStateBomb<'_, F> {
fn drop(&mut self) {
if let Some(drop) = self.drop.take() {
(drop)(&self.state);
}
}
}

impl ArcWake for InnerWaker {
fn wake_by_ref(self_arc: &Arc<Self>) {
let poll_state_value = self_arc.start_waking();
let (poll_state_value, state_bomb) = self_arc.start_waking();

// Only call waker if stream isn't being polled because of safety reasons.
// Waker will be called at the end of polling if state was changed.
Expand All @@ -137,14 +174,11 @@ impl ArcWake for InnerWaker {
unsafe { self_arc.inner_waker.get().as_ref().cloned().flatten() }
{
// First, stop waking to allow polling stream
self_arc.stop_waking();
drop(state_bomb);
// Wake inner waker
inner_waker.wake();
return;
}
}

self_arc.stop_waking();
}
}

Expand All @@ -168,30 +202,27 @@ impl<St> PollStreamFut<St> {
}
}

impl<St: Stream> Future for PollStreamFut<St> {
impl<St: Stream + Unpin> Future for PollStreamFut<St> {
type Output = Option<(St::Item, PollStreamFut<St>)>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut stream = self.project().stream;

let item = if let Some(stream) = stream.as_mut().as_pin_mut() {
ready!(stream.poll_next(cx))
} else {
None
};
let out = item.map(|item| (item, PollStreamFut::new(stream.get_mut().take())));

Poll::Ready(
item.map(|item| {
(item, PollStreamFut::new(unsafe { stream.get_unchecked_mut().take() }))
}),
)
Poll::Ready(out)
}
}

pin_project! {
/// Stream for the [`flatten_unordered`](super::StreamExt::flatten_unordered)
/// method.
#[must_use = "streams do nothing unless polled"]
#[project = FlattenUnorderedProj]
pub struct FlattenUnordered<St, U> {
#[pin]
inner_streams: FuturesUnordered<PollStreamFut<U>>,
Expand Down Expand Up @@ -224,7 +255,7 @@ where
impl<St> FlattenUnordered<St, St::Item>
where
St: Stream,
St::Item: Stream,
St::Item: Stream + Unpin,
{
pub(super) fn new(stream: St, limit: Option<usize>) -> FlattenUnordered<St, St::Item> {
let poll_state = SharedPollState::new(NEED_TO_POLL_STREAM);
Expand All @@ -248,18 +279,24 @@ where
}
}

delegate_access_inner!(stream, St, ());
}

impl<St> FlattenUnorderedProj<'_, St, St::Item>
where
St: Stream,
{
/// Checks if current `inner_streams` size is less than optional limit.
fn is_exceeded_limit(&self) -> bool {
self.limit.map(|limit| self.inner_streams.len() >= limit.get()).unwrap_or(false)
self.limit.map_or(false, |limit| self.inner_streams.len() >= limit.get())
}

delegate_access_inner!(stream, St, ());
}

impl<St> FusedStream for FlattenUnordered<St, St::Item>
where
St: FusedStream,
St::Item: FusedStream,
St::Item: FusedStream + Unpin,
<St::Item as Stream>::Item: core::fmt::Debug,
{
fn is_terminated(&self) -> bool {
self.stream.is_terminated() && self.inner_streams.is_empty()
Expand All @@ -269,18 +306,18 @@ where
impl<St> Stream for FlattenUnordered<St, St::Item>
where
St: Stream,
St::Item: Stream,
St::Item: Stream + Unpin,
<St::Item as Stream>::Item: core::fmt::Debug,
{
type Item = <St::Item as Stream>::Item;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut next_item = None;
let mut need_to_poll_next = NONE;
let limit_exceeded = self.is_exceeded_limit();

let mut this = self.as_mut().project();

let mut poll_state_value = match this.poll_state.start_polling() {
let (mut poll_state_value, state_bomb) = match this.poll_state.start_polling() {
Some(value) => value,
_ => {
// Waker was called, just wait for the next poll
Expand All @@ -289,44 +326,48 @@ where
};

let mut polling_with_two_wakers =
!limit_exceeded && poll_state_value & NEED_TO_POLL_ALL == NEED_TO_POLL_ALL;
!this.is_exceeded_limit() && poll_state_value & NEED_TO_POLL_ALL == NEED_TO_POLL_ALL;

if poll_state_value & NEED_TO_POLL_STREAM != NONE {
if !limit_exceeded && !*this.is_stream_done {
match if polling_with_two_wakers {
// Safety: now state is `POLLING`.
let waker = unsafe { InnerWaker::replace_waker(this.stream_waker, cx) };
let mut cx = Context::from_waker(&waker);
this.stream.as_mut().poll_next(&mut cx)
loop {
if this.is_exceeded_limit() || *this.is_stream_done {
polling_with_two_wakers = false;
need_to_poll_next |= NEED_TO_POLL_STREAM;

break;
} else {
this.stream.as_mut().poll_next(cx)
} {
Poll::Ready(Some(inner_stream)) => {
this.inner_streams.as_mut().push(PollStreamFut::new(inner_stream));
need_to_poll_next |= NEED_TO_POLL_STREAM;
// Polling inner streams in current iteration with the same context
// is ok because we already received `Poll::Ready` from
// stream
poll_state_value |= NEED_TO_POLL_INNER_STREAMS;
polling_with_two_wakers = false;
*this.is_stream_done = false;
}
Poll::Ready(None) => {
// Polling inner streams in current iteration with the same context
// is ok because we already received `Poll::Ready` from
// stream
polling_with_two_wakers = false;
*this.is_stream_done = true;
}
Poll::Pending => {
if !polling_with_two_wakers {
match if polling_with_two_wakers {
// Safety: now state is `POLLING`.
let waker = unsafe { InnerWaker::replace_waker(this.stream_waker, cx) };
let mut cx = Context::from_waker(&waker);
this.stream.as_mut().poll_next(&mut cx)
} else {
this.stream.as_mut().poll_next(cx)
} {
Poll::Ready(Some(inner_stream)) => {
this.inner_streams.as_mut().push(PollStreamFut::new(inner_stream));
need_to_poll_next |= NEED_TO_POLL_STREAM;
// Polling inner streams in current iteration with the same context
// is ok because we already received `Poll::Ready` from
// stream
poll_state_value |= NEED_TO_POLL_INNER_STREAMS;
*this.is_stream_done = false;
}
Poll::Ready(None) => {
// Polling inner streams in current iteration with the same context
// is ok because we already received `Poll::Ready` from
// stream
*this.is_stream_done = true;
}
Poll::Pending => {
if !polling_with_two_wakers {
need_to_poll_next |= NEED_TO_POLL_STREAM;
}
*this.is_stream_done = false;
break;
}
*this.is_stream_done = false;
}
}
} else {
need_to_poll_next |= NEED_TO_POLL_STREAM;
}
}

Expand All @@ -345,7 +386,7 @@ where
need_to_poll_next |= NEED_TO_POLL_INNER_STREAMS;
}
Poll::Ready(Some(None)) => {
need_to_poll_next |= NEED_TO_POLL_ALL;
need_to_poll_next |= NEED_TO_POLL_INNER_STREAMS;
}
Poll::Pending => {
if !polling_with_two_wakers {
Expand All @@ -358,14 +399,15 @@ where
}
}

state_bomb.deactivate();
poll_state_value = this.poll_state.stop_polling(need_to_poll_next);
let is_done = *this.is_stream_done && this.inner_streams.is_empty();

if next_item.is_some() || is_done {
Poll::Ready(next_item)
} else {
if poll_state_value & NEED_TO_POLL_ALL != NONE
|| !self.is_exceeded_limit() && need_to_poll_next & NEED_TO_POLL_STREAM != NONE
|| !this.is_exceeded_limit() && need_to_poll_next & NEED_TO_POLL_STREAM != NONE
{
cx.waker().wake_by_ref();
}
Expand Down
13 changes: 7 additions & 6 deletions futures-util/src/stream/stream/mod.rs
Expand Up @@ -210,7 +210,7 @@ delegate_all!(
FlattenUnordered<St>(
flatten_unordered::FlattenUnordered<St, St::Item>
): Debug + Sink + Stream + FusedStream + AccessInner[St, (.)] + New[|x: St, limit: Option<usize>| flatten_unordered::FlattenUnordered::new(x, limit)]
where St: Stream, St::Item: Stream
where St: Stream, St::Item: Stream, St::Item: Unpin
);

#[cfg(not(futures_no_atomic_cas))]
Expand All @@ -220,7 +220,7 @@ delegate_all!(
FlatMapUnordered<St, U, F>(
FlattenUnordered<Map<St, F>>
): Debug + Sink + Stream + FusedStream + AccessInner[St, (. .)] + New[|x: St, limit: Option<usize>, f: F| FlattenUnordered::new(Map::new(x, f), limit)]
where St: Stream, U: Stream, F: FnMut(St::Item) -> U
where St: Stream, U: Stream, U: Unpin, F: FnMut(St::Item) -> U
);

#[cfg(not(futures_no_atomic_cas))]
Expand Down Expand Up @@ -790,11 +790,11 @@ pub trait StreamExt: Stream {
/// assert_eq!(output, vec![1, 2, 3, 4]);
/// # });
/// ```
#[cfg_attr(feature = "cfg-target-has-atomic", cfg(target_has_atomic = "ptr"))]
#[cfg(not(futures_no_atomic_cas))]
#[cfg(feature = "alloc")]
fn flatten_unordered(self, limit: impl Into<Option<usize>>) -> FlattenUnordered<Self>
where
Self::Item: Stream,
Self::Item: Stream + Unpin,
Self: Sized,
{
FlattenUnordered::new(self, limit.into())
Expand Down Expand Up @@ -871,15 +871,16 @@ pub trait StreamExt: Stream {
///
/// assert_eq!(vec![1usize, 2, 2, 3, 3, 3, 4, 4, 4, 4], values);
/// # });
#[cfg_attr(feature = "cfg-target-has-atomic", cfg(target_has_atomic = "ptr"))]
/// ```
#[cfg(not(futures_no_atomic_cas))]
#[cfg(feature = "alloc")]
fn flat_map_unordered<U, F>(
self,
limit: impl Into<Option<usize>>,
f: F,
) -> FlatMapUnordered<Self, U, F>
where
U: Stream,
U: Stream + Unpin,
F: FnMut(Self::Item) -> U,
Self: Sized,
{
Expand Down

0 comments on commit b39e690

Please sign in to comment.