diff --git a/tokio/src/runtime/context.rs b/tokio/src/runtime/context.rs index 2e54c8ba366..fef53cab8a6 100644 --- a/tokio/src/runtime/context.rs +++ b/tokio/src/runtime/context.rs @@ -15,6 +15,10 @@ cfg_rt! { } struct Context { + /// Uniquely identifies the current thread + #[cfg(feature = "rt")] + thread_id: Cell>, + /// Handle to the runtime scheduler running on the current thread. #[cfg(feature = "rt")] handle: RefCell>, @@ -46,6 +50,9 @@ struct Context { tokio_thread_local! { static CONTEXT: Context = { Context { + #[cfg(feature = "rt")] + thread_id: Cell::new(None), + /// Tracks the current runtime handle to use when spawning, /// accessing drivers, etc... #[cfg(feature = "rt")] @@ -82,10 +89,23 @@ pub(super) fn budget(f: impl FnOnce(&Cell) -> R) -> Result Result { + CONTEXT.try_with(|ctx| { + match ctx.thread_id.get() { + Some(id) => id, + None => { + let id = ThreadId::next(); + ctx.thread_id.set(Some(id)); + id + } + } + }) + } + #[derive(Debug, Clone, Copy)] #[must_use] pub(crate) enum EnterRuntime { diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs index 45b79b0ac81..b6f43ea1754 100644 --- a/tokio/src/runtime/mod.rs +++ b/tokio/src/runtime/mod.rs @@ -237,6 +237,9 @@ cfg_rt! { mod runtime; pub use runtime::{Runtime, RuntimeFlavor}; + mod thread_id; + pub(crate) use thread_id::ThreadId; + cfg_metrics! { mod metrics; pub use metrics::RuntimeMetrics; diff --git a/tokio/src/runtime/thread_id.rs b/tokio/src/runtime/thread_id.rs new file mode 100644 index 00000000000..ef392897963 --- /dev/null +++ b/tokio/src/runtime/thread_id.rs @@ -0,0 +1,31 @@ +use std::num::NonZeroU64; + +#[derive(Eq, PartialEq, Clone, Copy, Hash, Debug)] +pub(crate) struct ThreadId(NonZeroU64); + +impl ThreadId { + pub(crate) fn next() -> Self { + use crate::loom::sync::atomic::{Ordering::Relaxed, StaticAtomicU64}; + + static NEXT_ID: StaticAtomicU64 = StaticAtomicU64::new(0); + + let mut last = NEXT_ID.load(Relaxed); + loop { + let id = match last.checked_add(1) { + Some(id) => id, + None => exhausted(), + }; + + match NEXT_ID.compare_exchange_weak(last, id, Relaxed, Relaxed) { + Ok(_) => return ThreadId(NonZeroU64::new(id).unwrap()), + Err(id) => last = id, + } + } + } +} + +#[cold] +#[allow(dead_code)] +fn exhausted() -> ! { + panic!("failed to generate unique thread ID: bitspace exhausted") +} diff --git a/tokio/src/task/local.rs b/tokio/src/task/local.rs index e4a198bd053..cc4500a58e7 100644 --- a/tokio/src/task/local.rs +++ b/tokio/src/task/local.rs @@ -1,8 +1,8 @@ //! Runs `!Send` futures on the current thread. use crate::loom::cell::UnsafeCell; use crate::loom::sync::{Arc, Mutex}; -use crate::loom::thread::{self, ThreadId}; use crate::runtime::task::{self, JoinHandle, LocalOwnedTasks, Task}; +use crate::runtime::{context, ThreadId}; use crate::sync::AtomicWaker; use crate::util::RcCell; @@ -277,12 +277,10 @@ pin_project! { } tokio_thread_local!(static CURRENT: LocalData = const { LocalData { - thread_id: Cell::new(None), ctx: RcCell::new(), } }); struct LocalData { - thread_id: Cell>, ctx: RcCell, } @@ -379,12 +377,14 @@ impl fmt::Debug for LocalEnterGuard { impl LocalSet { /// Returns a new local task set. pub fn new() -> LocalSet { + let owner = context::thread_id().expect("cannot create LocalSet during thread shutdown"); + LocalSet { tick: Cell::new(0), context: Rc::new(Context { shared: Arc::new(Shared { local_state: LocalState { - owner: thread_id().expect("cannot create LocalSet during thread shutdown"), + owner, owned: LocalOwnedTasks::new(), local_queue: UnsafeCell::new(VecDeque::with_capacity(INITIAL_CAPACITY)), }, @@ -949,7 +949,7 @@ impl Shared { // We are on the thread that owns the `LocalSet`, so we can // wake to the local queue. - _ if localdata.get_id() == Some(self.local_state.owner) => { + _ if context::thread_id().ok() == Some(self.local_state.owner) => { unsafe { // Safety: we just checked that the thread ID matches // the localset's owner, so this is safe. @@ -1093,7 +1093,9 @@ impl LocalState { // if we couldn't get the thread ID because we're dropping the local // data, skip the assertion --- the `Drop` impl is not going to be // called from another thread, because `LocalSet` is `!Send` - thread_id().map(|id| id == self.owner).unwrap_or(true), + context::thread_id() + .map(|id| id == self.owner) + .unwrap_or(true), "`LocalSet`'s local run queue must not be accessed by another thread!" ); } @@ -1103,26 +1105,6 @@ impl LocalState { // ensure they are on the same thread that owns the `LocalSet`. unsafe impl Send for LocalState {} -impl LocalData { - fn get_id(&self) -> Option { - self.thread_id.get() - } - - fn get_or_insert_id(&self) -> ThreadId { - self.thread_id.get().unwrap_or_else(|| { - let id = thread::current().id(); - self.thread_id.set(Some(id)); - id - }) - } -} - -fn thread_id() -> Option { - CURRENT - .try_with(|localdata| localdata.get_or_insert_id()) - .ok() -} - #[cfg(all(test, not(loom)))] mod tests { use super::*; diff --git a/tokio/tests/rt_metrics.rs b/tokio/tests/rt_metrics.rs index 4b98d234c41..fdb2fb5f551 100644 --- a/tokio/tests/rt_metrics.rs +++ b/tokio/tests/rt_metrics.rs @@ -141,7 +141,7 @@ fn worker_noop_count() { time::sleep(Duration::from_millis(1)).await; }); drop(rt); - assert!(2 <= metrics.worker_noop_count(0)); + assert!(0 < metrics.worker_noop_count(0)); let rt = threaded(); let metrics = rt.metrics(); @@ -149,8 +149,8 @@ fn worker_noop_count() { time::sleep(Duration::from_millis(1)).await; }); drop(rt); - assert!(1 <= metrics.worker_noop_count(0)); - assert!(1 <= metrics.worker_noop_count(1)); + assert!(0 < metrics.worker_noop_count(0)); + assert!(0 < metrics.worker_noop_count(1)); } #[test]