From d9122f617998f997657c594c5689f5e99497c36e Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Sat, 18 Nov 2023 06:36:04 -0500 Subject: [PATCH 1/6] Add an option to specify initializer function for processes --- joblib/parallel.py | 28 +++++++++++++++++++++++++++- joblib/test/test_parallel.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/joblib/parallel.py b/joblib/parallel.py index af842b42c..fa788dede 100644 --- a/joblib/parallel.py +++ b/joblib/parallel.py @@ -100,6 +100,8 @@ def _register_dask(): "mmap_mode": _Sentinel(default_value="r"), "prefer": _Sentinel(default_value=None), "require": _Sentinel(default_value=None), + "initializer": _Sentinel(default_value=None), + "initargs": _Sentinel(default_value=()), } @@ -312,6 +314,15 @@ class parallel_config: usable in some third-party library threadpools like OpenBLAS, MKL or OpenMP. This is only used with the ``loky`` backend. + initializer : callable, default=None + If not None, this will be called at the start of each worker + process. This can be used to install signal handlers or to + import additional modules for the worker. This can be used + with the ``loky`` and ``multiprocessing`` backends. + + initargs : tuple, default=None + Arguments for initializer. + backend_params : dict Additional parameters to pass to the backend constructor when backend is a string. @@ -356,6 +367,8 @@ def __init__( prefer=default_parallel_config["prefer"], require=default_parallel_config["require"], inner_max_num_threads=None, + initializer=default_parallel_config["initializer"], + initargs=default_parallel_config["initargs"], **backend_params ): # Save the parallel info and set the active parallel config @@ -375,7 +388,9 @@ def __init__( "mmap_mode": mmap_mode, "prefer": prefer, "require": require, - "backend": backend + "backend": backend, + "initializer": initializer, + "initargs": initargs, } self.parallel_config = self.old_parallel_config.copy() self.parallel_config.update({ @@ -1057,6 +1072,13 @@ class Parallel(Logger): disable memmapping, other modes defined in the numpy.memmap doc: https://numpy.org/doc/stable/reference/generated/numpy.memmap.html Also, see 'max_nbytes' parameter documentation for more details. + initializer : callable, default=None + If not None, this will be called at the start of each worker + process. This can be used to install signal handlers or to + import additional modules for the worker. This can be used + with the ``loky`` and ``multiprocessing`` backends. + initargs : tuple, default=None + Arguments for initializer. Notes ----- @@ -1194,6 +1216,8 @@ def __init__( mmap_mode=default_parallel_config["mmap_mode"], prefer=default_parallel_config["prefer"], require=default_parallel_config["require"], + initializer=default_parallel_config["initializer"], + initargs=default_parallel_config["initargs"], ): # Initiate parent Logger class state super().__init__() @@ -1233,6 +1257,8 @@ def __init__( (prefer, "prefer"), (require, "require"), (verbose, "verbose"), + (initializer, "initializer"), + (initargs, "initargs"), ] } diff --git a/joblib/test/test_parallel.py b/joblib/test/test_parallel.py index e2d6e6f6a..b52b11dfe 100644 --- a/joblib/test/test_parallel.py +++ b/joblib/test/test_parallel.py @@ -520,7 +520,7 @@ def producer(): queue.append('Produced %i' % i) yield i - Parallel(n_jobs=2, batch_size=1, pre_dispatch=3, backend=backend)( + Parallel(n_jobs=12, batch_size=1, pre_dispatch=3, backend=backend)( delayed(consumer)(queue, 'any') for _ in producer()) queue_contents = list(queue) @@ -2008,3 +2008,31 @@ def parallel_call(n_jobs): parallel_call(n_jobs) executor = get_reusable_executor(reuse=True) assert executor == first_executor + + +@with_multiprocessing +@parametrize('n_jobs', [2, 3, -1]) +@parametrize('backend', PROCESS_BACKENDS) +@parametrize("context", [parallel_config, parallel_backend]) +def test_initializer(n_jobs, backend, context): + manager = mp.Manager() + val = manager.Value('i', 0) + + def initializer(val): + val.value += 1 + + n_jobs = effective_n_jobs(n_jobs) + with context( + backend=backend, + n_jobs=n_jobs, + initializer=initializer, + initargs=(val,) + ): + with Parallel(n_jobs=n_jobs) as parallel: + executor = get_workers(parallel._backend) + _ = parallel(delayed(square)(i) for i in range(n_jobs)) + + if backend == "loky": + assert val.value == len(executor._processes) + else: + assert val.value == n_jobs From e2832a855eb203408f4774eab623c97f2095d165 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Mon, 20 Nov 2023 09:11:01 -0500 Subject: [PATCH 2/6] Add tests --- joblib/test/test_parallel.py | 54 ++++++++++++++++++++++++++---------- 1 file changed, 40 insertions(+), 14 deletions(-) diff --git a/joblib/test/test_parallel.py b/joblib/test/test_parallel.py index b52b11dfe..b8cdae2e3 100644 --- a/joblib/test/test_parallel.py +++ b/joblib/test/test_parallel.py @@ -6,6 +6,7 @@ # Copyright (c) 2010-2011 Gael Varoquaux # License: BSD Style, 3 clauses. +import multiprocessing import os import sys import time @@ -2011,28 +2012,53 @@ def parallel_call(n_jobs): @with_multiprocessing -@parametrize('n_jobs', [2, 3, -1]) +@parametrize('n_jobs', [2, 4, -1]) @parametrize('backend', PROCESS_BACKENDS) @parametrize("context", [parallel_config, parallel_backend]) def test_initializer(n_jobs, backend, context): + n_jobs = effective_n_jobs(n_jobs) manager = mp.Manager() - val = manager.Value('i', 0) - - def initializer(val): - val.value += 1 + queue = manager.list() + + def initializer(queue): + queue.append("spam") - n_jobs = effective_n_jobs(n_jobs) with context( backend=backend, n_jobs=n_jobs, initializer=initializer, - initargs=(val,) + initargs=(queue,) ): - with Parallel(n_jobs=n_jobs) as parallel: - executor = get_workers(parallel._backend) - _ = parallel(delayed(square)(i) for i in range(n_jobs)) + with Parallel() as parallel: + values = parallel(delayed(square)(i) for i in range(n_jobs)) - if backend == "loky": - assert val.value == len(executor._processes) - else: - assert val.value == n_jobs + assert len(queue) == n_jobs + assert all(q == "spam" for q in queue) + + +@with_multiprocessing +@pytest.mark.parametrize("n_jobs", [2, 4, -1]) +def test_initializer_reuse(n_jobs): + # test that the initializer is called only once + # when the executor is reused + n_jobs = effective_n_jobs(n_jobs) + manager = mp.Manager() + queue = manager.list() + + def initializer(queue): + queue.append("spam") + + def parallel_call(n_jobs, initializer, initargs): + Parallel( + backend="loky", + n_jobs=n_jobs, + initializer=initializer, + initargs=(queue,), + )(delayed(square)(i) for i in range(n_jobs)) + + parallel_call(n_jobs, initializer, (queue,)) + assert len(queue) == n_jobs + + for i in range(10): + parallel_call(n_jobs, initializer, (queue,)) + assert len(queue) == n_jobs From cb2478a5de8fb3720c452654e7cfa1d23f3c3105 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Mon, 20 Nov 2023 09:13:24 -0500 Subject: [PATCH 3/6] Update docs --- joblib/parallel.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/joblib/parallel.py b/joblib/parallel.py index fa788dede..3f163a163 100644 --- a/joblib/parallel.py +++ b/joblib/parallel.py @@ -318,7 +318,9 @@ class parallel_config: If not None, this will be called at the start of each worker process. This can be used to install signal handlers or to import additional modules for the worker. This can be used - with the ``loky`` and ``multiprocessing`` backends. + with the ``loky`` and ``multiprocessing`` backends. Note that + when workers are reused, as is done by the ``loky`` backend, + initialization happens only once per worker. initargs : tuple, default=None Arguments for initializer. @@ -1076,7 +1078,9 @@ class Parallel(Logger): If not None, this will be called at the start of each worker process. This can be used to install signal handlers or to import additional modules for the worker. This can be used - with the ``loky`` and ``multiprocessing`` backends. + with the ``loky`` and ``multiprocessing`` backends. Note that + when workers are reused, as is done by the ``loky`` backend, + initialization happens only once per worker. initargs : tuple, default=None Arguments for initializer. From f86bfab2a3b3981d7546f41ca7b67fa9dfb6fe21 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Mon, 20 Nov 2023 09:17:23 -0500 Subject: [PATCH 4/6] Don't be too prescriptive about how to use initializer --- joblib/parallel.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/joblib/parallel.py b/joblib/parallel.py index 3f163a163..ef8d3ab19 100644 --- a/joblib/parallel.py +++ b/joblib/parallel.py @@ -316,11 +316,10 @@ class parallel_config: initializer : callable, default=None If not None, this will be called at the start of each worker - process. This can be used to install signal handlers or to - import additional modules for the worker. This can be used - with the ``loky`` and ``multiprocessing`` backends. Note that - when workers are reused, as is done by the ``loky`` backend, - initialization happens only once per worker. + process. This can be used with the ``loky`` and + ``multiprocessing`` backends. Note that when workers are + reused, as is done by the ``loky`` backend, initialization + happens only once per worker. initargs : tuple, default=None Arguments for initializer. @@ -1075,12 +1074,11 @@ class Parallel(Logger): https://numpy.org/doc/stable/reference/generated/numpy.memmap.html Also, see 'max_nbytes' parameter documentation for more details. initializer : callable, default=None - If not None, this will be called at the start of each worker - process. This can be used to install signal handlers or to - import additional modules for the worker. This can be used - with the ``loky`` and ``multiprocessing`` backends. Note that - when workers are reused, as is done by the ``loky`` backend, - initialization happens only once per worker. + If not None, this will be called at the start of each + worker process. This can be used with the ``loky`` and + ``multiprocessing`` backends. Note that when workers are + reused, as is done by the ``loky`` backend, initialization + happens only once per worker. initargs : tuple, default=None Arguments for initializer. From aa8fc4203ea7df88650fc0e05cddf33e48759986 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Mon, 20 Nov 2023 10:17:24 -0500 Subject: [PATCH 5/6] Move initializer definition and use batch_size=1 --- joblib/test/test_parallel.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/joblib/test/test_parallel.py b/joblib/test/test_parallel.py index b8cdae2e3..40118a200 100644 --- a/joblib/test/test_parallel.py +++ b/joblib/test/test_parallel.py @@ -2011,6 +2011,9 @@ def parallel_call(n_jobs): assert executor == first_executor +def initializer(queue): + queue.append("spam") + @with_multiprocessing @parametrize('n_jobs', [2, 4, -1]) @parametrize('backend', PROCESS_BACKENDS) @@ -2020,19 +2023,16 @@ def test_initializer(n_jobs, backend, context): manager = mp.Manager() queue = manager.list() - def initializer(queue): - queue.append("spam") - with context( backend=backend, n_jobs=n_jobs, initializer=initializer, - initargs=(queue,) + initargs=(queue,), ): - with Parallel() as parallel: - values = parallel(delayed(square)(i) for i in range(n_jobs)) + with Parallel(batch_size=1) as parallel: + pids = parallel(delayed(sleep_and_return_pid)() for i in range(n_jobs)) - assert len(queue) == n_jobs + assert len(queue) == len(set(pids)) assert all(q == "spam" for q in queue) @@ -2045,20 +2045,20 @@ def test_initializer_reuse(n_jobs): manager = mp.Manager() queue = manager.list() - def initializer(queue): - queue.append("spam") - def parallel_call(n_jobs, initializer, initargs): - Parallel( + return Parallel( backend="loky", + batch_size=1, n_jobs=n_jobs, initializer=initializer, initargs=(queue,), - )(delayed(square)(i) for i in range(n_jobs)) + )(delayed(sleep_and_return_pid)() for i in range(n_jobs)) - parallel_call(n_jobs, initializer, (queue,)) - assert len(queue) == n_jobs + pids = parallel_call(n_jobs, initializer, (queue,)) + assert len(queue) == len(set(pids)) + assert all(q == "spam" for q in queue) for i in range(10): - parallel_call(n_jobs, initializer, (queue,)) + pids = parallel_call(n_jobs, initializer, (queue,)) assert len(queue) == n_jobs + assert all(q == "spam" for q in queue) From 3cb1440db37a636e4f0065f4bcac33c86be5d0ef Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Mon, 20 Nov 2023 12:15:59 -0500 Subject: [PATCH 6/6] Revert "Move initializer definition and use batch_size=1" This reverts commit aa8fc4203ea7df88650fc0e05cddf33e48759986. --- joblib/test/test_parallel.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/joblib/test/test_parallel.py b/joblib/test/test_parallel.py index 40118a200..b8cdae2e3 100644 --- a/joblib/test/test_parallel.py +++ b/joblib/test/test_parallel.py @@ -2011,9 +2011,6 @@ def parallel_call(n_jobs): assert executor == first_executor -def initializer(queue): - queue.append("spam") - @with_multiprocessing @parametrize('n_jobs', [2, 4, -1]) @parametrize('backend', PROCESS_BACKENDS) @@ -2023,16 +2020,19 @@ def test_initializer(n_jobs, backend, context): manager = mp.Manager() queue = manager.list() + def initializer(queue): + queue.append("spam") + with context( backend=backend, n_jobs=n_jobs, initializer=initializer, - initargs=(queue,), + initargs=(queue,) ): - with Parallel(batch_size=1) as parallel: - pids = parallel(delayed(sleep_and_return_pid)() for i in range(n_jobs)) + with Parallel() as parallel: + values = parallel(delayed(square)(i) for i in range(n_jobs)) - assert len(queue) == len(set(pids)) + assert len(queue) == n_jobs assert all(q == "spam" for q in queue) @@ -2045,20 +2045,20 @@ def test_initializer_reuse(n_jobs): manager = mp.Manager() queue = manager.list() + def initializer(queue): + queue.append("spam") + def parallel_call(n_jobs, initializer, initargs): - return Parallel( + Parallel( backend="loky", - batch_size=1, n_jobs=n_jobs, initializer=initializer, initargs=(queue,), - )(delayed(sleep_and_return_pid)() for i in range(n_jobs)) + )(delayed(square)(i) for i in range(n_jobs)) - pids = parallel_call(n_jobs, initializer, (queue,)) - assert len(queue) == len(set(pids)) - assert all(q == "spam" for q in queue) + parallel_call(n_jobs, initializer, (queue,)) + assert len(queue) == n_jobs for i in range(10): - pids = parallel_call(n_jobs, initializer, (queue,)) + parallel_call(n_jobs, initializer, (queue,)) assert len(queue) == n_jobs - assert all(q == "spam" for q in queue)