Skip to content

Commit

Permalink
MAINT: remove pytest.warns(None) warnings in pytest 7 (#1264)
Browse files Browse the repository at this point in the history
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Loïc Estève <loic.esteve@ymail.com>
  • Loading branch information
3 people committed Sep 16, 2022
1 parent 067ed4f commit 8aca6f4
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 17 deletions.
3 changes: 2 additions & 1 deletion joblib/test/test_dask.py
@@ -1,5 +1,6 @@
from __future__ import print_function, division, absolute_import
import os
import warnings

import pytest
from random import random
Expand Down Expand Up @@ -475,7 +476,7 @@ def func_using_joblib_parallel():
# pytest.warns(UserWarning)) make the test hang. Work-around: return
# the warning record to the client and the warning check is done
# client-side.
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
Parallel(n_jobs=2, backend=backend)(
delayed(inc)(i) for i in range(10))

Expand Down
4 changes: 2 additions & 2 deletions joblib/test/test_numpy_pickle.py
Expand Up @@ -300,7 +300,7 @@ def test_cache_size_warning(tmpdir, cache_size):
a = rnd.random_sample((10, 2))

warnings.simplefilter("always")
with warns(None) as warninfo:
with warnings.catch_warnings(record=True) as warninfo:
numpy_pickle.dump(a, filename, cache_size=cache_size)
expected_nb_warnings = 1 if cache_size is not None else 0
assert len(warninfo) == expected_nb_warnings
Expand Down Expand Up @@ -385,7 +385,7 @@ def _check_pickle(filename, expected_list, mmap_mode=None):
py_version_used_for_writing, 4)
if pickle_reading_protocol >= pickle_writing_protocol:
try:
with warns(None) as warninfo:
with warnings.catch_warnings(record=True) as warninfo:
warnings.simplefilter('always')
warnings.filterwarnings(
'ignore', module='numpy',
Expand Down
30 changes: 16 additions & 14 deletions joblib/test/test_parallel.py
Expand Up @@ -10,6 +10,7 @@
import sys
import time
import mmap
import warnings
import threading
from traceback import format_exception
from math import sqrt
Expand All @@ -20,8 +21,6 @@
import warnings
import pytest

from importlib import reload

import joblib
from joblib import parallel
from joblib import dump, load
Expand All @@ -31,7 +30,7 @@
from joblib.test.common import np, with_numpy
from joblib.test.common import with_multiprocessing
from joblib.testing import (parametrize, raises, check_subprocess_call,
skipif, SkipTest, warns)
skipif, warns)

if mp is not None:
# Loky is not available if multiprocessing is not
Expand Down Expand Up @@ -181,7 +180,7 @@ def test_main_thread_renamed_no_warning(backend, monkeypatch):
monkeypatch.setattr(target=threading.current_thread(), name='name',
value='some_new_name_for_the_main_thread')

with warns(None) as warninfo:
with warnings.catch_warnings(record=True) as warninfo:
results = Parallel(n_jobs=2, backend=backend)(
delayed(square)(x) for x in range(3))
assert results == [0, 1, 4]
Expand All @@ -197,18 +196,21 @@ def test_main_thread_renamed_no_warning(backend, monkeypatch):


def _assert_warning_nested(backend, inner_n_jobs, expected):
with warnings.catch_warnings(record=True) as records:
with warnings.catch_warnings(record=True) as warninfo:
warnings.simplefilter("always")
parallel_func(backend=backend, inner_n_jobs=inner_n_jobs)

messages = [w.message for w in records]
warninfo = [w.message for w in warninfo]
if expected:
# with threading, we might see more that one records
if messages:
return 'backed parallel loops cannot' in messages[0].args[0]
# with threading, we might see more that one warninfo
if warninfo:
return (
len(warninfo) == 1 and
'backed parallel loops cannot' in warninfo[0].args[0]
)
return False
else:
assert not messages
assert not warninfo
return True


Expand Down Expand Up @@ -251,11 +253,11 @@ def test_background_thread_parallelism(backend):
is_run_parallel = [False]

def background_thread(is_run_parallel):
with warns(None) as records:
with warnings.catch_warnings(record=True) as warninfo:
Parallel(n_jobs=2)(
delayed(sleep)(.1) for _ in range(4))
print(len(records))
is_run_parallel[0] = len(records) == 0
print(len(warninfo))
is_run_parallel[0] = len(warninfo) == 0

t = threading.Thread(target=background_thread, args=(is_run_parallel,))
t.start()
Expand Down Expand Up @@ -1180,7 +1182,7 @@ def test_memmap_with_big_offset(tmpdir):


def test_warning_about_timeout_not_supported_by_backend():
with warns(None) as warninfo:
with warnings.catch_warnings(record=True) as warninfo:
Parallel(timeout=1)(delayed(square)(i) for i in range(50))
assert len(warninfo) == 1
w = warninfo[0]
Expand Down

0 comments on commit 8aca6f4

Please sign in to comment.