Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use global rayon threadpool for multicore #71

Merged
merged 6 commits into from Aug 26, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 6 additions & 4 deletions CHANGELOG.md
Expand Up @@ -7,10 +7,11 @@ and this project adheres to Rust's notion of

## [Unreleased]
### Added
- `BELLMAN_NUM_CPUS` environment variable, which can be used to control the
number of logical CPUs that `bellman` will use when the (default) `multicore`
feature flag is enabled. The default (which has not changed) is to use the
`num_cpus` crate to determine the number of logical CPUs.
- `bellman` now uses `rayon` for multithreading when the (default) `multicore`
feature flag is enabled. This means that, when this flag is enabled, the
`RAYON_NUM_THREADS` environment variable controls the number of threads that
`bellman` will use. The default, which has not changed, is to use the same
number of threads as logical CPUs.
- `bellman::multicore::Waiter`

### Changed
Expand All @@ -20,6 +21,7 @@ and this project adheres to Rust's notion of
- `bellman::multiexp::multiexp` now returns
`bellman::multicore::Waiter<Result<G, SynthesisError>>` instead of
`Box<dyn Future<Item = G, Error = SynthesisError>>`.
- `bellman::multicore::log_num_cpus` is renamed to `log_num_threads`.

### Removed
- `bellman::multicore::WorkerFuture` (replaced by `Waiter`).
Expand Down
6 changes: 3 additions & 3 deletions Cargo.toml
Expand Up @@ -23,11 +23,11 @@ byteorder = "1"
subtle = "2.2.1"

# Multicore dependencies
crossbeam-channel = { version = "0.5", optional = true }
crossbeam-channel = { version = "0.5.1", optional = true }
lazy_static = { version = "1.4.0", optional = true }
log = { version = "0.4.8", optional = true }
log = { version = "0.4", optional = true }
num_cpus = { version = "1", optional = true }
rayon = { version = "1.3.0", optional = true }
rayon = { version = "1.5.1", optional = true }

[dev-dependencies]
bls12_381 = "0.5"
Expand Down
2 changes: 1 addition & 1 deletion src/domain.rs
Expand Up @@ -259,7 +259,7 @@ impl<S: PrimeField> Group<S> for Scalar<S> {
}

fn best_fft<S: PrimeField, T: Group<S>>(a: &mut [T], worker: &Worker, omega: &S, log_n: u32) {
let log_cpus = worker.log_num_cpus();
let log_cpus = worker.log_num_threads();
daira marked this conversation as resolved.
Show resolved Hide resolved

if log_n <= log_cpus {
serial_fft(a, omega, log_n);
Expand Down
45 changes: 19 additions & 26 deletions src/multicore.rs
Expand Up @@ -4,39 +4,30 @@

#[cfg(feature = "multicore")]
mod implementation {
use std::env;
use std::sync::atomic::{AtomicUsize, Ordering};

use crossbeam_channel::{bounded, Receiver};
use lazy_static::lazy_static;
use log::{error, trace};
use num_cpus;
use rayon::current_num_threads;

static WORKER_SPAWN_COUNTER: AtomicUsize = AtomicUsize::new(0);

lazy_static! {
static ref NUM_CPUS: usize = env::var("BELLMAN_NUM_CPUS")
.map_err(|_| ())
.and_then(|num| num.parse().map_err(|_| ()))
.unwrap_or_else(|_| num_cpus::get());
// See Worker::compute below for a description of this.
static ref WORKER_SPAWN_MAX_COUNT: usize = *NUM_CPUS * 4;
pub static ref THREAD_POOL: rayon::ThreadPool = rayon::ThreadPoolBuilder::new()
.num_threads(*NUM_CPUS)
.build()
.unwrap();
static ref WORKER_SPAWN_MAX_COUNT: usize = current_num_threads() * 4;
}

#[derive(Clone)]
#[derive(Clone, Default)]
pub struct Worker {}

impl Worker {
pub fn new() -> Worker {
Worker {}
}

pub fn log_num_cpus(&self) -> u32 {
log2_floor(*NUM_CPUS)
pub fn log_num_threads(&self) -> u32 {
log2_floor(current_num_threads())
}

pub fn compute<F, R>(&self, f: F) -> Waiter<R>
Expand All @@ -45,7 +36,6 @@ mod implementation {
R: Send + 'static,
{
let (sender, receiver) = bounded(1);
let thread_index = THREAD_POOL.current_thread_index().unwrap_or(0);

// We keep track here of how many times spawn has been called.
// It can be called without limit, each time, putting a
Expand All @@ -63,17 +53,18 @@ mod implementation {
// install call to help clear the growing work queue and
// minimize the chances of memory exhaustion.
if previous_count > *WORKER_SPAWN_MAX_COUNT {
THREAD_POOL.install(move || {
trace!("[{}] switching to install to help clear backlog[current threads {}, threads requested {}]",
thread_index,
THREAD_POOL.current_num_threads(),
WORKER_SPAWN_COUNTER.load(Ordering::SeqCst));
let thread_index = rayon::current_thread_index().unwrap_or(0);
rayon::scope(move |_| {
trace!("[{}] switching to scope to help clear backlog [threads: current {}, requested {}]",
thread_index,
current_num_threads(),
WORKER_SPAWN_COUNTER.load(Ordering::SeqCst));
let res = f();
sender.send(res).unwrap();
WORKER_SPAWN_COUNTER.fetch_sub(1, Ordering::SeqCst);
});
} else {
THREAD_POOL.spawn(move || {
rayon::spawn(move || {
let res = f();
sender.send(res).unwrap();
WORKER_SPAWN_COUNTER.fetch_sub(1, Ordering::SeqCst);
Expand All @@ -88,13 +79,14 @@ mod implementation {
F: FnOnce(&rayon::Scope<'a>, usize) -> R + Send,
R: Send,
{
let chunk_size = if elements < *NUM_CPUS {
let num_threads = current_num_threads();
let chunk_size = if elements < num_threads {
1
} else {
elements / *NUM_CPUS
elements / num_threads
};

THREAD_POOL.scope(|scope| f(scope, chunk_size))
rayon::scope(|scope| f(scope, chunk_size))
}
}

Expand All @@ -105,8 +97,9 @@ mod implementation {
impl<T> Waiter<T> {
/// Wait for the result.
pub fn wait(&self) -> T {
if THREAD_POOL.current_thread_index().is_some() {
let msg = "wait() cannot be called from within the worker thread pool since that would lead to deadlocks";
// This will be Some if this thread is in the global thread pool.
if rayon::current_thread_index().is_some() {
let msg = "wait() cannot be called from within a thread pool since that would lead to deadlocks";
// panic! doesn't necessarily kill the process, so we log as well.
error!("{}", msg);
panic!("{}", msg);
Expand Down