Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

task: add track_caller to block_in_place and spawn_local #5034

Merged
merged 2 commits into from Sep 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 14 additions & 6 deletions tokio/src/runtime/scheduler/multi_thread/worker.rs
Expand Up @@ -249,6 +249,7 @@ pub(super) fn create(
(handle, launch)
}

#[track_caller]
pub(crate) fn block_in_place<F, R>(f: F) -> R
where
F: FnOnce() -> R,
Expand All @@ -275,7 +276,7 @@ where

let mut had_entered = false;

CURRENT.with(|maybe_cx| {
let setup_result = CURRENT.with(|maybe_cx| {
match (crate::runtime::enter::context(), maybe_cx.is_some()) {
(EnterContext::Entered { .. }, true) => {
// We are on a thread pool runtime thread, so we just need to
Expand All @@ -288,22 +289,24 @@ where
// method:
if allow_blocking {
had_entered = true;
return;
return Ok(());
} else {
// This probably means we are on the current_thread runtime or in a
// LocalSet, where it is _not_ okay to block.
panic!("can call blocking only when running on the multi-threaded runtime");
return Err(
"can call blocking only when running on the multi-threaded runtime",
);
}
}
(EnterContext::NotEntered, true) => {
// This is a nested call to block_in_place (we already exited).
// All the necessary setup has already been done.
return;
return Ok(());
}
(EnterContext::NotEntered, false) => {
// We are outside of the tokio runtime, so blocking is fine.
// We can also skip all of the thread pool blocking setup steps.
return;
return Ok(());
}
}

Expand All @@ -312,7 +315,7 @@ where
// Get the worker core. If none is set, then blocking is fine!
let core = match cx.core.borrow_mut().take() {
Some(core) => core,
None => return,
None => return Ok(()),
};

// The parker should be set here
Expand All @@ -331,8 +334,13 @@ where
// steal the core back.
let worker = cx.worker.clone();
runtime::spawn_blocking(move || run(worker));
Ok(())
});

if let Err(panic_message) = setup_result {
panic!("{}", panic_message);
}

if had_entered {
// Unset the current task's budget. Blocking sections are not
// constrained by task budgets.
Expand Down
1 change: 1 addition & 0 deletions tokio/src/task/blocking.rs
Expand Up @@ -70,6 +70,7 @@ cfg_rt_multi_thread! {
/// This function panics if called from a [`current_thread`] runtime.
///
/// [`current_thread`]: fn@crate::runtime::Builder::new_current_thread
#[track_caller]
pub fn block_in_place<F, R>(f: F) -> R
where
F: FnOnce() -> R,
Expand Down
10 changes: 4 additions & 6 deletions tokio/src/task/local.rs
Expand Up @@ -314,12 +314,10 @@ cfg_rt! {
where F: Future + 'static,
F::Output: 'static
{
CURRENT.with(|maybe_cx| {
match maybe_cx.get() {
None => panic!("`spawn_local` called from outside of a `task::LocalSet`"),
Some(cx) => cx.spawn(future, name)
}
})
match CURRENT.with(|maybe_cx| maybe_cx.get()) {
None => panic!("`spawn_local` called from outside of a `task::LocalSet`"),
Some(cx) => cx.spawn(future, name)
}
}
}

Expand Down
34 changes: 32 additions & 2 deletions tokio/tests/task_panic.rs
Expand Up @@ -3,13 +3,43 @@

use futures::future;
use std::error::Error;
use tokio::{runtime::Builder, spawn, task};
use tokio::runtime::Builder;
use tokio::task::{self, block_in_place};

mod support {
pub mod panic;
}
use support::panic::test_panic;

#[test]
fn block_in_place_panic_caller() -> Result<(), Box<dyn Error>> {
let panic_location_file = test_panic(|| {
let rt = Builder::new_current_thread().enable_all().build().unwrap();
rt.block_on(async {
block_in_place(|| {});
});
});

// The panic location should be in this file
assert_eq!(&panic_location_file.unwrap(), file!());

Ok(())
}

#[test]
fn local_set_spawn_local_panic_caller() -> Result<(), Box<dyn Error>> {
let panic_location_file = test_panic(|| {
let _local = task::LocalSet::new();

let _ = task::spawn_local(async {});
});

// The panic location should be in this file
assert_eq!(&panic_location_file.unwrap(), file!());

Ok(())
}

#[test]
fn local_set_block_on_panic_caller() -> Result<(), Box<dyn Error>> {
let panic_location_file = test_panic(|| {
Expand All @@ -30,7 +60,7 @@ fn local_set_block_on_panic_caller() -> Result<(), Box<dyn Error>> {
#[test]
fn spawn_panic_caller() -> Result<(), Box<dyn Error>> {
let panic_location_file = test_panic(|| {
spawn(future::pending::<()>());
tokio::spawn(future::pending::<()>());
});

// The panic location should be in this file
Expand Down