diff --git a/joblib/test/test_dask.py b/joblib/test/test_dask.py index a882834dc..9f072a128 100644 --- a/joblib/test/test_dask.py +++ b/joblib/test/test_dask.py @@ -1,5 +1,6 @@ from __future__ import print_function, division, absolute_import import os +import warnings import pytest from random import random @@ -475,7 +476,7 @@ def func_using_joblib_parallel(): # pytest.warns(UserWarning)) make the test hang. Work-around: return # the warning record to the client and the warning check is done # client-side. - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: Parallel(n_jobs=2, backend=backend)( delayed(inc)(i) for i in range(10)) diff --git a/joblib/test/test_numpy_pickle.py b/joblib/test/test_numpy_pickle.py index cce16e381..a6776f634 100644 --- a/joblib/test/test_numpy_pickle.py +++ b/joblib/test/test_numpy_pickle.py @@ -300,7 +300,7 @@ def test_cache_size_warning(tmpdir, cache_size): a = rnd.random_sample((10, 2)) warnings.simplefilter("always") - with warns(None) as warninfo: + with warnings.catch_warnings(record=True) as warninfo: numpy_pickle.dump(a, filename, cache_size=cache_size) expected_nb_warnings = 1 if cache_size is not None else 0 assert len(warninfo) == expected_nb_warnings @@ -385,7 +385,7 @@ def _check_pickle(filename, expected_list, mmap_mode=None): py_version_used_for_writing, 4) if pickle_reading_protocol >= pickle_writing_protocol: try: - with warns(None) as warninfo: + with warnings.catch_warnings(record=True) as warninfo: warnings.simplefilter('always') warnings.filterwarnings( 'ignore', module='numpy', diff --git a/joblib/test/test_parallel.py b/joblib/test/test_parallel.py index 884d9d76b..906d43629 100644 --- a/joblib/test/test_parallel.py +++ b/joblib/test/test_parallel.py @@ -10,6 +10,7 @@ import sys import time import mmap +import warnings import threading from traceback import format_exception from math import sqrt @@ -20,8 +21,6 @@ import warnings import pytest -from importlib import reload - import joblib from joblib import parallel from joblib import dump, load @@ -31,7 +30,7 @@ from joblib.test.common import np, with_numpy from joblib.test.common import with_multiprocessing from joblib.testing import (parametrize, raises, check_subprocess_call, - skipif, SkipTest, warns) + skipif, warns) if mp is not None: # Loky is not available if multiprocessing is not @@ -181,7 +180,7 @@ def test_main_thread_renamed_no_warning(backend, monkeypatch): monkeypatch.setattr(target=threading.current_thread(), name='name', value='some_new_name_for_the_main_thread') - with warns(None) as warninfo: + with warnings.catch_warnings(record=True) as warninfo: results = Parallel(n_jobs=2, backend=backend)( delayed(square)(x) for x in range(3)) assert results == [0, 1, 4] @@ -197,18 +196,21 @@ def test_main_thread_renamed_no_warning(backend, monkeypatch): def _assert_warning_nested(backend, inner_n_jobs, expected): - with warnings.catch_warnings(record=True) as records: + with warnings.catch_warnings(record=True) as warninfo: warnings.simplefilter("always") parallel_func(backend=backend, inner_n_jobs=inner_n_jobs) - messages = [w.message for w in records] + warninfo = [w.message for w in warninfo] if expected: - # with threading, we might see more that one records - if messages: - return 'backed parallel loops cannot' in messages[0].args[0] + # with threading, we might see more that one warninfo + if warninfo: + return ( + len(warninfo) == 1 and + 'backed parallel loops cannot' in warninfo[0].args[0] + ) return False else: - assert not messages + assert not warninfo return True @@ -251,11 +253,11 @@ def test_background_thread_parallelism(backend): is_run_parallel = [False] def background_thread(is_run_parallel): - with warns(None) as records: + with warnings.catch_warnings(record=True) as warninfo: Parallel(n_jobs=2)( delayed(sleep)(.1) for _ in range(4)) - print(len(records)) - is_run_parallel[0] = len(records) == 0 + print(len(warninfo)) + is_run_parallel[0] = len(warninfo) == 0 t = threading.Thread(target=background_thread, args=(is_run_parallel,)) t.start() @@ -1180,7 +1182,7 @@ def test_memmap_with_big_offset(tmpdir): def test_warning_about_timeout_not_supported_by_backend(): - with warns(None) as warninfo: + with warnings.catch_warnings(record=True) as warninfo: Parallel(timeout=1)(delayed(square)(i) for i in range(50)) assert len(warninfo) == 1 w = warninfo[0]