Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

macros: don't take ownership of futures in macros #5087

Merged
merged 9 commits into from Oct 14, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 16 additions & 5 deletions tokio/src/future/poll_fn.rs
Expand Up @@ -4,22 +4,25 @@

use std::fmt;
use std::future::Future;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};

/// Future for the [`poll_fn`] function.
pub struct PollFn<F> {
f: F,
_pinned: PhantomPinned,
Darksonn marked this conversation as resolved.
Show resolved Hide resolved
}

impl<F> Unpin for PollFn<F> {}

/// Creates a new future wrapping around a function returning [`Poll`].
pub fn poll_fn<T, F>(f: F) -> PollFn<F>
where
F: FnMut(&mut Context<'_>) -> Poll<T>,
{
PollFn { f }
PollFn {
f,
_pinned: PhantomPinned,
}
}

impl<F> fmt::Debug for PollFn<F> {
Expand All @@ -34,7 +37,15 @@ where
{
type Output = T;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> {
(self.f)(cx)
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> {
// SAFETY: We never construct a `Pin<&mut F>` anywhere, so accessing `f`
// mutably in an unpinned way is sound.
//
// This is, strictly speaking, not necessary. We could make `PollFn`
// unconditionally `Unpin` and avoid this unsafe. However, making this
// struct `!Unpin` mitigates the issues described here:
// <https://internals.rust-lang.org/t/surprising-soundness-trouble-around-pollfn/17484>
let me = unsafe { Pin::into_inner_unchecked(self) };
(me.f)(cx)
}
}
8 changes: 7 additions & 1 deletion tokio/src/macros/join.rs
Expand Up @@ -74,6 +74,12 @@ macro_rules! join {
// the requirement of `Pin::new_unchecked` called below.
let mut futures = ( $( maybe_done($e), )* );

// This assignment makes sure that the `poll_fn` closure only has a
// reference to the futures, instead of taking ownership of them. This
// mitigates the issue described in
// <https://internals.rust-lang.org/t/surprising-soundness-trouble-around-pollfn/17484>
Darksonn marked this conversation as resolved.
Show resolved Hide resolved
let mut futures = &mut futures;

// Each time the future created by poll_fn is polled, a different future will be polled first
// to ensure every future passed to join! gets a chance to make progress even if
// one of the futures consumes the whole budget.
Expand Down Expand Up @@ -106,7 +112,7 @@ macro_rules! join {
to_run -= 1;

// Extract the future for this branch from the tuple.
let ( $($skip,)* fut, .. ) = &mut futures;
let ( $($skip,)* fut, .. ) = &mut *futures;

// Safety: future is stored on the stack above
// and never moved.
Expand Down
8 changes: 7 additions & 1 deletion tokio/src/macros/select.rs
Expand Up @@ -462,6 +462,12 @@ macro_rules! select {
// satisfy the requirement of `Pin::new_unchecked` called below.
let mut futures = ( $( $fut , )+ );

// This assignment makes sure that the `poll_fn` closure only has a
// reference to the futures, instead of taking ownership of them.
// This mitigates the issue described in
// <https://internals.rust-lang.org/t/surprising-soundness-trouble-around-pollfn/17484>
let mut futures = &mut futures;

$crate::macros::support::poll_fn(|cx| {
// Track if any branch returns pending. If no branch completes
// **or** returns pending, this implies that all branches are
Expand Down Expand Up @@ -497,7 +503,7 @@ macro_rules! select {

// Extract the future for this branch from the
// tuple
let ( $($skip,)* fut, .. ) = &mut futures;
let ( $($skip,)* fut, .. ) = &mut *futures;

// Safety: future is stored on the stack above
// and never moved.
Expand Down
8 changes: 7 additions & 1 deletion tokio/src/macros/try_join.rs
Expand Up @@ -120,6 +120,12 @@ macro_rules! try_join {
// the requirement of `Pin::new_unchecked` called below.
let mut futures = ( $( maybe_done($e), )* );

// This assignment makes sure that the `poll_fn` closure only has a
// reference to the futures, instead of taking ownership of them. This
// mitigates the issue described in
// <https://internals.rust-lang.org/t/surprising-soundness-trouble-around-pollfn/17484>
let mut futures = &mut futures;

// Each time the future created by poll_fn is polled, a different future will be polled first
// to ensure every future passed to join! gets a chance to make progress even if
// one of the futures consumes the whole budget.
Expand Down Expand Up @@ -152,7 +158,7 @@ macro_rules! try_join {
to_run -= 1;

// Extract the future for this branch from the tuple.
let ( $($skip,)* fut, .. ) = &mut futures;
let ( $($skip,)* fut, .. ) = &mut *futures;

// Safety: future is stored on the stack above
// and never moved.
Expand Down
6 changes: 4 additions & 2 deletions tokio/tests/macros_join.rs
Expand Up @@ -3,6 +3,7 @@
use std::sync::Arc;

#[cfg(tokio_wasm_not_wasi)]
#[cfg(target_pointer_width = "64")]
use wasm_bindgen_test::wasm_bindgen_test as test;
#[cfg(tokio_wasm_not_wasi)]
use wasm_bindgen_test::wasm_bindgen_test as maybe_tokio_test;
Expand Down Expand Up @@ -64,6 +65,7 @@ async fn two_await() {
}

#[test]
#[cfg(target_pointer_width = "64")]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we be making similar assertions about the size of the futures on 32-bit platforms as well? or, calculate the expected size by adding either 8 or 4 bytes based on the value of target_pointer_width?

fn join_size() {
use futures::future;
use std::mem;
Expand All @@ -72,14 +74,14 @@ fn join_size() {
let ready = future::ready(0i32);
tokio::join!(ready)
};
assert_eq!(mem::size_of_val(&fut), 20);
assert_eq!(mem::size_of_val(&fut), 32);

let fut = async {
let ready1 = future::ready(0i32);
let ready2 = future::ready(0i32);
tokio::join!(ready1, ready2)
};
assert_eq!(mem::size_of_val(&fut), 32);
assert_eq!(mem::size_of_val(&fut), 48);
}

async fn non_cooperative_task(permits: Arc<Semaphore>) -> usize {
Expand Down
7 changes: 4 additions & 3 deletions tokio/tests/macros_select.rs
Expand Up @@ -207,6 +207,7 @@ async fn nested() {
}

#[maybe_tokio_test]
#[cfg(target_pointer_width = "64")]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as with join, could we make a separate set of assertions about the future size on 32-bit platforms?

async fn struct_size() {
use futures::future;
use std::mem;
Expand All @@ -219,7 +220,7 @@ async fn struct_size() {
}
};

assert!(mem::size_of_val(&fut) <= 32);
assert_eq!(mem::size_of_val(&fut), 40);

let fut = async {
let ready1 = future::ready(0i32);
Expand All @@ -231,7 +232,7 @@ async fn struct_size() {
}
};

assert!(mem::size_of_val(&fut) <= 40);
assert_eq!(mem::size_of_val(&fut), 48);

let fut = async {
let ready1 = future::ready(0i32);
Expand All @@ -245,7 +246,7 @@ async fn struct_size() {
}
};

assert!(mem::size_of_val(&fut) <= 48);
assert_eq!(mem::size_of_val(&fut), 56);
}

#[maybe_tokio_test]
Expand Down
5 changes: 3 additions & 2 deletions tokio/tests/macros_try_join.rs
Expand Up @@ -88,6 +88,7 @@ async fn err_abort_early() {
}

#[test]
#[cfg(target_pointer_width = "64")]
fn join_size() {
use futures::future;
use std::mem;
Expand All @@ -96,14 +97,14 @@ fn join_size() {
let ready = future::ready(ok(0i32));
tokio::try_join!(ready)
};
assert_eq!(mem::size_of_val(&fut), 20);
assert_eq!(mem::size_of_val(&fut), 32);

let fut = async {
let ready1 = future::ready(ok(0i32));
let ready2 = future::ready(ok(0i32));
tokio::try_join!(ready1, ready2)
};
assert_eq!(mem::size_of_val(&fut), 32);
assert_eq!(mem::size_of_val(&fut), 48);
}

fn ok<T>(val: T) -> Result<T, ()> {
Expand Down