Skip to content

Commit

Permalink
rt: use internal ThreadId implementation (#5329)
Browse files Browse the repository at this point in the history
The version provided by `std` has limitations, including no way to try
to get a thread ID without panicking.
  • Loading branch information
carllerche committed Dec 30, 2022
1 parent 048049f commit c6552c5
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 30 deletions.
22 changes: 21 additions & 1 deletion tokio/src/runtime/context.rs
Expand Up @@ -15,6 +15,10 @@ cfg_rt! {
}

struct Context {
/// Uniquely identifies the current thread
#[cfg(feature = "rt")]
thread_id: Cell<Option<ThreadId>>,

/// Handle to the runtime scheduler running on the current thread.
#[cfg(feature = "rt")]
handle: RefCell<Option<scheduler::Handle>>,
Expand Down Expand Up @@ -46,6 +50,9 @@ struct Context {
tokio_thread_local! {
static CONTEXT: Context = {
Context {
#[cfg(feature = "rt")]
thread_id: Cell::new(None),

/// Tracks the current runtime handle to use when spawning,
/// accessing drivers, etc...
#[cfg(feature = "rt")]
Expand Down Expand Up @@ -82,10 +89,23 @@ pub(super) fn budget<R>(f: impl FnOnce(&Cell<coop::Budget>) -> R) -> Result<R, A
}

cfg_rt! {
use crate::runtime::TryCurrentError;
use crate::runtime::{ThreadId, TryCurrentError};

use std::fmt;

pub(crate) fn thread_id() -> Result<ThreadId, AccessError> {
CONTEXT.try_with(|ctx| {
match ctx.thread_id.get() {
Some(id) => id,
None => {
let id = ThreadId::next();
ctx.thread_id.set(Some(id));
id
}
}
})
}

#[derive(Debug, Clone, Copy)]
#[must_use]
pub(crate) enum EnterRuntime {
Expand Down
3 changes: 3 additions & 0 deletions tokio/src/runtime/mod.rs
Expand Up @@ -237,6 +237,9 @@ cfg_rt! {
mod runtime;
pub use runtime::{Runtime, RuntimeFlavor};

mod thread_id;
pub(crate) use thread_id::ThreadId;

cfg_metrics! {
mod metrics;
pub use metrics::RuntimeMetrics;
Expand Down
31 changes: 31 additions & 0 deletions tokio/src/runtime/thread_id.rs
@@ -0,0 +1,31 @@
use std::num::NonZeroU64;

#[derive(Eq, PartialEq, Clone, Copy, Hash, Debug)]
pub(crate) struct ThreadId(NonZeroU64);

impl ThreadId {
pub(crate) fn next() -> Self {
use crate::loom::sync::atomic::{Ordering::Relaxed, StaticAtomicU64};

static NEXT_ID: StaticAtomicU64 = StaticAtomicU64::new(0);

let mut last = NEXT_ID.load(Relaxed);
loop {
let id = match last.checked_add(1) {
Some(id) => id,
None => exhausted(),
};

match NEXT_ID.compare_exchange_weak(last, id, Relaxed, Relaxed) {
Ok(_) => return ThreadId(NonZeroU64::new(id).unwrap()),
Err(id) => last = id,
}
}
}
}

#[cold]
#[allow(dead_code)]
fn exhausted() -> ! {
panic!("failed to generate unique thread ID: bitspace exhausted")
}
34 changes: 8 additions & 26 deletions tokio/src/task/local.rs
@@ -1,8 +1,8 @@
//! Runs `!Send` futures on the current thread.
use crate::loom::cell::UnsafeCell;
use crate::loom::sync::{Arc, Mutex};
use crate::loom::thread::{self, ThreadId};
use crate::runtime::task::{self, JoinHandle, LocalOwnedTasks, Task};
use crate::runtime::{context, ThreadId};
use crate::sync::AtomicWaker;
use crate::util::RcCell;

Expand Down Expand Up @@ -277,12 +277,10 @@ pin_project! {
}

tokio_thread_local!(static CURRENT: LocalData = const { LocalData {
thread_id: Cell::new(None),
ctx: RcCell::new(),
} });

struct LocalData {
thread_id: Cell<Option<ThreadId>>,
ctx: RcCell<Context>,
}

Expand Down Expand Up @@ -379,12 +377,14 @@ impl fmt::Debug for LocalEnterGuard {
impl LocalSet {
/// Returns a new local task set.
pub fn new() -> LocalSet {
let owner = context::thread_id().expect("cannot create LocalSet during thread shutdown");

LocalSet {
tick: Cell::new(0),
context: Rc::new(Context {
shared: Arc::new(Shared {
local_state: LocalState {
owner: thread_id().expect("cannot create LocalSet during thread shutdown"),
owner,
owned: LocalOwnedTasks::new(),
local_queue: UnsafeCell::new(VecDeque::with_capacity(INITIAL_CAPACITY)),
},
Expand Down Expand Up @@ -949,7 +949,7 @@ impl Shared {

// We are on the thread that owns the `LocalSet`, so we can
// wake to the local queue.
_ if localdata.get_id() == Some(self.local_state.owner) => {
_ if context::thread_id().ok() == Some(self.local_state.owner) => {
unsafe {
// Safety: we just checked that the thread ID matches
// the localset's owner, so this is safe.
Expand Down Expand Up @@ -1093,7 +1093,9 @@ impl LocalState {
// if we couldn't get the thread ID because we're dropping the local
// data, skip the assertion --- the `Drop` impl is not going to be
// called from another thread, because `LocalSet` is `!Send`
thread_id().map(|id| id == self.owner).unwrap_or(true),
context::thread_id()
.map(|id| id == self.owner)
.unwrap_or(true),
"`LocalSet`'s local run queue must not be accessed by another thread!"
);
}
Expand All @@ -1103,26 +1105,6 @@ impl LocalState {
// ensure they are on the same thread that owns the `LocalSet`.
unsafe impl Send for LocalState {}

impl LocalData {
fn get_id(&self) -> Option<ThreadId> {
self.thread_id.get()
}

fn get_or_insert_id(&self) -> ThreadId {
self.thread_id.get().unwrap_or_else(|| {
let id = thread::current().id();
self.thread_id.set(Some(id));
id
})
}
}

fn thread_id() -> Option<ThreadId> {
CURRENT
.try_with(|localdata| localdata.get_or_insert_id())
.ok()
}

#[cfg(all(test, not(loom)))]
mod tests {
use super::*;
Expand Down
6 changes: 3 additions & 3 deletions tokio/tests/rt_metrics.rs
Expand Up @@ -141,16 +141,16 @@ fn worker_noop_count() {
time::sleep(Duration::from_millis(1)).await;
});
drop(rt);
assert!(2 <= metrics.worker_noop_count(0));
assert!(0 < metrics.worker_noop_count(0));

let rt = threaded();
let metrics = rt.metrics();
rt.block_on(async {
time::sleep(Duration::from_millis(1)).await;
});
drop(rt);
assert!(1 <= metrics.worker_noop_count(0));
assert!(1 <= metrics.worker_noop_count(1));
assert!(0 < metrics.worker_noop_count(0));
assert!(0 < metrics.worker_noop_count(1));
}

#[test]
Expand Down

0 comments on commit c6552c5

Please sign in to comment.