diff --git a/src/lib.rs b/src/lib.rs index a92f272..16b39bc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -143,7 +143,7 @@ impl Drop for ThreadLocal { let this_bucket_size = 1 << i; if bucket_ptr.is_null() { - break; + continue; } unsafe { deallocate_bucket(bucket_ptr, this_bucket_size) }; @@ -205,7 +205,7 @@ impl ThreadLocal { return Ok(val); } - Ok(self.insert(create()?)) + Ok(self.insert(thread, create()?)) } fn get_inner(&self, thread: Thread) -> Option<&T> { @@ -226,8 +226,7 @@ impl ThreadLocal { } #[cold] - fn insert(&self, data: T) -> &T { - let thread = thread_id::get(); + fn insert(&self, thread: Thread, data: T) -> &T { let bucket_atomic_ptr = unsafe { self.buckets.get_unchecked(thread.bucket) }; let bucket_ptr: *const _ = bucket_atomic_ptr.load(Ordering::Acquire); @@ -372,16 +371,14 @@ impl RawIter { let bucket = unsafe { thread_local.buckets.get_unchecked(self.bucket) }; let bucket = bucket.load(Ordering::Acquire); - if bucket.is_null() { - return None; - } - - while self.index < self.bucket_size { - let entry = unsafe { &*bucket.add(self.index) }; - self.index += 1; - if entry.present.load(Ordering::Acquire) { - self.yielded += 1; - return Some(unsafe { &*(&*entry.value.get()).as_ptr() }); + if !bucket.is_null() { + while self.index < self.bucket_size { + let entry = unsafe { &*bucket.add(self.index) }; + self.index += 1; + if entry.present.load(Ordering::Acquire) { + self.yielded += 1; + return Some(unsafe { &*(&*entry.value.get()).as_ptr() }); + } } } @@ -401,16 +398,14 @@ impl RawIter { let bucket = unsafe { thread_local.buckets.get_unchecked_mut(self.bucket) }; let bucket = *bucket.get_mut(); - if bucket.is_null() { - return None; - } - - while self.index < self.bucket_size { - let entry = unsafe { &mut *bucket.add(self.index) }; - self.index += 1; - if *entry.present.get_mut() { - self.yielded += 1; - return Some(entry); + if !bucket.is_null() { + while self.index < self.bucket_size { + let entry = unsafe { &mut *bucket.add(self.index) }; + self.index += 1; + if *entry.present.get_mut() { + self.yielded += 1; + return Some(entry); + } } } @@ -525,7 +520,8 @@ unsafe fn deallocate_bucket(bucket: *mut Entry, size: usize) { #[cfg(test)] mod tests { - use super::ThreadLocal; + use super::*; + use std::cell::RefCell; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering::Relaxed; @@ -627,6 +623,32 @@ mod tests { assert_eq!(dropped.load(Relaxed), 1); } + #[test] + fn test_earlyreturn_buckets() { + struct Dropped(Arc); + impl Drop for Dropped { + fn drop(&mut self) { + self.0.fetch_add(1, Relaxed); + } + } + let dropped = Arc::new(AtomicUsize::new(0)); + + // We use a high `id` here to guarantee that a lazily allocated bucket somewhere in the middle is used. + // Neither iteration nor `Drop` must early-return on `null` buckets that are used for lower `buckets`. + let thread = Thread::new(1234); + assert!(thread.bucket > 1); + + let mut local = ThreadLocal::new(); + local.insert(thread, Dropped(dropped.clone())); + + let item = local.iter().next().unwrap(); + assert_eq!(item.0.load(Relaxed), 0); + let item = local.iter_mut().next().unwrap(); + assert_eq!(item.0.load(Relaxed), 0); + drop(local); + assert_eq!(dropped.load(Relaxed), 1); + } + #[test] fn is_sync() { fn foo() {} diff --git a/src/thread_id.rs b/src/thread_id.rs index 075ab09..3024e9c 100644 --- a/src/thread_id.rs +++ b/src/thread_id.rs @@ -59,7 +59,7 @@ pub(crate) struct Thread { pub(crate) index: usize, } impl Thread { - fn new(id: usize) -> Self { + pub(crate) fn new(id: usize) -> Self { let bucket = usize::from(POINTER_WIDTH) - ((id + 1).leading_zeros() as usize) - 1; let bucket_size = 1 << bucket; let index = id - (bucket_size - 1);