Skip to content

Commit

Permalink
Add support for scoped threads
Browse files Browse the repository at this point in the history
Add loom::thread::scope to mirror std::thread::scope provided by the
standard library.
  • Loading branch information
akonradi committed Jun 12, 2023
1 parent ce8a232 commit c181071
Show file tree
Hide file tree
Showing 2 changed files with 358 additions and 47 deletions.
307 changes: 260 additions & 47 deletions src/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@ use std::{fmt, io};
use tracing::trace;

/// Mock implementation of `std::thread::JoinHandle`.
pub struct JoinHandle<T> {
result: Arc<Mutex<Option<std::thread::Result<T>>>>,
notify: rt::Notify,
thread: Thread,
}
pub struct JoinHandle<T>(JoinHandleInner<'static, T>);

/// Mock implementation of `std::thread::Thread`.
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -129,7 +125,7 @@ where
F: 'static,
T: 'static,
{
spawn_internal(f, None, None, location!())
JoinHandle(spawn_internal_static(f, None, None, location!()))
}

/// Mock implementation of `std::thread::park`.
Expand All @@ -143,43 +139,6 @@ pub fn park() {
rt::park(location!());
}

fn spawn_internal<F, T>(
f: F,
name: Option<String>,
stack_size: Option<usize>,
location: Location,
) -> JoinHandle<T>
where
F: FnOnce() -> T,
F: 'static,
T: 'static,
{
let result = Arc::new(Mutex::new(None));
let notify = rt::Notify::new(true, false);

let id = {
let name = name.clone();
let result = result.clone();
rt::spawn(stack_size, move || {
rt::execution(|execution| {
init_current(execution, name);
});

*result.lock().unwrap() = Some(Ok(f()));
notify.notify(location);
})
};

JoinHandle {
result,
notify,
thread: Thread {
id: ThreadId { id },
name,
},
}
}

impl Builder {
/// Generates the base configuration for spawning a thread, from which
/// configuration methods can be chained.
Expand Down Expand Up @@ -217,21 +176,53 @@ impl Builder {
F: Send + 'static,
T: Send + 'static,
{
Ok(spawn_internal(f, self.name, self.stack_size, location!()))
Ok(JoinHandle(spawn_internal_static(
f,
self.name,
self.stack_size,
location!(),
)))
}
}

impl Builder {
/// Spawns a new scoped thread using the settings set through this `Builder`.
pub fn spawn_scoped<'scope, 'env, F, T>(
self,
scope: &'scope Scope<'scope, 'env>,
f: F,
) -> io::Result<ScopedJoinHandle<'scope, T>>
where
F: FnOnce() -> T + Send + 'scope,
T: Send + 'scope,
{
Ok(ScopedJoinHandle(
// Safety: the call to this function requires a `&'scope Scope`
// which can only be constructed by `scope()`, which ensures that
// all spawned threads are joined before the `Scope` is destroyed.
unsafe {
spawn_internal(
f,
self.name,
self.stack_size,
Some(scope.data.clone()),
location!(),
)
},
))
}
}

impl<T> JoinHandle<T> {
/// Waits for the associated thread to finish.
#[track_caller]
pub fn join(self) -> std::thread::Result<T> {
self.notify.wait(location!());
self.result.lock().unwrap().take().unwrap()
self.0.join()
}

/// Gets a handle to the underlying [`Thread`]
pub fn thread(&self) -> &Thread {
&self.thread
self.0.thread()
}
}

Expand Down Expand Up @@ -312,3 +303,225 @@ impl<T: 'static> fmt::Debug for LocalKey<T> {
f.pad("LocalKey { .. }")
}
}

/// A scope for spawning scoped threads.
///
/// See [`scope`] for more details.
#[derive(Debug)]
pub struct Scope<'scope, 'env: 'scope> {
data: Arc<ScopeData>,
scope: PhantomData<&'scope mut &'scope ()>,
env: PhantomData<&'env mut &'env ()>,
}

/// An owned permission to join on a scoped thread (block on its termination).
///
/// See [`Scope::spawn`] for details.
#[derive(Debug)]
pub struct ScopedJoinHandle<'scope, T>(JoinHandleInner<'scope, T>);

/// Create a scope for spawning scoped threads.
///
/// Mock implementation of [`std::thread::scope`].
#[track_caller]
pub fn scope<'env, F, T>(f: F) -> T
where
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> T,
{
let scope = Scope {
data: Arc::new(ScopeData {
running_threads: Mutex::default(),
main_thread: current(),
}),
env: PhantomData,
scope: PhantomData,
};

// Run `f`, but catch panics so we can make sure to wait for all the threads to join.
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(&scope)));

// Wait until all the threads are finished. This is required to fulfill
// the safety requirements of `spawn_internal`.
let running = loop {
{
let running = scope.data.running_threads.lock().unwrap();
if running.count == 0 {
break running;
}
}
park();
};

for notify in &running.notify_on_finished {
notify.wait(location!())
}

// Throw any panic from `f`, or the return value of `f` if no thread panicked.
match result {
Err(e) => std::panic::resume_unwind(e),
Ok(result) => result,
}
}

impl<'scope, 'env> Scope<'scope, 'env> {
/// Spawns a new thread within a scope, returning a [`ScopedJoinHandle`] for it.
///
/// See [`std::thread::Scope`] and [`std::thread::scope`] for details.
pub fn spawn<F, T>(&'scope self, f: F) -> ScopedJoinHandle<'scope, T>
where
F: FnOnce() -> T + Send + 'scope,
T: Send + 'scope,
{
Builder::new()
.spawn_scoped(self, f)
.expect("failed to spawn thread")
}
}

impl<'scope, T> ScopedJoinHandle<'scope, T> {
/// Extracts a handle to the underlying thread.
pub fn thread(&self) -> &Thread {
self.0.thread()
}

/// Waits for the associated thread to finish.
pub fn join(self) -> std::thread::Result<T> {
self.0.join()
}
}

/// Handle for joining on a thread with a scope.
#[derive(Debug)]
struct JoinHandleInner<'scope, T> {
data: Arc<ThreadData<'scope, T>>,
notify: rt::Notify,
thread: Thread,
}

/// Spawns a thread without a local scope.
fn spawn_internal_static<F, T>(
f: F,
name: Option<String>,
stack_size: Option<usize>,
location: Location,
) -> JoinHandleInner<'static, T>
where
F: FnOnce() -> T,
F: 'static,
T: 'static,
{
// Safety: the requirements of `spawn_internal` are trivially satisfied
// since there is no `scope`.
unsafe { spawn_internal(f, name, stack_size, None, location) }
}

/// Spawns a thread with an optional scope.
///
/// The caller must ensure that if `scope` is not None, the provided closure
/// finishes before `'scope` ends.
unsafe fn spawn_internal<'scope, F, T>(
f: F,
name: Option<String>,
stack_size: Option<usize>,
scope: Option<Arc<ScopeData>>,
location: Location,
) -> JoinHandleInner<'scope, T>
where
F: FnOnce() -> T,
F: 'scope,
T: 'scope,
{
let scope_notify = scope
.clone()
.map(|scope| (scope.add_running_thread(), scope));
let thread_data = Arc::new(ThreadData::new());
let notify = rt::Notify::new(true, false);

let id = {
let name = name.clone();
let thread_data = thread_data.clone();
let body: Box<dyn FnOnce() + 'scope> = Box::new(move || {
rt::execution(|execution| {
init_current(execution, name);
});

*thread_data.result.lock().unwrap() = Some(Ok(f()));
notify.notify(location);

if let Some((notifier, scope)) = scope_notify {
notifier.notify(location!());
scope.remove_running_thread()
}
});
rt::spawn(
stack_size,
std::mem::transmute::<_, Box<dyn FnOnce()>>(body),
)
};

JoinHandleInner {
data: thread_data,
notify,
thread: Thread {
id: ThreadId { id },
name,
},
}
}

/// Data for a running thread.
#[derive(Debug)]
struct ThreadData<'scope, T> {
result: Mutex<Option<std::thread::Result<T>>>,
_marker: PhantomData<Option<&'scope ScopeData>>,
}

impl<'scope, T> ThreadData<'scope, T> {
fn new() -> Self {
Self {
result: Mutex::new(None),
_marker: PhantomData,
}
}
}

impl<'scope, T> JoinHandleInner<'scope, T> {
fn join(self) -> std::thread::Result<T> {
self.notify.wait(location!());
self.data.result.lock().unwrap().take().unwrap()
}

fn thread(&self) -> &Thread {
&self.thread
}
}

#[derive(Default, Debug)]
struct ScopeThreads {
count: usize,
notify_on_finished: Vec<rt::Notify>,
}

#[derive(Debug)]
struct ScopeData {
running_threads: Mutex<ScopeThreads>,
main_thread: Thread,
}

impl ScopeData {
fn add_running_thread(&self) -> rt::Notify {
let mut running = self.running_threads.lock().unwrap();
running.count += 1;
let notify = rt::Notify::new(true, false);
running.notify_on_finished.push(notify);
notify
}

fn remove_running_thread(&self) {
let mut running = self.running_threads.lock().unwrap();
running.count -= 1;
if running.count == 0 {
self.main_thread.unpark()
}
}
}

0 comments on commit c181071

Please sign in to comment.