diff --git a/joblib/parallel.py b/joblib/parallel.py index af842b42c..ef8d3ab19 100644 --- a/joblib/parallel.py +++ b/joblib/parallel.py @@ -100,6 +100,8 @@ def _register_dask(): "mmap_mode": _Sentinel(default_value="r"), "prefer": _Sentinel(default_value=None), "require": _Sentinel(default_value=None), + "initializer": _Sentinel(default_value=None), + "initargs": _Sentinel(default_value=()), } @@ -312,6 +314,16 @@ class parallel_config: usable in some third-party library threadpools like OpenBLAS, MKL or OpenMP. This is only used with the ``loky`` backend. + initializer : callable, default=None + If not None, this will be called at the start of each worker + process. This can be used with the ``loky`` and + ``multiprocessing`` backends. Note that when workers are + reused, as is done by the ``loky`` backend, initialization + happens only once per worker. + + initargs : tuple, default=None + Arguments for initializer. + backend_params : dict Additional parameters to pass to the backend constructor when backend is a string. @@ -356,6 +368,8 @@ def __init__( prefer=default_parallel_config["prefer"], require=default_parallel_config["require"], inner_max_num_threads=None, + initializer=default_parallel_config["initializer"], + initargs=default_parallel_config["initargs"], **backend_params ): # Save the parallel info and set the active parallel config @@ -375,7 +389,9 @@ def __init__( "mmap_mode": mmap_mode, "prefer": prefer, "require": require, - "backend": backend + "backend": backend, + "initializer": initializer, + "initargs": initargs, } self.parallel_config = self.old_parallel_config.copy() self.parallel_config.update({ @@ -1057,6 +1073,14 @@ class Parallel(Logger): disable memmapping, other modes defined in the numpy.memmap doc: https://numpy.org/doc/stable/reference/generated/numpy.memmap.html Also, see 'max_nbytes' parameter documentation for more details. + initializer : callable, default=None + If not None, this will be called at the start of each + worker process. This can be used with the ``loky`` and + ``multiprocessing`` backends. Note that when workers are + reused, as is done by the ``loky`` backend, initialization + happens only once per worker. + initargs : tuple, default=None + Arguments for initializer. Notes ----- @@ -1194,6 +1218,8 @@ def __init__( mmap_mode=default_parallel_config["mmap_mode"], prefer=default_parallel_config["prefer"], require=default_parallel_config["require"], + initializer=default_parallel_config["initializer"], + initargs=default_parallel_config["initargs"], ): # Initiate parent Logger class state super().__init__() @@ -1233,6 +1259,8 @@ def __init__( (prefer, "prefer"), (require, "require"), (verbose, "verbose"), + (initializer, "initializer"), + (initargs, "initargs"), ] } diff --git a/joblib/test/test_parallel.py b/joblib/test/test_parallel.py index e2d6e6f6a..b8cdae2e3 100644 --- a/joblib/test/test_parallel.py +++ b/joblib/test/test_parallel.py @@ -6,6 +6,7 @@ # Copyright (c) 2010-2011 Gael Varoquaux # License: BSD Style, 3 clauses. +import multiprocessing import os import sys import time @@ -520,7 +521,7 @@ def producer(): queue.append('Produced %i' % i) yield i - Parallel(n_jobs=2, batch_size=1, pre_dispatch=3, backend=backend)( + Parallel(n_jobs=12, batch_size=1, pre_dispatch=3, backend=backend)( delayed(consumer)(queue, 'any') for _ in producer()) queue_contents = list(queue) @@ -2008,3 +2009,56 @@ def parallel_call(n_jobs): parallel_call(n_jobs) executor = get_reusable_executor(reuse=True) assert executor == first_executor + + +@with_multiprocessing +@parametrize('n_jobs', [2, 4, -1]) +@parametrize('backend', PROCESS_BACKENDS) +@parametrize("context", [parallel_config, parallel_backend]) +def test_initializer(n_jobs, backend, context): + n_jobs = effective_n_jobs(n_jobs) + manager = mp.Manager() + queue = manager.list() + + def initializer(queue): + queue.append("spam") + + with context( + backend=backend, + n_jobs=n_jobs, + initializer=initializer, + initargs=(queue,) + ): + with Parallel() as parallel: + values = parallel(delayed(square)(i) for i in range(n_jobs)) + + assert len(queue) == n_jobs + assert all(q == "spam" for q in queue) + + +@with_multiprocessing +@pytest.mark.parametrize("n_jobs", [2, 4, -1]) +def test_initializer_reuse(n_jobs): + # test that the initializer is called only once + # when the executor is reused + n_jobs = effective_n_jobs(n_jobs) + manager = mp.Manager() + queue = manager.list() + + def initializer(queue): + queue.append("spam") + + def parallel_call(n_jobs, initializer, initargs): + Parallel( + backend="loky", + n_jobs=n_jobs, + initializer=initializer, + initargs=(queue,), + )(delayed(square)(i) for i in range(n_jobs)) + + parallel_call(n_jobs, initializer, (queue,)) + assert len(queue) == n_jobs + + for i in range(10): + parallel_call(n_jobs, initializer, (queue,)) + assert len(queue) == n_jobs