diff --git a/tokio-util/src/context.rs b/tokio-util/src/context.rs index ae954d85c4a..990c0f14442 100644 --- a/tokio-util/src/context.rs +++ b/tokio-util/src/context.rs @@ -34,7 +34,8 @@ impl Future for TokioContext<'_, F> { let handle = me.handle; let fut = me.inner; - handle.enter(|| fut.poll(cx)) + let _enter = handle.enter(); + fut.poll(cx) } } diff --git a/tokio/src/runtime/blocking/pool.rs b/tokio/src/runtime/blocking/pool.rs index 2d44f896c1b..d0f2c1c8d14 100644 --- a/tokio/src/runtime/blocking/pool.rs +++ b/tokio/src/runtime/blocking/pool.rs @@ -232,11 +232,9 @@ impl Spawner { builder .spawn(move || { // Only the reference should be moved into the closure - let rt = &rt; - rt.enter(move || { - rt.blocking_spawner.inner.run(worker_id); - drop(shutdown_tx); - }) + let _enter = crate::runtime::context::enter(rt.clone()); + rt.blocking_spawner.inner.run(worker_id); + drop(shutdown_tx); }) .unwrap() } diff --git a/tokio/src/runtime/builder.rs b/tokio/src/runtime/builder.rs index 735e9b6a975..bc0b29b926d 100644 --- a/tokio/src/runtime/builder.rs +++ b/tokio/src/runtime/builder.rs @@ -508,7 +508,8 @@ cfg_rt_multi_thread! { }; // Spawn the thread pool workers - handle.enter(|| launch.launch()); + let _enter = crate::runtime::context::enter(handle.clone()); + launch.launch(); Ok(Runtime { kind: Kind::ThreadPool(scheduler), diff --git a/tokio/src/runtime/context.rs b/tokio/src/runtime/context.rs index 9dfca8d8ab0..0817019db48 100644 --- a/tokio/src/runtime/context.rs +++ b/tokio/src/runtime/context.rs @@ -60,24 +60,20 @@ cfg_rt! { /// Set this [`Handle`] as the current active [`Handle`]. /// /// [`Handle`]: Handle -pub(crate) fn enter(new: Handle, f: F) -> R -where - F: FnOnce() -> R, -{ - struct DropGuard(Option); - - impl Drop for DropGuard { - fn drop(&mut self) { - CONTEXT.with(|ctx| { - *ctx.borrow_mut() = self.0.take(); - }); - } - } - - let _guard = CONTEXT.with(|ctx| { +pub(crate) fn enter(new: Handle) -> EnterGuard { + CONTEXT.with(|ctx| { let old = ctx.borrow_mut().replace(new); - DropGuard(old) - }); + EnterGuard(old) + }) +} + +#[derive(Debug)] +pub(crate) struct EnterGuard(Option); - f() +impl Drop for EnterGuard { + fn drop(&mut self) { + CONTEXT.with(|ctx| { + *ctx.borrow_mut() = self.0.take(); + }); + } } diff --git a/tokio/src/runtime/handle.rs b/tokio/src/runtime/handle.rs index d48b6242320..9c2cfa5f1db 100644 --- a/tokio/src/runtime/handle.rs +++ b/tokio/src/runtime/handle.rs @@ -1,4 +1,4 @@ -use crate::runtime::{blocking, context, driver, Spawner}; +use crate::runtime::{blocking, driver, Spawner}; /// Handle to the runtime. /// @@ -27,13 +27,13 @@ pub(crate) struct Handle { } impl Handle { - /// Enter the runtime context. This allows you to construct types that must - /// have an executor available on creation such as [`Sleep`] or [`TcpStream`]. - /// It will also allow you to call methods such as [`tokio::spawn`]. - pub(crate) fn enter(&self, f: F) -> R - where - F: FnOnce() -> R, - { - context::enter(self.clone(), f) - } + // /// Enter the runtime context. This allows you to construct types that must + // /// have an executor available on creation such as [`Sleep`] or [`TcpStream`]. + // /// It will also allow you to call methods such as [`tokio::spawn`]. + // pub(crate) fn enter(&self, f: F) -> R + // where + // F: FnOnce() -> R, + // { + // context::enter(self.clone(), f) + // } } diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs index 7712a7f8525..7ce3881cb68 100644 --- a/tokio/src/runtime/mod.rs +++ b/tokio/src/runtime/mod.rs @@ -262,6 +262,16 @@ cfg_rt! { blocking_pool: BlockingPool, } + /// Runtime context guard. + /// + /// Returned by [`Runtime::enter`], the context guard exits the runtime + /// context on drop. + #[derive(Debug)] + pub struct EnterGuard<'a> { + rt: &'a Runtime, + guard: context::EnterGuard, + } + /// The runtime executor is either a thread-pool or a current-thread executor. #[derive(Debug)] enum Kind { @@ -356,25 +366,26 @@ cfg_rt! { } } - /// Run a future to completion on the Tokio runtime. This is the runtime's - /// entry point. + /// Run a future to completion on the Tokio runtime. This is the + /// runtime's entry point. /// /// This runs the given future on the runtime, blocking until it is - /// complete, and yielding its resolved result. Any tasks or timers which - /// the future spawns internally will be executed on the runtime. + /// complete, and yielding its resolved result. Any tasks or timers + /// which the future spawns internally will be executed on the runtime. /// - /// When this runtime is configured with `core_threads = 0`, only the first call - /// to `block_on` will run the IO and timer drivers. Calls to other methods _before_ the first - /// `block_on` completes will just hook into the driver running on the thread - /// that first called `block_on`. This means that the driver may be passed - /// from thread to thread by the user between calls to `block_on`. + /// When this runtime is configured with `core_threads = 0`, only the + /// first call to `block_on` will run the IO and timer drivers. Calls to + /// other methods _before_ the first `block_on` completes will just hook + /// into the driver running on the thread that first called `block_on`. + /// This means that the driver may be passed from thread to thread by + /// the user between calls to `block_on`. /// /// This method may not be called from an asynchronous context. /// /// # Panics /// - /// This function panics if the provided future panics, or if called within an - /// asynchronous execution context. + /// This function panics if the provided future panics, or if called + /// within an asynchronous execution context. /// /// # Examples /// @@ -392,17 +403,21 @@ cfg_rt! { /// /// [handle]: fn@Handle::block_on pub fn block_on(&self, future: F) -> F::Output { - self.handle.enter(|| match &self.kind { + let _enter = self.enter(); + + match &self.kind { #[cfg(feature = "rt")] Kind::CurrentThread(exec) => exec.block_on(future), #[cfg(feature = "rt-multi-thread")] Kind::ThreadPool(exec) => exec.block_on(future), - }) + } } - /// Enter the runtime context. This allows you to construct types that must - /// have an executor available on creation such as [`Sleep`] or [`TcpStream`]. - /// It will also allow you to call methods such as [`tokio::spawn`]. + /// Enter the runtime context. + /// + /// This allows you to construct types that must have an executor + /// available on creation such as [`Sleep`] or [`TcpStream`]. It will + /// also allow you to call methods such as [`tokio::spawn`]. /// /// [`Sleep`]: struct@crate::time::Sleep /// [`TcpStream`]: struct@crate::net::TcpStream @@ -426,14 +441,15 @@ cfg_rt! { /// let s = "Hello World!".to_string(); /// /// // By entering the context, we tie `tokio::spawn` to this executor. - /// rt.enter(|| function_that_spawns(s)); + /// let _guard = rt.enter(); + /// function_that_spawns(s); /// } /// ``` - pub fn enter(&self, f: F) -> R - where - F: FnOnce() -> R, - { - self.handle.enter(f) + pub fn enter(&self) -> EnterGuard<'_> { + EnterGuard { + rt: self, + guard: context::enter(self.handle.clone()), + } } /// Shutdown the runtime, waiting for at most `duration` for all spawned diff --git a/tokio/src/runtime/tests/loom_blocking.rs b/tokio/src/runtime/tests/loom_blocking.rs index 8f0b901493b..8fb54c5657e 100644 --- a/tokio/src/runtime/tests/loom_blocking.rs +++ b/tokio/src/runtime/tests/loom_blocking.rs @@ -8,14 +8,15 @@ fn blocking_shutdown() { let v = Arc::new(()); let rt = mk_runtime(1); - rt.enter(|| { + { + let _enter = rt.enter(); for _ in 0..2 { let v = v.clone(); crate::task::spawn_blocking(move || { assert!(1 < Arc::strong_count(&v)); }); } - }); + } drop(rt); assert_eq!(1, Arc::strong_count(&v)); diff --git a/tokio/src/signal/windows.rs b/tokio/src/signal/windows.rs index 46271722515..1e783622bb1 100644 --- a/tokio/src/signal/windows.rs +++ b/tokio/src/signal/windows.rs @@ -253,21 +253,20 @@ mod tests { #[test] fn ctrl_c() { let rt = rt(); + let _enter = rt.enter(); - rt.enter(|| { - let mut ctrl_c = task::spawn(crate::signal::ctrl_c()); + let mut ctrl_c = task::spawn(crate::signal::ctrl_c()); - assert_pending!(ctrl_c.poll()); + assert_pending!(ctrl_c.poll()); - // Windows doesn't have a good programmatic way of sending events - // like sending signals on Unix, so we'll stub out the actual OS - // integration and test that our handling works. - unsafe { - super::handler(CTRL_C_EVENT); - } + // Windows doesn't have a good programmatic way of sending events + // like sending signals on Unix, so we'll stub out the actual OS + // integration and test that our handling works. + unsafe { + super::handler(CTRL_C_EVENT); + } - assert_ready_ok!(ctrl_c.poll()); - }); + assert_ready_ok!(ctrl_c.poll()); } #[test] diff --git a/tokio/tests/io_driver.rs b/tokio/tests/io_driver.rs index 82fb10214b1..9a40247ea98 100644 --- a/tokio/tests/io_driver.rs +++ b/tokio/tests/io_driver.rs @@ -67,11 +67,10 @@ fn test_drop_on_notify() { })); { - rt.enter(|| { - let waker = waker_ref(&task); - let mut cx = Context::from_waker(&waker); - assert_pending!(task.future.lock().unwrap().as_mut().poll(&mut cx)); - }); + let _enter = rt.enter(); + let waker = waker_ref(&task); + let mut cx = Context::from_waker(&waker); + assert_pending!(task.future.lock().unwrap().as_mut().poll(&mut cx)); } // Get the address diff --git a/tokio/tests/io_driver_drop.rs b/tokio/tests/io_driver_drop.rs index 72c7ae2552b..631e66e9fbe 100644 --- a/tokio/tests/io_driver_drop.rs +++ b/tokio/tests/io_driver_drop.rs @@ -9,10 +9,11 @@ use tokio_test::{assert_err, assert_pending, assert_ready, task}; fn tcp_doesnt_block() { let rt = rt(); - let listener = rt.enter(|| { + let listener = { + let _enter = rt.enter(); let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); TcpListener::from_std(listener).unwrap() - }); + }; drop(rt); @@ -27,10 +28,11 @@ fn tcp_doesnt_block() { fn drop_wakes() { let rt = rt(); - let listener = rt.enter(|| { + let listener = { + let _enter = rt.enter(); let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); TcpListener::from_std(listener).unwrap() - }); + }; let mut task = task::spawn(async move { assert_err!(listener.accept().await); diff --git a/tokio/tests/rt_common.rs b/tokio/tests/rt_common.rs index a4091616ecf..74a94d5b9d1 100644 --- a/tokio/tests/rt_common.rs +++ b/tokio/tests/rt_common.rs @@ -554,23 +554,6 @@ rt_test! { }); } - #[test] - fn spawn_blocking_after_shutdown() { - let rt = rt(); - let handle = rt.clone(); - - // Shutdown - drop(rt); - - handle.enter(|| { - let res = task::spawn_blocking(|| unreachable!()); - - // Avoid using a tokio runtime - let out = futures::executor::block_on(res); - assert!(out.is_err()); - }); - } - #[test] fn always_active_parker() { // This test it to show that we will always have @@ -713,9 +696,10 @@ rt_test! { #[test] fn enter_and_spawn() { let rt = rt(); - let handle = rt.enter(|| { + let handle = { + let _enter = rt.enter(); tokio::spawn(async {}) - }); + }; assert_ok!(rt.block_on(handle)); }