Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to specify an initialization function for 'loky' and 'multiprocessing' backends #1525

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 29 additions & 1 deletion joblib/parallel.py
Expand Up @@ -100,6 +100,8 @@ def _register_dask():
"mmap_mode": _Sentinel(default_value="r"),
"prefer": _Sentinel(default_value=None),
"require": _Sentinel(default_value=None),
"initializer": _Sentinel(default_value=None),
"initargs": _Sentinel(default_value=()),
}


Expand Down Expand Up @@ -312,6 +314,16 @@ class parallel_config:
usable in some third-party library threadpools like OpenBLAS,
MKL or OpenMP. This is only used with the ``loky`` backend.

initializer : callable, default=None
If not None, this will be called at the start of each worker
process. This can be used with the ``loky`` and
``multiprocessing`` backends. Note that when workers are
reused, as is done by the ``loky`` backend, initialization
happens only once per worker.

initargs : tuple, default=None
Arguments for initializer.

backend_params : dict
Additional parameters to pass to the backend constructor when
backend is a string.
Expand Down Expand Up @@ -356,6 +368,8 @@ def __init__(
prefer=default_parallel_config["prefer"],
require=default_parallel_config["require"],
inner_max_num_threads=None,
initializer=default_parallel_config["initializer"],
initargs=default_parallel_config["initargs"],
**backend_params
):
# Save the parallel info and set the active parallel config
Expand All @@ -375,7 +389,9 @@ def __init__(
"mmap_mode": mmap_mode,
"prefer": prefer,
"require": require,
"backend": backend
"backend": backend,
"initializer": initializer,
"initargs": initargs,
}
self.parallel_config = self.old_parallel_config.copy()
self.parallel_config.update({
Expand Down Expand Up @@ -1057,6 +1073,14 @@ class Parallel(Logger):
disable memmapping, other modes defined in the numpy.memmap doc:
https://numpy.org/doc/stable/reference/generated/numpy.memmap.html
Also, see 'max_nbytes' parameter documentation for more details.
initializer : callable, default=None
If not None, this will be called at the start of each
worker process. This can be used with the ``loky`` and
``multiprocessing`` backends. Note that when workers are
reused, as is done by the ``loky`` backend, initialization
happens only once per worker.
initargs : tuple, default=None
Arguments for initializer.

Notes
-----
Expand Down Expand Up @@ -1194,6 +1218,8 @@ def __init__(
mmap_mode=default_parallel_config["mmap_mode"],
prefer=default_parallel_config["prefer"],
require=default_parallel_config["require"],
initializer=default_parallel_config["initializer"],
initargs=default_parallel_config["initargs"],
):
# Initiate parent Logger class state
super().__init__()
Expand Down Expand Up @@ -1233,6 +1259,8 @@ def __init__(
(prefer, "prefer"),
(require, "require"),
(verbose, "verbose"),
(initializer, "initializer"),
(initargs, "initargs"),
]
}

Expand Down
56 changes: 55 additions & 1 deletion joblib/test/test_parallel.py
Expand Up @@ -6,6 +6,7 @@
# Copyright (c) 2010-2011 Gael Varoquaux
# License: BSD Style, 3 clauses.

import multiprocessing
import os
import sys
import time
Expand Down Expand Up @@ -520,7 +521,7 @@ def producer():
queue.append('Produced %i' % i)
yield i

Parallel(n_jobs=2, batch_size=1, pre_dispatch=3, backend=backend)(
Parallel(n_jobs=12, batch_size=1, pre_dispatch=3, backend=backend)(
delayed(consumer)(queue, 'any') for _ in producer())

queue_contents = list(queue)
Expand Down Expand Up @@ -2008,3 +2009,56 @@ def parallel_call(n_jobs):
parallel_call(n_jobs)
executor = get_reusable_executor(reuse=True)
assert executor == first_executor


@with_multiprocessing
@parametrize('n_jobs', [2, 4, -1])
@parametrize('backend', PROCESS_BACKENDS)
@parametrize("context", [parallel_config, parallel_backend])
def test_initializer(n_jobs, backend, context):
n_jobs = effective_n_jobs(n_jobs)
manager = mp.Manager()
queue = manager.list()

def initializer(queue):
queue.append("spam")

with context(
backend=backend,
n_jobs=n_jobs,
initializer=initializer,
initargs=(queue,)
):
with Parallel() as parallel:
values = parallel(delayed(square)(i) for i in range(n_jobs))

assert len(queue) == n_jobs
assert all(q == "spam" for q in queue)
Comment on lines +2035 to +2036
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ogrisel I'm actually not sure what assertion I can make here (the current assertion seems not always to be true).

I want to test that every process that was started has run its init function.

However, it seems I sometimes reach line #2035 before all processes have been initialized - and by that time, the work has completed by fewer than the number of started processes.

Do you have any suggestions? Thanks!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ogrisel -- just a gentle ping about this if you have any suggestions; thanks!



@with_multiprocessing
@pytest.mark.parametrize("n_jobs", [2, 4, -1])
def test_initializer_reuse(n_jobs):
# test that the initializer is called only once
# when the executor is reused
n_jobs = effective_n_jobs(n_jobs)
manager = mp.Manager()
queue = manager.list()

def initializer(queue):
queue.append("spam")

def parallel_call(n_jobs, initializer, initargs):
Parallel(
backend="loky",
n_jobs=n_jobs,
initializer=initializer,
initargs=(queue,),
)(delayed(square)(i) for i in range(n_jobs))

parallel_call(n_jobs, initializer, (queue,))
assert len(queue) == n_jobs

for i in range(10):
parallel_call(n_jobs, initializer, (queue,))
assert len(queue) == n_jobs