diff --git a/CHANGELOG.md b/CHANGELOG.md index c0896ed8d..0da3a9cc3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,10 +7,11 @@ and this project adheres to Rust's notion of ## [Unreleased] ### Added -- `BELLMAN_NUM_CPUS` environment variable, which can be used to control the - number of logical CPUs that `bellman` will use when the (default) `multicore` - feature flag is enabled. The default (which has not changed) is to use the - `num_cpus` crate to determine the number of logical CPUs. +- `bellman` now uses `rayon` for multithreading when the (default) `multicore` + feature flag is enabled. This means that, when this flag is enabled, the + `RAYON_NUM_THREADS` environment variable controls the number of threads that + `bellman` will use. The default, which has not changed, is to use the same + number of threads as logical CPUs. - `bellman::multicore::Waiter` ### Changed @@ -20,6 +21,7 @@ and this project adheres to Rust's notion of - `bellman::multiexp::multiexp` now returns `bellman::multicore::Waiter>` instead of `Box>`. + - `bellman::multicore::log_num_cpus` is renamed to `log_num_threads`. ### Removed - `bellman::multicore::WorkerFuture` (replaced by `Waiter`). diff --git a/Cargo.toml b/Cargo.toml index 973a7e883..7b7626baa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,11 +23,11 @@ byteorder = "1" subtle = "2.2.1" # Multicore dependencies -crossbeam-channel = { version = "0.5", optional = true } +crossbeam-channel = { version = "0.5.1", optional = true } lazy_static = { version = "1.4.0", optional = true } -log = { version = "0.4.8", optional = true } +log = { version = "0.4", optional = true } num_cpus = { version = "1", optional = true } -rayon = { version = "1.3.0", optional = true } +rayon = { version = "1.5.1", optional = true } [dev-dependencies] bls12_381 = "0.5" diff --git a/src/domain.rs b/src/domain.rs index fa02adce9..dc27c565b 100644 --- a/src/domain.rs +++ b/src/domain.rs @@ -259,7 +259,7 @@ impl Group for Scalar { } fn best_fft>(a: &mut [T], worker: &Worker, omega: &S, log_n: u32) { - let log_cpus = worker.log_num_cpus(); + let log_cpus = worker.log_num_threads(); if log_n <= log_cpus { serial_fft(a, omega, log_n); diff --git a/src/multicore.rs b/src/multicore.rs index f7c0430b2..6326d4e40 100644 --- a/src/multicore.rs +++ b/src/multicore.rs @@ -4,30 +4,21 @@ #[cfg(feature = "multicore")] mod implementation { - use std::env; use std::sync::atomic::{AtomicUsize, Ordering}; use crossbeam_channel::{bounded, Receiver}; use lazy_static::lazy_static; use log::{error, trace}; - use num_cpus; + use rayon::current_num_threads; static WORKER_SPAWN_COUNTER: AtomicUsize = AtomicUsize::new(0); lazy_static! { - static ref NUM_CPUS: usize = env::var("BELLMAN_NUM_CPUS") - .map_err(|_| ()) - .and_then(|num| num.parse().map_err(|_| ())) - .unwrap_or_else(|_| num_cpus::get()); // See Worker::compute below for a description of this. - static ref WORKER_SPAWN_MAX_COUNT: usize = *NUM_CPUS * 4; - pub static ref THREAD_POOL: rayon::ThreadPool = rayon::ThreadPoolBuilder::new() - .num_threads(*NUM_CPUS) - .build() - .unwrap(); + static ref WORKER_SPAWN_MAX_COUNT: usize = current_num_threads() * 4; } - #[derive(Clone)] + #[derive(Clone, Default)] pub struct Worker {} impl Worker { @@ -35,8 +26,8 @@ mod implementation { Worker {} } - pub fn log_num_cpus(&self) -> u32 { - log2_floor(*NUM_CPUS) + pub fn log_num_threads(&self) -> u32 { + log2_floor(current_num_threads()) } pub fn compute(&self, f: F) -> Waiter @@ -45,7 +36,6 @@ mod implementation { R: Send + 'static, { let (sender, receiver) = bounded(1); - let thread_index = THREAD_POOL.current_thread_index().unwrap_or(0); // We keep track here of how many times spawn has been called. // It can be called without limit, each time, putting a @@ -63,17 +53,18 @@ mod implementation { // install call to help clear the growing work queue and // minimize the chances of memory exhaustion. if previous_count > *WORKER_SPAWN_MAX_COUNT { - THREAD_POOL.install(move || { - trace!("[{}] switching to install to help clear backlog[current threads {}, threads requested {}]", - thread_index, - THREAD_POOL.current_num_threads(), - WORKER_SPAWN_COUNTER.load(Ordering::SeqCst)); + let thread_index = rayon::current_thread_index().unwrap_or(0); + rayon::scope(move |_| { + trace!("[{}] switching to scope to help clear backlog [threads: current {}, requested {}]", + thread_index, + current_num_threads(), + WORKER_SPAWN_COUNTER.load(Ordering::SeqCst)); let res = f(); sender.send(res).unwrap(); WORKER_SPAWN_COUNTER.fetch_sub(1, Ordering::SeqCst); }); } else { - THREAD_POOL.spawn(move || { + rayon::spawn(move || { let res = f(); sender.send(res).unwrap(); WORKER_SPAWN_COUNTER.fetch_sub(1, Ordering::SeqCst); @@ -88,13 +79,14 @@ mod implementation { F: FnOnce(&rayon::Scope<'a>, usize) -> R + Send, R: Send, { - let chunk_size = if elements < *NUM_CPUS { + let num_threads = current_num_threads(); + let chunk_size = if elements < num_threads { 1 } else { - elements / *NUM_CPUS + elements / num_threads }; - THREAD_POOL.scope(|scope| f(scope, chunk_size)) + rayon::scope(|scope| f(scope, chunk_size)) } } @@ -105,8 +97,9 @@ mod implementation { impl Waiter { /// Wait for the result. pub fn wait(&self) -> T { - if THREAD_POOL.current_thread_index().is_some() { - let msg = "wait() cannot be called from within the worker thread pool since that would lead to deadlocks"; + // This will be Some if this thread is in the global thread pool. + if rayon::current_thread_index().is_some() { + let msg = "wait() cannot be called from within a thread pool since that would lead to deadlocks"; // panic! doesn't necessarily kill the process, so we log as well. error!("{}", msg); panic!("{}", msg); @@ -158,7 +151,7 @@ mod implementation { Worker } - pub fn log_num_cpus(&self) -> u32 { + pub fn log_num_threads(&self) -> u32 { 0 }