Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a list instead of a hash map #22

Merged
merged 9 commits into from Jan 7, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/cached.rs
Expand Up @@ -3,6 +3,7 @@ use std::cell::UnsafeCell;
use std::fmt;
use std::panic::UnwindSafe;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::usize;
use thread_id;
use unreachable::{UncheckedOptionExt, UncheckedResultExt};

Expand Down Expand Up @@ -30,7 +31,7 @@ impl<T: Send> CachedThreadLocal<T> {
/// Creates a new empty `CachedThreadLocal`.
pub fn new() -> CachedThreadLocal<T> {
CachedThreadLocal {
owner: AtomicUsize::new(0),
owner: AtomicUsize::new(usize::MAX),
local: UnsafeCell::new(None),
global: ThreadLocal::new(),
}
Expand All @@ -43,7 +44,7 @@ impl<T: Send> CachedThreadLocal<T> {
if owner == id {
return unsafe { Some((*self.local.get()).as_ref().unchecked_unwrap()) };
}
if owner == 0 {
if owner == usize::MAX {
return None;
}
self.global.get_fast(id)
Expand Down Expand Up @@ -83,15 +84,20 @@ impl<T: Send> CachedThreadLocal<T> {
where
F: FnOnce() -> Result<T, E>,
{
if owner == 0 && self.owner.compare_and_swap(0, id, Ordering::Relaxed) == 0 {
if owner == usize::MAX
&& self
.owner
.compare_and_swap(usize::MAX, id, Ordering::Relaxed)
== usize::MAX
{
unsafe {
(*self.local.get()) = Some(Box::new(create()?));
return Ok((*self.local.get()).as_ref().unchecked_unwrap());
}
}
match self.global.get_fast(id) {
Some(x) => Ok(x),
None => Ok(self.global.insert(id, Box::new(create()?), true)),
None => Ok(self.global.insert(id, create()?, true)),
}
}

Expand Down
238 changes: 100 additions & 138 deletions src/lib.rs
Expand Up @@ -84,41 +84,42 @@ use std::cell::UnsafeCell;
use std::fmt;
use std::marker::PhantomData;
use std::panic::UnwindSafe;
use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
use std::sync::atomic::{AtomicPtr, Ordering};
use std::sync::Mutex;
use unreachable::{UncheckedOptionExt, UncheckedResultExt};

/// Thread-local variable wrapper
///
/// See the [module-level documentation](index.html) for more.
pub struct ThreadLocal<T: Send> {
// Pointer to the current top-level hash table
table: AtomicPtr<Table<T>>,
// Pointer to the current top-level list
list: AtomicPtr<List<T>>,

// Lock used to guard against concurrent modifications. This is only taken
// while writing to the table, not when reading from it. This also guards
// the counter for the total number of values in the hash table.
// while writing to the list, not when reading from it. This also guards
// the counter for the total number of values in the thread local.
lock: Mutex<usize>,
}

struct Table<T: Send> {
// Hash entries for the table
entries: Box<[TableEntry<T>]>,

// Number of bits used for the hash function
hash_bits: usize,

// Previous table, half the size of the current one
prev: Option<Box<Table<T>>>,
/// A list of thread-local values.
struct List<T: Send> {
// The thread local values in this list. If any values is `None`, it is
// either in an earlier list or it is uninitialized.
values: Box<[UnsafeCell<Option<T>>]>,

// Previous list, half the size of the current one
//
// This cannot be a Box as that would result in the Box's pointer
// potentially being aliased when creating a new list, which is UB.
prev: Option<*mut List<T>>,
}

struct TableEntry<T: Send> {
// Current owner of this entry, or 0 if this is an empty entry
owner: AtomicUsize,

// The object associated with this entry. This is only ever accessed by the
// owner of the entry.
data: UnsafeCell<Option<Box<T>>>,
impl<T: Send> Drop for List<T> {
fn drop(&mut self) {
if let Some(prev) = self.prev.take() {
drop(unsafe { Box::from_raw(prev) });
}
}
}

// ThreadLocal is always Sync, even if T isn't
Expand All @@ -133,47 +134,29 @@ impl<T: Send> Default for ThreadLocal<T> {
impl<T: Send> Drop for ThreadLocal<T> {
fn drop(&mut self) {
unsafe {
Box::from_raw(self.table.load(Ordering::Relaxed));
}
}
}

// Implementation of Clone for TableEntry, needed to make vec![] work
impl<T: Send> Clone for TableEntry<T> {
fn clone(&self) -> TableEntry<T> {
TableEntry {
owner: AtomicUsize::new(0),
data: UnsafeCell::new(None),
Box::from_raw(*self.list.get_mut());
}
}
}

// Hash function for the thread id
#[cfg(target_pointer_width = "32")]
#[inline]
fn hash(id: usize, bits: usize) -> usize {
id.wrapping_mul(0x9E3779B9) >> (32 - bits)
}
#[cfg(target_pointer_width = "64")]
#[inline]
fn hash(id: usize, bits: usize) -> usize {
id.wrapping_mul(0x9E37_79B9_7F4A_7C15) >> (64 - bits)
}

impl<T: Send> ThreadLocal<T> {
/// Creates a new empty `ThreadLocal`.
pub fn new() -> ThreadLocal<T> {
let entry = TableEntry {
owner: AtomicUsize::new(0),
data: UnsafeCell::new(None),
};
let table = Table {
entries: vec![entry; 2].into_boxed_slice(),
hash_bits: 1,
ThreadLocal::with_capacity(2)
}

/// Creates a new `ThreadLocal` with an initial capacity. If less than the capacity threads
/// access the thread local it will never reallocate.
pub fn with_capacity(capacity: usize) -> ThreadLocal<T> {
let list = List {
values: (0..capacity)
.map(|_| UnsafeCell::new(None))
.collect::<Vec<_>>()
.into_boxed_slice(),
prev: None,
};
ThreadLocal {
table: AtomicPtr::new(Box::into_raw(Box::new(table))),
list: AtomicPtr::new(Box::into_raw(Box::new(list))),
lock: Mutex::new(0),
}
}
Expand Down Expand Up @@ -206,114 +189,92 @@ impl<T: Send> ThreadLocal<T> {
let id = thread_id::get();
match self.get_fast(id) {
Some(x) => Ok(x),
None => Ok(self.insert(id, Box::new(create()?), true)),
None => Ok(self.insert(id, create()?, true)),
}
}

// Simple hash table lookup function
fn lookup(id: usize, table: &Table<T>) -> Option<&UnsafeCell<Option<Box<T>>>> {
// Because we use a Mutex to prevent concurrent modifications (but not
// reads) of the hash table, we can avoid any memory barriers here. No
// elements between our hash bucket and our value can have been modified
// since we inserted our thread-local value into the table.
for entry in table.entries.iter().cycle().skip(hash(id, table.hash_bits)) {
let owner = entry.owner.load(Ordering::Relaxed);
if owner == id {
return Some(&entry.data);
}
if owner == 0 {
return None;
}
}
unreachable!();
}

// Fast path: try to find our thread in the top-level hash table
// Fast path: try to find our thread in the top-level list
fn get_fast(&self, id: usize) -> Option<&T> {
let table = unsafe { &*self.table.load(Ordering::Acquire) };
match Self::lookup(id, table) {
Some(x) => unsafe { Some((*x.get()).as_ref().unchecked_unwrap()) },
None => self.get_slow(id, table),
}
let list = unsafe { &*self.list.load(Ordering::Acquire) };
list.values
.get(id)
.and_then(|cell| unsafe { &*cell.get() }.as_ref())
.or_else(|| self.get_slow(id, list))
}

// Slow path: try to find our thread in the other hash tables, and then
// move it to the top-level hash table.
// Slow path: try to find our thread in the other lists, and then move it to
// the top-level list.
#[cold]
fn get_slow(&self, id: usize, table_top: &Table<T>) -> Option<&T> {
let mut current = &table_top.prev;
while let Some(ref table) = *current {
if let Some(x) = Self::lookup(id, table) {
let data = unsafe { (*x.get()).take().unchecked_unwrap() };
return Some(self.insert(id, data, false));
fn get_slow(&self, id: usize, list_top: &List<T>) -> Option<&T> {
let mut current = list_top.prev;
while let Some(list) = current {
let list = unsafe { &*list };

match list.values.get(id) {
Some(value) => {
let value_option = unsafe { &mut *value.get() };
if value_option.is_some() {
let value = unsafe { value_option.take().unchecked_unwrap() };
return Some(self.insert(id, value, false));
}
}
None => break,
}
current = &table.prev;
current = list.prev;
}
None
}

#[cold]
fn insert(&self, id: usize, data: Box<T>, new: bool) -> &T {
// Lock the Mutex to ensure only a single thread is modify the hash
// table at once.
fn insert(&self, id: usize, data: T, new: bool) -> &T {
let list_raw = self.list.load(Ordering::Relaxed);
let list = unsafe { &*list_raw };

// Lock the Mutex to ensure only a single thread is adding new lists at
// once
let mut count = self.lock.lock().unwrap();
if new {
*count += 1;
}
let table_raw = self.table.load(Ordering::Relaxed);
let table = unsafe { &*table_raw };

// If the current top-level hash table is more than 75% full, add a new
// level with 2x the capacity. Elements will be moved up to the new top
// level table as they are accessed.
let table = if *count > table.entries.len() * 3 / 4 {
let entry = TableEntry {
owner: AtomicUsize::new(0),
data: UnsafeCell::new(None),
};
let new_table = Box::into_raw(Box::new(Table {
entries: vec![entry; table.entries.len() * 2].into_boxed_slice(),
hash_bits: table.hash_bits + 1,
prev: unsafe { Some(Box::from_raw(table_raw)) },

// If there isn't space for this thread's local, add a new list.
let list = if id >= list.values.len() {
let new_list = Box::into_raw(Box::new(List {
values: (0..std::cmp::max(list.values.len() * 2, id + 1))
// Values will be lazily moved into the top-level list, so
// it starts out empty
.map(|_| UnsafeCell::new(None))
.collect::<Vec<_>>()
.into_boxed_slice(),
prev: Some(list_raw),
}));
self.table.store(new_table, Ordering::Release);
unsafe { &*new_table }
self.list.store(new_list, Ordering::Release);
unsafe { &*new_list }
} else {
table
list
};

// Insert the new element into the top-level hash table
for entry in table.entries.iter().cycle().skip(hash(id, table.hash_bits)) {
let owner = entry.owner.load(Ordering::Relaxed);
if owner == 0 {
unsafe {
entry.owner.store(id, Ordering::Relaxed);
*entry.data.get() = Some(data);
return (*entry.data.get()).as_ref().unchecked_unwrap();
}
}
if owner == id {
// This can happen if create() inserted a value into this
// ThreadLocal between our calls to get_fast() and insert(). We
// just return the existing value and drop the newly-allocated
// Box.
unsafe {
return (*entry.data.get()).as_ref().unchecked_unwrap();
}
}
// We are no longer adding new lists, so we don't need the guard
drop(count);

// Insert the new element into the top-level list
unsafe {
let value_ptr = list.values.get_unchecked(id).get();
*value_ptr = Some(data);
(&*value_ptr).as_ref().unchecked_unwrap()
}
unreachable!();
}

fn raw_iter(&mut self) -> RawIter<T> {
RawIter {
remaining: *self.lock.get_mut().unwrap(),
index: 0,
table: self.table.load(Ordering::Relaxed),
list: *self.list.get_mut(),
}
}

/// Returns a mutable iterator over the local values of all threads.
/// Returns a mutable iterator over the local values of all threads in
/// unspecified order.
///
/// Since this call borrows the `ThreadLocal` mutably, this operation can
/// be done safely---the mutable borrow statically guarantees no other
Expand Down Expand Up @@ -376,29 +337,30 @@ impl<T: Send + UnwindSafe> UnwindSafe for ThreadLocal<T> {}
struct RawIter<T: Send> {
remaining: usize,
index: usize,
table: *const Table<T>,
list: *const List<T>,
}

impl<T: Send> Iterator for RawIter<T> {
type Item = *mut Option<Box<T>>;
type Item = *mut Option<T>;

fn next(&mut self) -> Option<*mut Option<Box<T>>> {
fn next(&mut self) -> Option<Self::Item> {
if self.remaining == 0 {
return None;
}

loop {
let entries = unsafe { &(*self.table).entries[..] };
while self.index < entries.len() {
let val = entries[self.index].data.get();
let values = &*unsafe { &*self.list }.values;

while self.index < values.len() {
let val = values[self.index].get();
self.index += 1;
if unsafe { (*val).is_some() } {
self.remaining -= 1;
return Some(val);
}
}
self.index = 0;
self.table = unsafe { &**(*self.table).prev.as_ref().unchecked_unwrap() };
self.list = unsafe { &**(*self.list).prev.as_ref().unchecked_unwrap() };
}
}

Expand All @@ -419,7 +381,7 @@ impl<'a, T: Send + 'a> Iterator for IterMut<'a, T> {
fn next(&mut self) -> Option<&'a mut T> {
self.raw
.next()
.map(|x| unsafe { &mut **(*x).as_mut().unchecked_unwrap() })
.map(|x| unsafe { &mut *(*x).as_mut().unchecked_unwrap() })
}

fn size_hint(&self) -> (usize, Option<usize>) {
Expand All @@ -441,7 +403,7 @@ impl<T: Send> Iterator for IntoIter<T> {
fn next(&mut self) -> Option<T> {
self.raw
.next()
.map(|x| unsafe { *(*x).take().unchecked_unwrap() })
.map(|x| unsafe { (*x).take().unchecked_unwrap() })
}

fn size_hint(&self) -> (usize, Option<usize>) {
Expand Down