diff --git a/tokio/src/task/consume_budget.rs b/tokio/src/task/consume_budget.rs new file mode 100644 index 00000000000..c8b2d7e5ceb --- /dev/null +++ b/tokio/src/task/consume_budget.rs @@ -0,0 +1,45 @@ +use std::task::Poll; + +/// Consumes a unit of budget and returns the execution back to the Tokio +/// runtime *if* the task's coop budget was exhausted. +/// +/// The task will only yield if its entire coop budget has been exhausted. +/// This function can can be used in order to insert optional yield points into long +/// computations that do not use Tokio resources like sockets or semaphores, +/// without redundantly yielding to the runtime each time. +/// +/// **Note**: This is an [unstable API][unstable]. The public API of this type +/// may break in 1.x releases. See [the documentation on unstable +/// features][unstable] for details. +/// +/// # Examples +/// +/// Make sure that a function which returns a sum of (potentially lots of) +/// iterated values is cooperative. +/// +/// ``` +/// async fn sum_iterator(input: &mut impl std::iter::Iterator) -> i64 { +/// let mut sum: i64 = 0; +/// while let Some(i) = input.next() { +/// sum += i; +/// tokio::task::consume_budget().await +/// } +/// sum +/// } +/// ``` +/// [unstable]: crate#unstable-features +#[cfg_attr(docsrs, doc(cfg(all(tokio_unstable, feature = "rt"))))] +pub async fn consume_budget() { + let mut status = Poll::Pending; + + crate::future::poll_fn(move |cx| { + if status.is_ready() { + return status; + } + status = crate::coop::poll_proceed(cx).map(|restore| { + restore.made_progress(); + }); + status + }) + .await +} diff --git a/tokio/src/task/mod.rs b/tokio/src/task/mod.rs index cebc269bb40..4057d535218 100644 --- a/tokio/src/task/mod.rs +++ b/tokio/src/task/mod.rs @@ -291,6 +291,11 @@ cfg_rt! { mod yield_now; pub use yield_now::yield_now; + cfg_unstable! { + mod consume_budget; + pub use consume_budget::consume_budget; + } + mod local; pub use local::{spawn_local, LocalSet}; diff --git a/tokio/tests/rt_common.rs b/tokio/tests/rt_common.rs index 1f7a378549b..14e19095933 100644 --- a/tokio/tests/rt_common.rs +++ b/tokio/tests/rt_common.rs @@ -1054,6 +1054,31 @@ rt_test! { }); } + #[cfg(tokio_unstable)] + #[test] + fn coop_consume_budget() { + let rt = rt(); + + rt.block_on(async { + poll_fn(|cx| { + let counter = Arc::new(std::sync::Mutex::new(0)); + let counter_clone = Arc::clone(&counter); + let mut worker = Box::pin(async move { + // Consume the budget until a yield happens + for _ in 0..1000 { + *counter.lock().unwrap() += 1; + task::consume_budget().await + } + }); + // Assert that the worker was yielded and it didn't manage + // to finish the whole work (assuming the total budget of 128) + assert!(Pin::new(&mut worker).poll(cx).is_pending()); + assert!(*counter_clone.lock().unwrap() < 1000); + std::task::Poll::Ready(()) + }).await; + }); + } + // Tests that the "next task" scheduler optimization is not able to starve // other tasks. #[test]