diff --git a/tokio/src/runtime/context.rs b/tokio/src/runtime/context.rs index d9e8428a837..6de657481e0 100644 --- a/tokio/src/runtime/context.rs +++ b/tokio/src/runtime/context.rs @@ -74,6 +74,7 @@ cfg_rt! { use std::fmt; #[derive(Debug, Clone, Copy)] + #[must_use] pub(crate) enum EnterRuntime { /// Currently in a runtime context. #[cfg_attr(not(feature = "rt"), allow(dead_code))] @@ -84,17 +85,22 @@ cfg_rt! { } #[derive(Debug)] + #[must_use] pub(crate) struct SetCurrentGuard { old_handle: Option, old_seed: RngSeed, } /// Guard tracking that a caller has entered a runtime context. + #[must_use] pub(crate) struct EnterRuntimeGuard { pub(crate) blocking: BlockingRegionGuard, + #[allow(dead_code)] // Only tracking the guard. + pub(crate) handle: SetCurrentGuard, } /// Guard tracking that a caller has entered a blocking region. + #[must_use] pub(crate) struct BlockingRegionGuard { _p: PhantomData>, } @@ -121,10 +127,7 @@ cfg_rt! { /// executor. #[track_caller] pub(crate) fn enter_runtime(handle: &scheduler::Handle, allow_block_in_place: bool) -> EnterRuntimeGuard { - if let Some(enter) = try_enter_runtime(allow_block_in_place) { - // Set the current runtime handle. This should not fail. A later - // cleanup will remove the unwrap(). - try_set_current(handle).unwrap(); + if let Some(enter) = try_enter_runtime(handle, allow_block_in_place) { return enter; } @@ -138,7 +141,7 @@ cfg_rt! { /// Tries to enter a runtime context, returns `None` if already in a runtime /// context. - fn try_enter_runtime(allow_block_in_place: bool) -> Option { + fn try_enter_runtime(handle: &scheduler::Handle, allow_block_in_place: bool) -> Option { CONTEXT.with(|c| { if c.runtime.get().is_entered() { None @@ -146,6 +149,7 @@ cfg_rt! { c.runtime.set(EnterRuntime::Entered { allow_block_in_place }); Some(EnterRuntimeGuard { blocking: BlockingRegionGuard::new(), + handle: c.set_current(handle), }) } }) diff --git a/tokio/tests/rt_common.rs b/tokio/tests/rt_common.rs index a9439c632c0..1039162ac2f 100644 --- a/tokio/tests/rt_common.rs +++ b/tokio/tests/rt_common.rs @@ -243,14 +243,6 @@ rt_test! { tokio::spawn(async move { let (done_tx, mut done_rx) = mpsc::unbounded_channel(); - /* - for _ in 0..100 { - tokio::spawn(async move { }); - } - - tokio::task::yield_now().await; - */ - let mut txs = (0..ITER) .map(|i| { let (tx, rx) = oneshot::channel(); @@ -291,6 +283,31 @@ rt_test! { } } + #[test] + fn spawn_one_from_block_on_called_on_handle() { + let rt = rt(); + let (tx, rx) = oneshot::channel(); + + #[allow(clippy::async_yields_async)] + let handle = rt.handle().block_on(async { + tokio::spawn(async move { + tx.send("ZOMG").unwrap(); + "DONE" + }) + }); + + let out = rt.block_on(async { + let msg = assert_ok!(rx.await); + + let out = assert_ok!(handle.await); + assert_eq!(out, "DONE"); + + msg + }); + + assert_eq!(out, "ZOMG"); + } + #[test] fn spawn_await_chain() { let rt = rt();