Skip to content

Commit

Permalink
Fix thread::scope result drop order bug
Browse files Browse the repository at this point in the history
Fix a bug with thread::scope where the result of a spawned thread was being
saved even when the scoped thread's handle had been dropped. This was
causing the result to be dropped after the spawned thread was "finished",
which was after the point at which the call to thread::scope had already
completed. Add a regression test for this behavior.
  • Loading branch information
akonradi committed Aug 19, 2023
1 parent 2762cc0 commit ae21d3d
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 12 deletions.
34 changes: 22 additions & 12 deletions src/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ impl Builder {
f,
self.name,
self.stack_size,
Some(scope.data.clone()),
Some(&scope.data),
location!(),
)
},
Expand Down Expand Up @@ -309,7 +309,7 @@ impl<T: 'static> fmt::Debug for LocalKey<T> {
/// See [`scope`] for more details.
#[derive(Debug)]
pub struct Scope<'scope, 'env: 'scope> {
data: Arc<ScopeData>,
data: ScopeData,
scope: PhantomData<&'scope mut &'scope ()>,
env: PhantomData<&'env mut &'env ()>,
}
Expand All @@ -329,10 +329,10 @@ where
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> T,
{
let scope = Scope {
data: Arc::new(ScopeData {
data: ScopeData {
running_threads: Mutex::default(),
main_thread: current(),
}),
},
env: PhantomData,
scope: PhantomData,
};
Expand Down Expand Up @@ -394,7 +394,6 @@ impl<'scope, T> ScopedJoinHandle<'scope, T> {
#[derive(Debug)]
struct JoinHandleInner<'scope, T> {
data: Arc<ThreadData<'scope, T>>,
notify: rt::Notify,
thread: Thread,
}

Expand Down Expand Up @@ -423,7 +422,7 @@ unsafe fn spawn_internal<'scope, F, T>(
f: F,
name: Option<String>,
stack_size: Option<usize>,
scope: Option<Arc<ScopeData>>,
scope: Option<&'scope ScopeData>,
location: Location,
) -> JoinHandleInner<'scope, T>
where
Expand All @@ -435,18 +434,28 @@ where
.clone()
.map(|scope| (scope.add_running_thread(), scope));
let thread_data = Arc::new(ThreadData::new());
let notify = rt::Notify::new(true, false);

let id = {
let name = name.clone();
let thread_data = thread_data.clone();
// Hold a weak reference so that if the thread handle gets dropped, we
// don't try to store the result or notify anybody unnecessarily.
let weak_data = Arc::downgrade(&thread_data);

let body: Box<dyn FnOnce() + 'scope> = Box::new(move || {
rt::execution(|execution| {
init_current(execution, name);
});

*thread_data.result.lock().unwrap() = Some(Ok(f()));
notify.notify(location);
// Ensure everything from the spawned thread's execution either gets
// stored in the thread handle or dropped before notifying that the
// thread has completed.
{
let result = f();
if let Some(thread_data) = weak_data.upgrade() {
*thread_data.result.lock().unwrap() = Some(Ok(result));
thread_data.notification.notify(location);
}
}

if let Some((notifier, scope)) = scope_notify {
notifier.notify(location!());
Expand All @@ -461,7 +470,6 @@ where

JoinHandleInner {
data: thread_data,
notify,
thread: Thread {
id: ThreadId { id },
name,
Expand All @@ -473,21 +481,23 @@ where
#[derive(Debug)]
struct ThreadData<'scope, T> {
result: Mutex<Option<std::thread::Result<T>>>,
notification: rt::Notify,
_marker: PhantomData<Option<&'scope ScopeData>>,
}

impl<'scope, T> ThreadData<'scope, T> {
fn new() -> Self {
Self {
result: Mutex::new(None),
notification: rt::Notify::new(true, false),
_marker: PhantomData,
}
}
}

impl<'scope, T> JoinHandleInner<'scope, T> {
fn join(self) -> std::thread::Result<T> {
self.notify.wait(location!());
self.data.notification.wait(location!());
self.data.result.lock().unwrap().take().unwrap()
}

Expand Down
42 changes: 42 additions & 0 deletions tests/thread_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,45 @@ fn scoped_and_unscoped_threads() {
assert_eq!(v, 2);
})
}

struct YieldAndIncrementOnDrop<'a>(&'a std::sync::atomic::AtomicUsize);

impl Drop for YieldAndIncrementOnDrop<'_> {
fn drop(&mut self) {
thread::yield_now();
self.0.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
}

#[test]
fn scoped_thread_wait_until_finished() {
loom::model(|| {
let a = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
let r: &std::sync::atomic::AtomicUsize = &a;
thread::scope(|s| {
s.spawn(move || {
r.fetch_add(2, std::sync::atomic::Ordering::SeqCst);
YieldAndIncrementOnDrop(r)
});
});
assert_eq!(a.load(std::sync::atomic::Ordering::SeqCst), 3);
});
}

#[test]
fn scoped_thread_join_handle_forgotten() {
loom::model(|| {
let a = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
let r: &std::sync::atomic::AtomicUsize = &a;
thread::scope(|s| {
let handle = s.spawn(move || {
r.fetch_add(2, std::sync::atomic::Ordering::SeqCst);
YieldAndIncrementOnDrop(r)
});
std::mem::forget(handle)
});
// Expect only 2 since the spawned thread will complete but its result
// will be leaked and so never dropped.
assert_eq!(a.load(std::sync::atomic::Ordering::SeqCst), 2);
});
}

0 comments on commit ae21d3d

Please sign in to comment.