diff --git a/cleanlab/pruning.py b/cleanlab/pruning.py index 63e4a377dc..cc30f3a8a1 100644 --- a/cleanlab/pruning.py +++ b/cleanlab/pruning.py @@ -29,6 +29,7 @@ from sklearn.preprocessing import MultiLabelBinarizer import multiprocessing from multiprocessing.sharedctypes import RawArray +from contextlib import contextmanager import sys import os import time @@ -56,19 +57,19 @@ # pruning, regardless if noise estimates are larger. MIN_NUM_PER_CLASS = 5 -# For python 2/3 compatibility, define pool context manager -# to support the 'with' statement in Python 2 -if sys.version_info[0] == 2: - from contextlib import contextmanager - - - @contextmanager - def multiprocessing_context(*args, **kwargs): - pool = multiprocessing.Pool(*args, **kwargs) - yield pool - pool.terminate() -else: - multiprocessing_context = multiprocessing.Pool +# Coverage testing registers finalizers that don't run if +# `multiprocessing.Pool.terminate()` is called +# (`multiprocessing.Pool.__exit__()` also calls `terminate()`), so we use this +# alternative context manager that calls `join()` to ensure finalizers are run +# reliably. +# +# See https://pytest-cov.readthedocs.io/en/latest/subprocess-support.html#if-you-use-multiprocessing-pool +@contextmanager +def joining_pool(*args, **kwargs): + pool = multiprocessing.Pool(*args, **kwargs) + yield pool + pool.close() + pool.join() # Globals to be shared across threads in multiprocessing mp_params = {} # parameters passed to multiprocessing helper functions @@ -392,7 +393,7 @@ def get_noise_indices( # Operations are parallelized across all CPU processes if prune_method == 'prune_by_class' or prune_method == 'both': if n_jobs > 1: # parallelize - with multiprocessing_context( + with joining_pool( n_jobs, initializer=_init, initargs=(_s, _s_counts, _prune_count_matrix, @@ -417,7 +418,7 @@ def get_noise_indices( if prune_method == 'prune_by_noise_rate' or prune_method == 'both': if n_jobs > 1: # parallelize - with multiprocessing_context( + with joining_pool( n_jobs, initializer=_init, initargs=(_s, _s_counts, _prune_count_matrix,