Skip to content

Commit

Permalink
Reduce size of Thread
Browse files Browse the repository at this point in the history
Make insertion fully cold
  • Loading branch information
terrarier2111 committed Jul 22, 2023
1 parent ee60698 commit fd17302
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 44 deletions.
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ rust-version = "1.59"

[features]
# this feature provides performance improvements using nightly features
nightly = []
nightly = ["memoffset"]

[badges]
travis-ci = { repository = "Amanieu/thread_local-rs" }
Expand All @@ -23,9 +23,10 @@ once_cell = "1.5.2"
# this is required to gate `nightly` related code paths
cfg-if = "1.0.0"
crossbeam-utils = "0.8.15"
memoffset = { version = "0.9.0", optional = true }

[dev-dependencies]
criterion = "0.4.0"
criterion = "0.4"

[[bench]]
name = "thread_local"
Expand Down
30 changes: 18 additions & 12 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,11 @@ impl<T: Send> ThreadLocal<T> {
where
F: FnOnce() -> T,
{
unsafe {
self.get_or_try(|| Ok::<T, ()>(create()))
.unchecked_unwrap_ok()
if let Some(val) = self.get() {
return val;
}

self.insert(create)
}

/// Returns the element for the current thread, or creates it if it doesn't
Expand All @@ -201,12 +202,11 @@ impl<T: Send> ThreadLocal<T> {
where
F: FnOnce() -> Result<T, E>,
{
let thread = thread_id::get();
if let Some(val) = self.get_inner(thread) {
if let Some(val) = self.get() {
return Ok(val);
}

Ok(self.insert(create()?))
self.insert_maybe(create)
}

fn get_inner(&self, thread: Thread) -> Option<&T> {
Expand All @@ -227,14 +227,22 @@ impl<T: Send> ThreadLocal<T> {
}

#[cold]
fn insert(&self, data: T) -> &T {
fn insert_maybe<F: FnOnce() -> Result<T, E>, E>(&self, gen: F) -> Result<&T, E> {
let data = gen()?;
Ok(self.insert(|| data))
}

#[cold]
fn insert<F: FnOnce() -> T>(&self, gen: F) -> &T {
// call the generator here, so it is #[cold] as well.
let data = gen();
let thread = thread_id::get();
let bucket_atomic_ptr = unsafe { self.buckets.get_unchecked(thread.bucket) };
let bucket_ptr: *const _ = bucket_atomic_ptr.load(Ordering::Acquire);

// If the bucket doesn't already exist, we need to allocate it
let bucket_ptr = if bucket_ptr.is_null() {
let new_bucket = allocate_bucket(thread.bucket_size);
let new_bucket = allocate_bucket(thread.bucket_size());

match bucket_atomic_ptr.compare_exchange(
ptr::null_mut(),
Expand All @@ -247,7 +255,7 @@ impl<T: Send> ThreadLocal<T> {
// another thread stored a new bucket before we could,
// and we can free our bucket and use that one instead
Err(bucket_ptr) => {
unsafe { deallocate_bucket(new_bucket, thread.bucket_size) }
unsafe { deallocate_bucket(new_bucket, thread.bucket_size()) }
bucket_ptr
}
}
Expand Down Expand Up @@ -496,9 +504,7 @@ impl<T: Send> Iterator for IntoIter<T> {
fn next(&mut self) -> Option<T> {
self.raw.next_mut(&mut self.thread_local).map(|entry| {
*entry.present.get_mut() = false;
unsafe {
std::mem::replace(&mut *entry.value.get(), MaybeUninit::uninit()).assume_init()
}
unsafe { mem::replace(&mut *entry.value.get(), MaybeUninit::uninit()).assume_init() }
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
Expand Down
90 changes: 60 additions & 30 deletions src/thread_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,38 +49,47 @@ static THREAD_ID_MANAGER: Lazy<Mutex<ThreadIdManager>> =
/// A thread ID may be reused after a thread exits.
#[derive(Clone, Copy)]
pub(crate) struct Thread {
/// The thread ID obtained from the thread ID manager.
pub(crate) id: usize,
/// The bucket this thread's local storage will be in.
pub(crate) bucket: usize,
/// The size of the bucket this thread's local storage will be in.
pub(crate) bucket_size: usize,
/// The index into the bucket this thread's local storage is in.
pub(crate) index: usize,
}

impl Thread {
/// id: The thread ID obtained from the thread ID manager.
#[inline]
fn new(id: usize) -> Self {
let bucket = usize::from(POINTER_WIDTH) - ((id + 1).leading_zeros() as usize) - 1;
let bucket_size = 1 << bucket;
let index = id - (bucket_size - 1);
Self { bucket, index }
}

Self {
id,
bucket,
bucket_size,
index,
}
/// The size of the bucket this thread's local storage will be in.
#[inline]
pub fn bucket_size(&self) -> usize {
1 << self.bucket
}
}

cfg_if::cfg_if! {
if #[cfg(feature = "nightly")] {
use memoffset::offset_of;
use std::ptr::null;
use std::cell::UnsafeCell;

// This is split into 2 thread-local variables so that we can check whether the
// thread is initialized without having to register a thread-local destructor.
//
// This makes the fast path smaller.
#[thread_local]
static mut THREAD: Option<Thread> = None;
static THREAD: UnsafeCell<ThreadWrapper> = UnsafeCell::new(ThreadWrapper {
self_ptr: null(),
thread: Thread {
index: 0,
bucket: 0,
},
});
thread_local! { static THREAD_GUARD: ThreadGuard = const { ThreadGuard { id: Cell::new(0) } }; }

// Guard to ensure the thread ID is released on thread exit.
Expand All @@ -97,17 +106,41 @@ cfg_if::cfg_if! {
// will go through get_slow which will either panic or
// initialize a new ThreadGuard.
unsafe {
THREAD = None;
(&mut *THREAD.get()).self_ptr = null();
}
THREAD_ID_MANAGER.lock().free(self.id.get());
}
}

/// Data which is unique to the current thread while it is running.
/// A thread ID may be reused after a thread exits.
///
/// This wrapper exists to hide multiple accesses to the TLS data
/// from the backend as this can lead to inefficient codegen
/// (to be precise it can lead to multiple TLS address lookups)
#[derive(Clone, Copy)]
struct ThreadWrapper {
self_ptr: *const Thread,
thread: Thread,
}

impl ThreadWrapper {
/// The thread ID obtained from the thread ID manager.
#[inline]
fn new(id: usize) -> Self {
Self {
self_ptr: ((THREAD.get().cast_const() as usize) + offset_of!(ThreadWrapper, thread)) as *const Thread,
thread: Thread::new(id),
}
}
}

/// Returns a thread ID for the current thread, allocating one if needed.
#[inline]
pub(crate) fn get() -> Thread {
if let Some(thread) = unsafe { THREAD } {
thread
let thread = unsafe { *THREAD.get() };
if !thread.self_ptr.is_null() {
unsafe { thread.self_ptr.read() }
} else {
get_slow()
}
Expand All @@ -116,12 +149,13 @@ cfg_if::cfg_if! {
/// Out-of-line slow path for allocating a thread ID.
#[cold]
fn get_slow() -> Thread {
let new = Thread::new(THREAD_ID_MANAGER.lock().alloc());
let id = THREAD_ID_MANAGER.lock().alloc();
let new = ThreadWrapper::new(id);
unsafe {
THREAD = Some(new);
*THREAD.get() = new;
}
THREAD_GUARD.with(|guard| guard.id.set(new.id));
new
THREAD_GUARD.with(|guard| guard.id.set(id));
new.thread
}
} else {
// This is split into 2 thread-local variables so that we can check whether the
Expand Down Expand Up @@ -164,9 +198,10 @@ cfg_if::cfg_if! {
/// Out-of-line slow path for allocating a thread ID.
#[cold]
fn get_slow(thread: &Cell<Option<Thread>>) -> Thread {
let new = Thread::new(THREAD_ID_MANAGER.lock().alloc());
let id = THREAD_ID_MANAGER.lock().alloc();
let new = Thread::new(id);
thread.set(Some(new));
THREAD_GUARD.with(|guard| guard.id.set(new.id));
THREAD_GUARD.with(|guard| guard.id.set(id));
new
}
}
Expand All @@ -175,32 +210,27 @@ cfg_if::cfg_if! {
#[test]
fn test_thread() {
let thread = Thread::new(0);
assert_eq!(thread.id, 0);
assert_eq!(thread.bucket, 0);
assert_eq!(thread.bucket_size, 1);
assert_eq!(thread.bucket_size(), 1);
assert_eq!(thread.index, 0);

let thread = Thread::new(1);
assert_eq!(thread.id, 1);
assert_eq!(thread.bucket, 1);
assert_eq!(thread.bucket_size, 2);
assert_eq!(thread.bucket_size(), 2);
assert_eq!(thread.index, 0);

let thread = Thread::new(2);
assert_eq!(thread.id, 2);
assert_eq!(thread.bucket, 1);
assert_eq!(thread.bucket_size, 2);
assert_eq!(thread.bucket_size(), 2);
assert_eq!(thread.index, 1);

let thread = Thread::new(3);
assert_eq!(thread.id, 3);
assert_eq!(thread.bucket, 2);
assert_eq!(thread.bucket_size, 4);
assert_eq!(thread.bucket_size(), 4);
assert_eq!(thread.index, 0);

let thread = Thread::new(19);
assert_eq!(thread.id, 19);
assert_eq!(thread.bucket, 4);
assert_eq!(thread.bucket_size, 16);
assert_eq!(thread.bucket_size(), 16);
assert_eq!(thread.index, 4);
}

0 comments on commit fd17302

Please sign in to comment.