Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Follow requirement to allow finalizers to run #108

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 16 additions & 15 deletions cleanlab/pruning.py
Expand Up @@ -29,6 +29,7 @@
from sklearn.preprocessing import MultiLabelBinarizer
import multiprocessing
from multiprocessing.sharedctypes import RawArray
from contextlib import contextmanager
import sys
import os
import time
Expand Down Expand Up @@ -56,19 +57,19 @@
# pruning, regardless if noise estimates are larger.
MIN_NUM_PER_CLASS = 5

# For python 2/3 compatibility, define pool context manager
# to support the 'with' statement in Python 2
if sys.version_info[0] == 2:
from contextlib import contextmanager


@contextmanager
def multiprocessing_context(*args, **kwargs):
pool = multiprocessing.Pool(*args, **kwargs)
yield pool
pool.terminate()
else:
multiprocessing_context = multiprocessing.Pool
# Coverage testing registers finalizers that don't run if
# `multiprocessing.Pool.terminate()` is called
# (`multiprocessing.Pool.__exit__()` also calls `terminate()`), so we use this
# alternative context manager that calls `join()` to ensure finalizers are run
# reliably.
#
# See https://pytest-cov.readthedocs.io/en/latest/subprocess-support.html#if-you-use-multiprocessing-pool
@contextmanager
def joining_pool(*args, **kwargs):
pool = multiprocessing.Pool(*args, **kwargs)
yield pool
pool.close()
pool.join()

# Globals to be shared across threads in multiprocessing
mp_params = {} # parameters passed to multiprocessing helper functions
Expand Down Expand Up @@ -392,7 +393,7 @@ def get_noise_indices(
# Operations are parallelized across all CPU processes
if prune_method == 'prune_by_class' or prune_method == 'both':
if n_jobs > 1: # parallelize
with multiprocessing_context(
with joining_pool(
n_jobs,
initializer=_init,
initargs=(_s, _s_counts, _prune_count_matrix,
Expand All @@ -417,7 +418,7 @@ def get_noise_indices(

if prune_method == 'prune_by_noise_rate' or prune_method == 'both':
if n_jobs > 1: # parallelize
with multiprocessing_context(
with joining_pool(
n_jobs,
initializer=_init,
initargs=(_s, _s_counts, _prune_count_matrix,
Expand Down