Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rt: use internal ThreadId implementation #5329

Merged
merged 2 commits into from Dec 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 21 additions & 1 deletion tokio/src/runtime/context.rs
Expand Up @@ -15,6 +15,10 @@ cfg_rt! {
}

struct Context {
/// Uniquely identifies the current thread
#[cfg(feature = "rt")]
thread_id: Cell<Option<ThreadId>>,

/// Handle to the runtime scheduler running on the current thread.
#[cfg(feature = "rt")]
handle: RefCell<Option<scheduler::Handle>>,
Expand Down Expand Up @@ -46,6 +50,9 @@ struct Context {
tokio_thread_local! {
static CONTEXT: Context = {
Context {
#[cfg(feature = "rt")]
thread_id: Cell::new(None),

/// Tracks the current runtime handle to use when spawning,
/// accessing drivers, etc...
#[cfg(feature = "rt")]
Expand Down Expand Up @@ -82,10 +89,23 @@ pub(super) fn budget<R>(f: impl FnOnce(&Cell<coop::Budget>) -> R) -> Result<R, A
}

cfg_rt! {
use crate::runtime::TryCurrentError;
use crate::runtime::{ThreadId, TryCurrentError};

use std::fmt;

pub(crate) fn thread_id() -> Result<ThreadId, AccessError> {
CONTEXT.try_with(|ctx| {
match ctx.thread_id.get() {
Some(id) => id,
None => {
let id = ThreadId::next();
ctx.thread_id.set(Some(id));
id
}
}
})
}

#[derive(Debug, Clone, Copy)]
#[must_use]
pub(crate) enum EnterRuntime {
Expand Down
3 changes: 3 additions & 0 deletions tokio/src/runtime/mod.rs
Expand Up @@ -237,6 +237,9 @@ cfg_rt! {
mod runtime;
pub use runtime::{Runtime, RuntimeFlavor};

mod thread_id;
pub(crate) use thread_id::ThreadId;

cfg_metrics! {
mod metrics;
pub use metrics::RuntimeMetrics;
Expand Down
31 changes: 31 additions & 0 deletions tokio/src/runtime/thread_id.rs
@@ -0,0 +1,31 @@
use std::num::NonZeroU64;

#[derive(Eq, PartialEq, Clone, Copy, Hash, Debug)]
pub(crate) struct ThreadId(NonZeroU64);

impl ThreadId {
pub(crate) fn next() -> Self {
use crate::loom::sync::atomic::{Ordering::Relaxed, StaticAtomicU64};

static NEXT_ID: StaticAtomicU64 = StaticAtomicU64::new(0);

let mut last = NEXT_ID.load(Relaxed);
loop {
let id = match last.checked_add(1) {
Some(id) => id,
None => exhausted(),
};

match NEXT_ID.compare_exchange_weak(last, id, Relaxed, Relaxed) {
Ok(_) => return ThreadId(NonZeroU64::new(id).unwrap()),
Comment on lines +14 to +20
Copy link
Contributor

@Darksonn Darksonn Dec 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unwrap is a bit unfortunate. One option is to define a method call like this one:

fn checked_add_one(num: u64) -> Option<NonZeroU64> {
    NonZeroU64::new(num.wrapping_add(1))
}

That said, it is compiled out in this instance. I am also okay with keeping it as-is.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I grabbed it from std here 🤷

Err(id) => last = id,
}
}
}
}

#[cold]
#[allow(dead_code)]
fn exhausted() -> ! {
panic!("failed to generate unique thread ID: bitspace exhausted")
}
34 changes: 8 additions & 26 deletions tokio/src/task/local.rs
@@ -1,8 +1,8 @@
//! Runs `!Send` futures on the current thread.
use crate::loom::cell::UnsafeCell;
use crate::loom::sync::{Arc, Mutex};
use crate::loom::thread::{self, ThreadId};
use crate::runtime::task::{self, JoinHandle, LocalOwnedTasks, Task};
use crate::runtime::{context, ThreadId};
use crate::sync::AtomicWaker;
use crate::util::RcCell;

Expand Down Expand Up @@ -277,12 +277,10 @@ pin_project! {
}

tokio_thread_local!(static CURRENT: LocalData = const { LocalData {
thread_id: Cell::new(None),
ctx: RcCell::new(),
} });

struct LocalData {
thread_id: Cell<Option<ThreadId>>,
ctx: RcCell<Context>,
}

Expand Down Expand Up @@ -379,12 +377,14 @@ impl fmt::Debug for LocalEnterGuard {
impl LocalSet {
/// Returns a new local task set.
pub fn new() -> LocalSet {
let owner = context::thread_id().expect("cannot create LocalSet during thread shutdown");

LocalSet {
tick: Cell::new(0),
context: Rc::new(Context {
shared: Arc::new(Shared {
local_state: LocalState {
owner: thread_id().expect("cannot create LocalSet during thread shutdown"),
owner,
owned: LocalOwnedTasks::new(),
local_queue: UnsafeCell::new(VecDeque::with_capacity(INITIAL_CAPACITY)),
},
Expand Down Expand Up @@ -949,7 +949,7 @@ impl Shared {

// We are on the thread that owns the `LocalSet`, so we can
// wake to the local queue.
_ if localdata.get_id() == Some(self.local_state.owner) => {
_ if context::thread_id().ok() == Some(self.local_state.owner) => {
unsafe {
// Safety: we just checked that the thread ID matches
// the localset's owner, so this is safe.
Expand Down Expand Up @@ -1093,7 +1093,9 @@ impl LocalState {
// if we couldn't get the thread ID because we're dropping the local
// data, skip the assertion --- the `Drop` impl is not going to be
// called from another thread, because `LocalSet` is `!Send`
thread_id().map(|id| id == self.owner).unwrap_or(true),
context::thread_id()
.map(|id| id == self.owner)
.unwrap_or(true),
"`LocalSet`'s local run queue must not be accessed by another thread!"
);
}
Expand All @@ -1103,26 +1105,6 @@ impl LocalState {
// ensure they are on the same thread that owns the `LocalSet`.
unsafe impl Send for LocalState {}

impl LocalData {
fn get_id(&self) -> Option<ThreadId> {
self.thread_id.get()
}

fn get_or_insert_id(&self) -> ThreadId {
self.thread_id.get().unwrap_or_else(|| {
let id = thread::current().id();
self.thread_id.set(Some(id));
id
})
}
}

fn thread_id() -> Option<ThreadId> {
CURRENT
.try_with(|localdata| localdata.get_or_insert_id())
.ok()
}

#[cfg(all(test, not(loom)))]
mod tests {
use super::*;
Expand Down
6 changes: 3 additions & 3 deletions tokio/tests/rt_metrics.rs
Expand Up @@ -141,16 +141,16 @@ fn worker_noop_count() {
time::sleep(Duration::from_millis(1)).await;
});
drop(rt);
assert!(2 <= metrics.worker_noop_count(0));
assert!(0 < metrics.worker_noop_count(0));

let rt = threaded();
let metrics = rt.metrics();
rt.block_on(async {
time::sleep(Duration::from_millis(1)).await;
});
drop(rt);
assert!(1 <= metrics.worker_noop_count(0));
assert!(1 <= metrics.worker_noop_count(1));
assert!(0 < metrics.worker_noop_count(0));
assert!(0 < metrics.worker_noop_count(1));
}

#[test]
Expand Down