From 3e6f0f34ff97787897506e93cd5ddcdd17133887 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Mon, 25 Nov 2019 16:49:31 +0200 Subject: [PATCH 1/5] Cleanup unhelpful alias _AST_FLAG Also replace one direct call to `compile` with this flag with the equivalent wrapper `ast.parse`. This function can have a more precise type. --- src/_pytest/_code/source.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/_pytest/_code/source.py b/src/_pytest/_code/source.py index d7cef683d7a..ee3f7cb148a 100644 --- a/src/_pytest/_code/source.py +++ b/src/_pytest/_code/source.py @@ -5,7 +5,6 @@ import textwrap import tokenize import warnings -from ast import PyCF_ONLY_AST as _AST_FLAG from bisect import bisect_right from types import FrameType from typing import Iterator @@ -196,7 +195,7 @@ def compile( newex.text = ex.text raise newex else: - if flag & _AST_FLAG: + if flag & ast.PyCF_ONLY_AST: return co lines = [(x + "\n") for x in self.lines] # Type ignored because linecache.cache is private. @@ -321,7 +320,7 @@ def getstatementrange_ast( # don't produce duplicate warnings when compiling source to find ast with warnings.catch_warnings(): warnings.simplefilter("ignore") - astnode = compile(content, "source", "exec", _AST_FLAG) + astnode = ast.parse(content, "source", "exec") start, end = get_statement_startend2(lineno, astnode) # we need to correct the end: From 0c247be76932e9cf066c091136074f7ac8ed3c05 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Mon, 25 Nov 2019 17:20:54 +0200 Subject: [PATCH 2/5] Add a few missing type annotations in _pytest._code These are more "dirty" than the previous batch (that's why they were left out). The trouble is that `compile` can return either a code object or an AST depending on a flag, so we need to add an overload to make the common case Union free. But it's still worthwhile. --- src/_pytest/_code/code.py | 4 +- src/_pytest/_code/source.py | 77 ++++++++++++++++++++++++++++++++----- testing/code/test_source.py | 16 ++++++++ 3 files changed, 86 insertions(+), 11 deletions(-) diff --git a/src/_pytest/_code/code.py b/src/_pytest/_code/code.py index 14428c8854d..55c9e910036 100644 --- a/src/_pytest/_code/code.py +++ b/src/_pytest/_code/code.py @@ -67,7 +67,7 @@ def __ne__(self, other): return not self == other @property - def path(self): + def path(self) -> Union[py.path.local, str]: """ return a path object pointing to source code (note that it might not point to an actually existing file). """ try: @@ -335,7 +335,7 @@ def cut( (path is None or codepath == path) and ( excludepath is None - or not hasattr(codepath, "relto") + or not isinstance(codepath, py.path.local) or not codepath.relto(excludepath) ) and (lineno is None or x.lineno == lineno) diff --git a/src/_pytest/_code/source.py b/src/_pytest/_code/source.py index ee3f7cb148a..67c74143f55 100644 --- a/src/_pytest/_code/source.py +++ b/src/_pytest/_code/source.py @@ -6,6 +6,7 @@ import tokenize import warnings from bisect import bisect_right +from types import CodeType from types import FrameType from typing import Iterator from typing import List @@ -17,6 +18,10 @@ import py from _pytest.compat import overload +from _pytest.compat import TYPE_CHECKING + +if TYPE_CHECKING: + from typing_extensions import Literal class Source: @@ -120,7 +125,7 @@ def getstatement(self, lineno: int) -> "Source": start, end = self.getstatementrange(lineno) return self[start:end] - def getstatementrange(self, lineno: int): + def getstatementrange(self, lineno: int) -> Tuple[int, int]: """ return (start, end) tuple which spans the minimal statement region which containing the given lineno. """ @@ -158,14 +163,36 @@ def isparseable(self, deindent: bool = True) -> bool: def __str__(self) -> str: return "\n".join(self.lines) + @overload def compile( self, - filename=None, - mode="exec", + filename: Optional[str] = ..., + mode: str = ..., + flag: "Literal[0]" = ..., + dont_inherit: int = ..., + _genframe: Optional[FrameType] = ..., + ) -> CodeType: + raise NotImplementedError() + + @overload # noqa: F811 + def compile( # noqa: F811 + self, + filename: Optional[str] = ..., + mode: str = ..., + flag: int = ..., + dont_inherit: int = ..., + _genframe: Optional[FrameType] = ..., + ) -> Union[CodeType, ast.AST]: + raise NotImplementedError() + + def compile( # noqa: F811 + self, + filename: Optional[str] = None, + mode: str = "exec", flag: int = 0, dont_inherit: int = 0, _genframe: Optional[FrameType] = None, - ): + ) -> Union[CodeType, ast.AST]: """ return compiled code object. if filename is None invent an artificial filename which displays the source/line position of the caller frame. @@ -196,7 +223,9 @@ def compile( raise newex else: if flag & ast.PyCF_ONLY_AST: + assert isinstance(co, ast.AST) return co + assert isinstance(co, CodeType) lines = [(x + "\n") for x in self.lines] # Type ignored because linecache.cache is private. linecache.cache[filename] = (1, None, lines, filename) # type: ignore @@ -208,7 +237,35 @@ def compile( # -def compile_(source, filename=None, mode="exec", flags: int = 0, dont_inherit: int = 0): +@overload +def compile_( + source: Union[str, bytes, ast.mod, ast.AST], + filename: Optional[str] = ..., + mode: str = ..., + flags: "Literal[0]" = ..., + dont_inherit: int = ..., +) -> CodeType: + raise NotImplementedError() + + +@overload # noqa: F811 +def compile_( # noqa: F811 + source: Union[str, bytes, ast.mod, ast.AST], + filename: Optional[str] = ..., + mode: str = ..., + flags: int = ..., + dont_inherit: int = ..., +) -> Union[CodeType, ast.AST]: + raise NotImplementedError() + + +def compile_( # noqa: F811 + source: Union[str, bytes, ast.mod, ast.AST], + filename: Optional[str] = None, + mode: str = "exec", + flags: int = 0, + dont_inherit: int = 0, +) -> Union[CodeType, ast.AST]: """ compile the given source to a raw code object, and maintain an internal cache which allows later retrieval of the source code for the code object @@ -216,14 +273,16 @@ def compile_(source, filename=None, mode="exec", flags: int = 0, dont_inherit: i """ if isinstance(source, ast.AST): # XXX should Source support having AST? - return compile(source, filename, mode, flags, dont_inherit) + assert filename is not None + co = compile(source, filename, mode, flags, dont_inherit) + assert isinstance(co, (CodeType, ast.AST)) + return co _genframe = sys._getframe(1) # the caller s = Source(source) - co = s.compile(filename, mode, flags, _genframe=_genframe) - return co + return s.compile(filename, mode, flags, _genframe=_genframe) -def getfslineno(obj): +def getfslineno(obj) -> Tuple[Union[str, py.path.local], int]: """ Return source location (path, lineno) for the given object. If the source cannot be determined return ("", -1). diff --git a/testing/code/test_source.py b/testing/code/test_source.py index 1390d8b0ac5..030e6067625 100644 --- a/testing/code/test_source.py +++ b/testing/code/test_source.py @@ -4,10 +4,13 @@ import ast import inspect import sys +from types import CodeType from typing import Any from typing import Dict from typing import Optional +import py + import _pytest._code import pytest from _pytest._code import Source @@ -147,6 +150,10 @@ def test_getrange(self) -> None: assert len(x.lines) == 2 assert str(x) == "def f(x):\n pass" + def test_getrange_step_not_supported(self) -> None: + with pytest.raises(IndexError, match=r"step"): + self.source[::2] + def test_getline(self) -> None: x = self.source[0] assert x == "def f(x):" @@ -449,6 +456,14 @@ def test_idem_compile_and_getsource() -> None: assert src == expected +def test_compile_ast() -> None: + # We don't necessarily want to support this. + # This test was added just for coverage. + stmt = ast.parse("def x(): pass") + co = _pytest._code.compile(stmt, filename="foo.py") + assert isinstance(co, CodeType) + + def test_findsource_fallback() -> None: from _pytest._code.source import findsource @@ -488,6 +503,7 @@ def f(x) -> None: fspath, lineno = getfslineno(f) + assert isinstance(fspath, py.path.local) assert fspath.basename == "test_source.py" assert lineno == f.__code__.co_firstlineno - 1 # see findsource From 0b603156b92b258dd971b3d20cf2d158783efd66 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Tue, 3 Dec 2019 14:25:28 +0200 Subject: [PATCH 3/5] Fix check_untyped_defs errors in test_pytester --- testing/test_pytester.py | 134 ++++++++++++++++++++------------------- 1 file changed, 69 insertions(+), 65 deletions(-) diff --git a/testing/test_pytester.py b/testing/test_pytester.py index 0015e489a92..6c8c933d7e9 100644 --- a/testing/test_pytester.py +++ b/testing/test_pytester.py @@ -2,6 +2,7 @@ import subprocess import sys import time +from typing import List import py.path @@ -9,6 +10,7 @@ import pytest from _pytest.config import PytestPluginManager from _pytest.main import ExitCode +from _pytest.outcomes import Failed from _pytest.pytester import CwdSnapshot from _pytest.pytester import HookRecorder from _pytest.pytester import LineMatcher @@ -16,7 +18,7 @@ from _pytest.pytester import SysPathsSnapshot -def test_make_hook_recorder(testdir): +def test_make_hook_recorder(testdir) -> None: item = testdir.getitem("def test_func(): pass") recorder = testdir.make_hook_recorder(item.config.pluginmanager) assert not recorder.getfailures() @@ -36,23 +38,23 @@ class rep: failures = recorder.getfailures() assert failures == [rep] - class rep: + class rep2: excinfo = None passed = False failed = False skipped = True when = "call" - rep.passed = False - rep.skipped = True - recorder.hook.pytest_runtest_logreport(report=rep) + rep2.passed = False + rep2.skipped = True + recorder.hook.pytest_runtest_logreport(report=rep2) modcol = testdir.getmodulecol("") - rep = modcol.config.hook.pytest_make_collect_report(collector=modcol) - rep.passed = False - rep.failed = True - rep.skipped = False - recorder.hook.pytest_collectreport(report=rep) + rep3 = modcol.config.hook.pytest_make_collect_report(collector=modcol) + rep3.passed = False + rep3.failed = True + rep3.skipped = False + recorder.hook.pytest_collectreport(report=rep3) passed, skipped, failed = recorder.listoutcomes() assert not passed and skipped and failed @@ -65,17 +67,17 @@ class rep: recorder.unregister() recorder.clear() - recorder.hook.pytest_runtest_logreport(report=rep) + recorder.hook.pytest_runtest_logreport(report=rep3) pytest.raises(ValueError, recorder.getfailures) -def test_parseconfig(testdir): +def test_parseconfig(testdir) -> None: config1 = testdir.parseconfig() config2 = testdir.parseconfig() assert config2 is not config1 -def test_testdir_runs_with_plugin(testdir): +def test_testdir_runs_with_plugin(testdir) -> None: testdir.makepyfile( """ pytest_plugins = "pytester" @@ -87,7 +89,7 @@ def test_hello(testdir): result.assert_outcomes(passed=1) -def test_runresult_assertion_on_xfail(testdir): +def test_runresult_assertion_on_xfail(testdir) -> None: testdir.makepyfile( """ import pytest @@ -104,7 +106,7 @@ def test_potato(): assert result.ret == 0 -def test_runresult_assertion_on_xpassed(testdir): +def test_runresult_assertion_on_xpassed(testdir) -> None: testdir.makepyfile( """ import pytest @@ -121,7 +123,7 @@ def test_potato(): assert result.ret == 0 -def test_xpassed_with_strict_is_considered_a_failure(testdir): +def test_xpassed_with_strict_is_considered_a_failure(testdir) -> None: testdir.makepyfile( """ import pytest @@ -154,13 +156,13 @@ def pytest_xyz(arg): def pytest_xyz_noarg(): "x" - apimod.pytest_xyz = pytest_xyz - apimod.pytest_xyz_noarg = pytest_xyz_noarg + apimod.pytest_xyz = pytest_xyz # type: ignore + apimod.pytest_xyz_noarg = pytest_xyz_noarg # type: ignore return apiclass, apimod @pytest.mark.parametrize("holder", make_holder()) -def test_hookrecorder_basic(holder): +def test_hookrecorder_basic(holder) -> None: pm = PytestPluginManager() pm.add_hookspecs(holder) rec = HookRecorder(pm) @@ -168,17 +170,17 @@ def test_hookrecorder_basic(holder): call = rec.popcall("pytest_xyz") assert call.arg == 123 assert call._name == "pytest_xyz" - pytest.raises(pytest.fail.Exception, rec.popcall, "abc") + pytest.raises(Failed, rec.popcall, "abc") pm.hook.pytest_xyz_noarg() call = rec.popcall("pytest_xyz_noarg") assert call._name == "pytest_xyz_noarg" -def test_makepyfile_unicode(testdir): +def test_makepyfile_unicode(testdir) -> None: testdir.makepyfile(chr(0xFFFD)) -def test_makepyfile_utf8(testdir): +def test_makepyfile_utf8(testdir) -> None: """Ensure makepyfile accepts utf-8 bytes as input (#2738)""" utf8_contents = """ def setup_function(function): @@ -189,7 +191,7 @@ def setup_function(function): class TestInlineRunModulesCleanup: - def test_inline_run_test_module_not_cleaned_up(self, testdir): + def test_inline_run_test_module_not_cleaned_up(self, testdir) -> None: test_mod = testdir.makepyfile("def test_foo(): assert True") result = testdir.inline_run(str(test_mod)) assert result.ret == ExitCode.OK @@ -200,9 +202,9 @@ def test_inline_run_test_module_not_cleaned_up(self, testdir): def spy_factory(self): class SysModulesSnapshotSpy: - instances = [] + instances = [] # type: List[SysModulesSnapshotSpy] - def __init__(self, preserve=None): + def __init__(self, preserve=None) -> None: SysModulesSnapshotSpy.instances.append(self) self._spy_restore_count = 0 self._spy_preserve = preserve @@ -216,7 +218,7 @@ def restore(self): def test_inline_run_taking_and_restoring_a_sys_modules_snapshot( self, testdir, monkeypatch - ): + ) -> None: spy_factory = self.spy_factory() monkeypatch.setattr(pytester, "SysModulesSnapshot", spy_factory) testdir.syspathinsert() @@ -237,7 +239,7 @@ def test_foo(): import import2""" def test_inline_run_sys_modules_snapshot_restore_preserving_modules( self, testdir, monkeypatch - ): + ) -> None: spy_factory = self.spy_factory() monkeypatch.setattr(pytester, "SysModulesSnapshot", spy_factory) test_mod = testdir.makepyfile("def test_foo(): pass") @@ -248,7 +250,7 @@ def test_inline_run_sys_modules_snapshot_restore_preserving_modules( assert spy._spy_preserve("zope.interface") assert spy._spy_preserve("zopelicious") - def test_external_test_module_imports_not_cleaned_up(self, testdir): + def test_external_test_module_imports_not_cleaned_up(self, testdir) -> None: testdir.syspathinsert() testdir.makepyfile(imported="data = 'you son of a silly person'") import imported @@ -263,7 +265,7 @@ def test_foo(): assert imported.data == 42 -def test_assert_outcomes_after_pytest_error(testdir): +def test_assert_outcomes_after_pytest_error(testdir) -> None: testdir.makepyfile("def test_foo(): assert True") result = testdir.runpytest("--unexpected-argument") @@ -271,7 +273,7 @@ def test_assert_outcomes_after_pytest_error(testdir): result.assert_outcomes(passed=0) -def test_cwd_snapshot(tmpdir): +def test_cwd_snapshot(tmpdir) -> None: foo = tmpdir.ensure("foo", dir=1) bar = tmpdir.ensure("bar", dir=1) foo.chdir() @@ -285,16 +287,16 @@ def test_cwd_snapshot(tmpdir): class TestSysModulesSnapshot: key = "my-test-module" - def test_remove_added(self): + def test_remove_added(self) -> None: original = dict(sys.modules) assert self.key not in sys.modules snapshot = SysModulesSnapshot() - sys.modules[self.key] = "something" + sys.modules[self.key] = "something" # type: ignore assert self.key in sys.modules snapshot.restore() assert sys.modules == original - def test_add_removed(self, monkeypatch): + def test_add_removed(self, monkeypatch) -> None: assert self.key not in sys.modules monkeypatch.setitem(sys.modules, self.key, "something") assert self.key in sys.modules @@ -305,17 +307,17 @@ def test_add_removed(self, monkeypatch): snapshot.restore() assert sys.modules == original - def test_restore_reloaded(self, monkeypatch): + def test_restore_reloaded(self, monkeypatch) -> None: assert self.key not in sys.modules monkeypatch.setitem(sys.modules, self.key, "something") assert self.key in sys.modules original = dict(sys.modules) snapshot = SysModulesSnapshot() - sys.modules[self.key] = "something else" + sys.modules[self.key] = "something else" # type: ignore snapshot.restore() assert sys.modules == original - def test_preserve_modules(self, monkeypatch): + def test_preserve_modules(self, monkeypatch) -> None: key = [self.key + str(i) for i in range(3)] assert not any(k in sys.modules for k in key) for i, k in enumerate(key): @@ -326,17 +328,17 @@ def preserve(name): return name in (key[0], key[1], "some-other-key") snapshot = SysModulesSnapshot(preserve=preserve) - sys.modules[key[0]] = original[key[0]] = "something else0" - sys.modules[key[1]] = original[key[1]] = "something else1" - sys.modules[key[2]] = "something else2" + sys.modules[key[0]] = original[key[0]] = "something else0" # type: ignore + sys.modules[key[1]] = original[key[1]] = "something else1" # type: ignore + sys.modules[key[2]] = "something else2" # type: ignore snapshot.restore() assert sys.modules == original - def test_preserve_container(self, monkeypatch): + def test_preserve_container(self, monkeypatch) -> None: original = dict(sys.modules) assert self.key not in original replacement = dict(sys.modules) - replacement[self.key] = "life of brian" + replacement[self.key] = "life of brian" # type: ignore snapshot = SysModulesSnapshot() monkeypatch.setattr(sys, "modules", replacement) snapshot.restore() @@ -349,10 +351,10 @@ class TestSysPathsSnapshot: other_path = {"path": "meta_path", "meta_path": "path"} @staticmethod - def path(n): + def path(n: int) -> str: return "my-dirty-little-secret-" + str(n) - def test_restore(self, monkeypatch, path_type): + def test_restore(self, monkeypatch, path_type) -> None: other_path_type = self.other_path[path_type] for i in range(10): assert self.path(i) not in getattr(sys, path_type) @@ -375,12 +377,12 @@ def test_restore(self, monkeypatch, path_type): assert getattr(sys, path_type) == original assert getattr(sys, other_path_type) == original_other - def test_preserve_container(self, monkeypatch, path_type): + def test_preserve_container(self, monkeypatch, path_type) -> None: other_path_type = self.other_path[path_type] original_data = list(getattr(sys, path_type)) original_other = getattr(sys, other_path_type) original_other_data = list(original_other) - new = [] + new = [] # type: List[object] snapshot = SysPathsSnapshot() monkeypatch.setattr(sys, path_type, new) snapshot.restore() @@ -390,7 +392,7 @@ def test_preserve_container(self, monkeypatch, path_type): assert getattr(sys, other_path_type) == original_other_data -def test_testdir_subprocess(testdir): +def test_testdir_subprocess(testdir) -> None: testfile = testdir.makepyfile("def test_one(): pass") assert testdir.runpytest_subprocess(testfile).ret == 0 @@ -416,17 +418,17 @@ def test_one(): assert result.ret == 0 -def test_unicode_args(testdir): +def test_unicode_args(testdir) -> None: result = testdir.runpytest("-k", "💩") assert result.ret == ExitCode.NO_TESTS_COLLECTED -def test_testdir_run_no_timeout(testdir): +def test_testdir_run_no_timeout(testdir) -> None: testfile = testdir.makepyfile("def test_no_timeout(): pass") assert testdir.runpytest_subprocess(testfile).ret == ExitCode.OK -def test_testdir_run_with_timeout(testdir): +def test_testdir_run_with_timeout(testdir) -> None: testfile = testdir.makepyfile("def test_no_timeout(): pass") timeout = 120 @@ -440,7 +442,7 @@ def test_testdir_run_with_timeout(testdir): assert duration < timeout -def test_testdir_run_timeout_expires(testdir): +def test_testdir_run_timeout_expires(testdir) -> None: testfile = testdir.makepyfile( """ import time @@ -452,7 +454,7 @@ def test_timeout(): testdir.runpytest_subprocess(testfile, timeout=1) -def test_linematcher_with_nonlist(): +def test_linematcher_with_nonlist() -> None: """Test LineMatcher with regard to passing in a set (accidentally).""" lm = LineMatcher([]) @@ -467,10 +469,11 @@ def test_linematcher_with_nonlist(): assert lm._getlines(set()) == set() -def test_linematcher_match_failure(): +def test_linematcher_match_failure() -> None: lm = LineMatcher(["foo", "foo", "bar"]) - with pytest.raises(pytest.fail.Exception) as e: + with pytest.raises(Failed) as e: lm.fnmatch_lines(["foo", "f*", "baz"]) + assert e.value.msg is not None assert e.value.msg.splitlines() == [ "exact match: 'foo'", "fnmatch: 'f*'", @@ -481,8 +484,9 @@ def test_linematcher_match_failure(): ] lm = LineMatcher(["foo", "foo", "bar"]) - with pytest.raises(pytest.fail.Exception) as e: + with pytest.raises(Failed) as e: lm.re_match_lines(["foo", "^f.*", "baz"]) + assert e.value.msg is not None assert e.value.msg.splitlines() == [ "exact match: 'foo'", "re.match: '^f.*'", @@ -494,7 +498,7 @@ def test_linematcher_match_failure(): @pytest.mark.parametrize("function", ["no_fnmatch_line", "no_re_match_line"]) -def test_no_matching(function): +def test_no_matching(function) -> None: if function == "no_fnmatch_line": good_pattern = "*.py OK*" bad_pattern = "*X.py OK*" @@ -515,7 +519,7 @@ def test_no_matching(function): # check the function twice to ensure we don't accumulate the internal buffer for i in range(2): - with pytest.raises(pytest.fail.Exception) as e: + with pytest.raises(Failed) as e: func = getattr(lm, function) func(good_pattern) obtained = str(e.value).splitlines() @@ -542,15 +546,15 @@ def test_no_matching(function): func(bad_pattern) # bad pattern does not match any line: passes -def test_no_matching_after_match(): +def test_no_matching_after_match() -> None: lm = LineMatcher(["1", "2", "3"]) lm.fnmatch_lines(["1", "3"]) - with pytest.raises(pytest.fail.Exception) as e: + with pytest.raises(Failed) as e: lm.no_fnmatch_line("*") assert str(e.value).splitlines() == ["fnmatch: '*'", " with: '1'"] -def test_pytester_addopts(request, monkeypatch): +def test_pytester_addopts(request, monkeypatch) -> None: monkeypatch.setenv("PYTEST_ADDOPTS", "--orig-unused") testdir = request.getfixturevalue("testdir") @@ -563,7 +567,7 @@ def test_pytester_addopts(request, monkeypatch): assert os.environ["PYTEST_ADDOPTS"] == "--orig-unused" -def test_run_stdin(testdir): +def test_run_stdin(testdir) -> None: with pytest.raises(testdir.TimeoutExpired): testdir.run( sys.executable, @@ -593,7 +597,7 @@ def test_run_stdin(testdir): assert result.ret == 0 -def test_popen_stdin_pipe(testdir): +def test_popen_stdin_pipe(testdir) -> None: proc = testdir.popen( [sys.executable, "-c", "import sys; print(sys.stdin.read())"], stdout=subprocess.PIPE, @@ -607,7 +611,7 @@ def test_popen_stdin_pipe(testdir): assert proc.returncode == 0 -def test_popen_stdin_bytes(testdir): +def test_popen_stdin_bytes(testdir) -> None: proc = testdir.popen( [sys.executable, "-c", "import sys; print(sys.stdin.read())"], stdout=subprocess.PIPE, @@ -620,7 +624,7 @@ def test_popen_stdin_bytes(testdir): assert proc.returncode == 0 -def test_popen_default_stdin_stderr_and_stdin_None(testdir): +def test_popen_default_stdin_stderr_and_stdin_None(testdir) -> None: # stdout, stderr default to pipes, # stdin can be None to not close the pipe, avoiding # "ValueError: flush of closed file" with `communicate()`. @@ -639,7 +643,7 @@ def test_popen_default_stdin_stderr_and_stdin_None(testdir): assert proc.returncode == 0 -def test_spawn_uses_tmphome(testdir): +def test_spawn_uses_tmphome(testdir) -> None: import os tmphome = str(testdir.tmpdir) @@ -665,7 +669,7 @@ def test(): assert child.wait() == 0, out.decode("utf8") -def test_run_result_repr(): +def test_run_result_repr() -> None: outlines = ["some", "normal", "output"] errlines = ["some", "nasty", "errors", "happened"] From 3d2680b31b29119f8df08e9757403c9774f158e0 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Tue, 3 Dec 2019 14:34:41 +0200 Subject: [PATCH 4/5] Fix type of pytest.warns, and fix check_untyped_defs in test_recwarn The expected_warning is optional. --- src/_pytest/recwarn.py | 6 +- testing/test_recwarn.py | 118 +++++++++++++++++++++------------------- 2 files changed, 64 insertions(+), 60 deletions(-) diff --git a/src/_pytest/recwarn.py b/src/_pytest/recwarn.py index 956a9078314..c57c94b1cb1 100644 --- a/src/_pytest/recwarn.py +++ b/src/_pytest/recwarn.py @@ -57,7 +57,7 @@ def deprecated_call(func=None, *args, **kwargs): @overload def warns( - expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]], + expected_warning: Optional[Union["Type[Warning]", Tuple["Type[Warning]", ...]]], *, match: "Optional[Union[str, Pattern]]" = ... ) -> "WarningsChecker": @@ -66,7 +66,7 @@ def warns( @overload # noqa: F811 def warns( # noqa: F811 - expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]], + expected_warning: Optional[Union["Type[Warning]", Tuple["Type[Warning]", ...]]], func: Callable, *args: Any, match: Optional[Union[str, "Pattern"]] = ..., @@ -76,7 +76,7 @@ def warns( # noqa: F811 def warns( # noqa: F811 - expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]], + expected_warning: Optional[Union["Type[Warning]", Tuple["Type[Warning]", ...]]], *args: Any, match: Optional[Union[str, "Pattern"]] = None, **kwargs: Any diff --git a/testing/test_recwarn.py b/testing/test_recwarn.py index 208dc5b449d..bbcefaddf7d 100644 --- a/testing/test_recwarn.py +++ b/testing/test_recwarn.py @@ -1,17 +1,19 @@ import re import warnings +from typing import Optional import pytest +from _pytest.outcomes import Failed from _pytest.recwarn import WarningsRecorder -def test_recwarn_stacklevel(recwarn): +def test_recwarn_stacklevel(recwarn: WarningsRecorder) -> None: warnings.warn("hello") warn = recwarn.pop() assert warn.filename == __file__ -def test_recwarn_functional(testdir): +def test_recwarn_functional(testdir) -> None: testdir.makepyfile( """ import warnings @@ -26,7 +28,7 @@ def test_method(recwarn): class TestWarningsRecorderChecker: - def test_recording(self): + def test_recording(self) -> None: rec = WarningsRecorder() with rec: assert not rec.list @@ -42,23 +44,23 @@ def test_recording(self): assert values is rec.list pytest.raises(AssertionError, rec.pop) - def test_warn_stacklevel(self): + def test_warn_stacklevel(self) -> None: """#4243""" rec = WarningsRecorder() with rec: warnings.warn("test", DeprecationWarning, 2) - def test_typechecking(self): + def test_typechecking(self) -> None: from _pytest.recwarn import WarningsChecker with pytest.raises(TypeError): - WarningsChecker(5) + WarningsChecker(5) # type: ignore with pytest.raises(TypeError): - WarningsChecker(("hi", RuntimeWarning)) + WarningsChecker(("hi", RuntimeWarning)) # type: ignore with pytest.raises(TypeError): - WarningsChecker([DeprecationWarning, RuntimeWarning]) + WarningsChecker([DeprecationWarning, RuntimeWarning]) # type: ignore - def test_invalid_enter_exit(self): + def test_invalid_enter_exit(self) -> None: # wrap this test in WarningsRecorder to ensure warning state gets reset with WarningsRecorder(): with pytest.raises(RuntimeError): @@ -75,50 +77,52 @@ def test_invalid_enter_exit(self): class TestDeprecatedCall: """test pytest.deprecated_call()""" - def dep(self, i, j=None): + def dep(self, i: int, j: Optional[int] = None) -> int: if i == 0: warnings.warn("is deprecated", DeprecationWarning, stacklevel=1) return 42 - def dep_explicit(self, i): + def dep_explicit(self, i: int) -> None: if i == 0: warnings.warn_explicit( "dep_explicit", category=DeprecationWarning, filename="hello", lineno=3 ) - def test_deprecated_call_raises(self): - with pytest.raises(pytest.fail.Exception, match="No warnings of type"): + def test_deprecated_call_raises(self) -> None: + with pytest.raises(Failed, match="No warnings of type"): pytest.deprecated_call(self.dep, 3, 5) - def test_deprecated_call(self): + def test_deprecated_call(self) -> None: pytest.deprecated_call(self.dep, 0, 5) - def test_deprecated_call_ret(self): + def test_deprecated_call_ret(self) -> None: ret = pytest.deprecated_call(self.dep, 0) assert ret == 42 - def test_deprecated_call_preserves(self): - onceregistry = warnings.onceregistry.copy() - filters = warnings.filters[:] + def test_deprecated_call_preserves(self) -> None: + # Type ignored because `onceregistry` and `filters` are not + # documented API. + onceregistry = warnings.onceregistry.copy() # type: ignore + filters = warnings.filters[:] # type: ignore warn = warnings.warn warn_explicit = warnings.warn_explicit self.test_deprecated_call_raises() self.test_deprecated_call() - assert onceregistry == warnings.onceregistry - assert filters == warnings.filters + assert onceregistry == warnings.onceregistry # type: ignore + assert filters == warnings.filters # type: ignore assert warn is warnings.warn assert warn_explicit is warnings.warn_explicit - def test_deprecated_explicit_call_raises(self): - with pytest.raises(pytest.fail.Exception): + def test_deprecated_explicit_call_raises(self) -> None: + with pytest.raises(Failed): pytest.deprecated_call(self.dep_explicit, 3) - def test_deprecated_explicit_call(self): + def test_deprecated_explicit_call(self) -> None: pytest.deprecated_call(self.dep_explicit, 0) pytest.deprecated_call(self.dep_explicit, 0) @pytest.mark.parametrize("mode", ["context_manager", "call"]) - def test_deprecated_call_no_warning(self, mode): + def test_deprecated_call_no_warning(self, mode) -> None: """Ensure deprecated_call() raises the expected failure when its block/function does not raise a deprecation warning. """ @@ -127,7 +131,7 @@ def f(): pass msg = "No warnings of type (.*DeprecationWarning.*, .*PendingDeprecationWarning.*)" - with pytest.raises(pytest.fail.Exception, match=msg): + with pytest.raises(Failed, match=msg): if mode == "call": pytest.deprecated_call(f) else: @@ -140,7 +144,7 @@ def f(): @pytest.mark.parametrize("mode", ["context_manager", "call"]) @pytest.mark.parametrize("call_f_first", [True, False]) @pytest.mark.filterwarnings("ignore") - def test_deprecated_call_modes(self, warning_type, mode, call_f_first): + def test_deprecated_call_modes(self, warning_type, mode, call_f_first) -> None: """Ensure deprecated_call() captures a deprecation warning as expected inside its block/function. """ @@ -159,7 +163,7 @@ def f(): assert f() == 10 @pytest.mark.parametrize("mode", ["context_manager", "call"]) - def test_deprecated_call_exception_is_raised(self, mode): + def test_deprecated_call_exception_is_raised(self, mode) -> None: """If the block of the code being tested by deprecated_call() raises an exception, it must raise the exception undisturbed. """ @@ -174,7 +178,7 @@ def f(): with pytest.deprecated_call(): f() - def test_deprecated_call_specificity(self): + def test_deprecated_call_specificity(self) -> None: other_warnings = [ Warning, UserWarning, @@ -189,40 +193,40 @@ def test_deprecated_call_specificity(self): def f(): warnings.warn(warning("hi")) - with pytest.raises(pytest.fail.Exception): + with pytest.raises(Failed): pytest.deprecated_call(f) - with pytest.raises(pytest.fail.Exception): + with pytest.raises(Failed): with pytest.deprecated_call(): f() - def test_deprecated_call_supports_match(self): + def test_deprecated_call_supports_match(self) -> None: with pytest.deprecated_call(match=r"must be \d+$"): warnings.warn("value must be 42", DeprecationWarning) - with pytest.raises(pytest.fail.Exception): + with pytest.raises(Failed): with pytest.deprecated_call(match=r"must be \d+$"): warnings.warn("this is not here", DeprecationWarning) class TestWarns: - def test_check_callable(self): + def test_check_callable(self) -> None: source = "warnings.warn('w1', RuntimeWarning)" with pytest.raises(TypeError, match=r".* must be callable"): - pytest.warns(RuntimeWarning, source) + pytest.warns(RuntimeWarning, source) # type: ignore - def test_several_messages(self): + def test_several_messages(self) -> None: # different messages, b/c Python suppresses multiple identical warnings pytest.warns(RuntimeWarning, lambda: warnings.warn("w1", RuntimeWarning)) - with pytest.raises(pytest.fail.Exception): + with pytest.raises(Failed): pytest.warns(UserWarning, lambda: warnings.warn("w2", RuntimeWarning)) pytest.warns(RuntimeWarning, lambda: warnings.warn("w3", RuntimeWarning)) - def test_function(self): + def test_function(self) -> None: pytest.warns( SyntaxWarning, lambda msg: warnings.warn(msg, SyntaxWarning), "syntax" ) - def test_warning_tuple(self): + def test_warning_tuple(self) -> None: pytest.warns( (RuntimeWarning, SyntaxWarning), lambda: warnings.warn("w1", RuntimeWarning) ) @@ -230,21 +234,21 @@ def test_warning_tuple(self): (RuntimeWarning, SyntaxWarning), lambda: warnings.warn("w2", SyntaxWarning) ) pytest.raises( - pytest.fail.Exception, + Failed, lambda: pytest.warns( (RuntimeWarning, SyntaxWarning), lambda: warnings.warn("w3", UserWarning), ), ) - def test_as_contextmanager(self): + def test_as_contextmanager(self) -> None: with pytest.warns(RuntimeWarning): warnings.warn("runtime", RuntimeWarning) with pytest.warns(UserWarning): warnings.warn("user", UserWarning) - with pytest.raises(pytest.fail.Exception) as excinfo: + with pytest.raises(Failed) as excinfo: with pytest.warns(RuntimeWarning): warnings.warn("user", UserWarning) excinfo.match( @@ -252,7 +256,7 @@ def test_as_contextmanager(self): r"The list of emitted warnings is: \[UserWarning\('user',?\)\]." ) - with pytest.raises(pytest.fail.Exception) as excinfo: + with pytest.raises(Failed) as excinfo: with pytest.warns(UserWarning): warnings.warn("runtime", RuntimeWarning) excinfo.match( @@ -260,7 +264,7 @@ def test_as_contextmanager(self): r"The list of emitted warnings is: \[RuntimeWarning\('runtime',?\)\]." ) - with pytest.raises(pytest.fail.Exception) as excinfo: + with pytest.raises(Failed) as excinfo: with pytest.warns(UserWarning): pass excinfo.match( @@ -269,7 +273,7 @@ def test_as_contextmanager(self): ) warning_classes = (UserWarning, FutureWarning) - with pytest.raises(pytest.fail.Exception) as excinfo: + with pytest.raises(Failed) as excinfo: with pytest.warns(warning_classes) as warninfo: warnings.warn("runtime", RuntimeWarning) warnings.warn("import", ImportWarning) @@ -286,14 +290,14 @@ def test_as_contextmanager(self): ) ) - def test_record(self): + def test_record(self) -> None: with pytest.warns(UserWarning) as record: warnings.warn("user", UserWarning) assert len(record) == 1 assert str(record[0].message) == "user" - def test_record_only(self): + def test_record_only(self) -> None: with pytest.warns(None) as record: warnings.warn("user", UserWarning) warnings.warn("runtime", RuntimeWarning) @@ -302,7 +306,7 @@ def test_record_only(self): assert str(record[0].message) == "user" assert str(record[1].message) == "runtime" - def test_record_by_subclass(self): + def test_record_by_subclass(self) -> None: with pytest.warns(Warning) as record: warnings.warn("user", UserWarning) warnings.warn("runtime", RuntimeWarning) @@ -325,7 +329,7 @@ class MyRuntimeWarning(RuntimeWarning): assert str(record[0].message) == "user" assert str(record[1].message) == "runtime" - def test_double_test(self, testdir): + def test_double_test(self, testdir) -> None: """If a test is run again, the warning should still be raised""" testdir.makepyfile( """ @@ -341,32 +345,32 @@ def test(run): result = testdir.runpytest() result.stdout.fnmatch_lines(["*2 passed in*"]) - def test_match_regex(self): + def test_match_regex(self) -> None: with pytest.warns(UserWarning, match=r"must be \d+$"): warnings.warn("value must be 42", UserWarning) - with pytest.raises(pytest.fail.Exception): + with pytest.raises(Failed): with pytest.warns(UserWarning, match=r"must be \d+$"): warnings.warn("this is not here", UserWarning) - with pytest.raises(pytest.fail.Exception): + with pytest.raises(Failed): with pytest.warns(FutureWarning, match=r"must be \d+$"): warnings.warn("value must be 42", UserWarning) - def test_one_from_multiple_warns(self): + def test_one_from_multiple_warns(self) -> None: with pytest.warns(UserWarning, match=r"aaa"): warnings.warn("cccccccccc", UserWarning) warnings.warn("bbbbbbbbbb", UserWarning) warnings.warn("aaaaaaaaaa", UserWarning) - def test_none_of_multiple_warns(self): - with pytest.raises(pytest.fail.Exception): + def test_none_of_multiple_warns(self) -> None: + with pytest.raises(Failed): with pytest.warns(UserWarning, match=r"aaa"): warnings.warn("bbbbbbbbbb", UserWarning) warnings.warn("cccccccccc", UserWarning) @pytest.mark.filterwarnings("ignore") - def test_can_capture_previously_warned(self): + def test_can_capture_previously_warned(self) -> None: def f(): warnings.warn(UserWarning("ohai")) return 10 @@ -375,8 +379,8 @@ def f(): assert pytest.warns(UserWarning, f) == 10 assert pytest.warns(UserWarning, f) == 10 - def test_warns_context_manager_with_kwargs(self): + def test_warns_context_manager_with_kwargs(self) -> None: with pytest.raises(TypeError) as excinfo: - with pytest.warns(UserWarning, foo="bar"): + with pytest.warns(UserWarning, foo="bar"): # type: ignore pass assert "Unexpected keyword arguments" in str(excinfo.value) From 3392be37e1966a2bda933ea5768b864ffdbb1490 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Tue, 3 Dec 2019 14:48:22 +0200 Subject: [PATCH 5/5] Fix check_untyped_defs in test_runner --- src/_pytest/reports.py | 9 ++- testing/test_runner.py | 169 +++++++++++++++++++++++------------------ 2 files changed, 103 insertions(+), 75 deletions(-) diff --git a/src/_pytest/reports.py b/src/_pytest/reports.py index 215c1c3e7c9..79e106a65ad 100644 --- a/src/_pytest/reports.py +++ b/src/_pytest/reports.py @@ -1,5 +1,6 @@ from io import StringIO from pprint import pprint +from typing import Any from typing import List from typing import Optional from typing import Tuple @@ -17,6 +18,7 @@ from _pytest._code.code import ReprLocals from _pytest._code.code import ReprTraceback from _pytest._code.code import TerminalRepr +from _pytest.compat import TYPE_CHECKING from _pytest.nodes import Node from _pytest.outcomes import skip from _pytest.pathlib import Path @@ -41,9 +43,14 @@ class BaseReport: sections = [] # type: List[Tuple[str, str]] nodeid = None # type: str - def __init__(self, **kw): + def __init__(self, **kw: Any) -> None: self.__dict__.update(kw) + if TYPE_CHECKING: + # Can have arbitrary fields given to __init__(). + def __getattr__(self, key: str) -> Any: + raise NotImplementedError() + def toterminal(self, out) -> None: if hasattr(self, "node"): out.line(getslaveinfoline(self.node)) # type: ignore diff --git a/testing/test_runner.py b/testing/test_runner.py index df0b7b0cf6c..ecb60d4bee2 100644 --- a/testing/test_runner.py +++ b/testing/test_runner.py @@ -2,6 +2,9 @@ import os import sys import types +from typing import Dict +from typing import List +from typing import Tuple import py @@ -11,11 +14,17 @@ from _pytest import outcomes from _pytest import reports from _pytest import runner +from _pytest.outcomes import Exit +from _pytest.outcomes import Failed from _pytest.outcomes import OutcomeException +from _pytest.outcomes import Skipped + +if False: # TYPE_CHECKING + from typing import Type class TestSetupState: - def test_setup(self, testdir): + def test_setup(self, testdir) -> None: ss = runner.SetupState() item = testdir.getitem("def test_func(): pass") values = [1] @@ -25,14 +34,14 @@ def test_setup(self, testdir): ss._pop_and_teardown() assert not values - def test_teardown_exact_stack_empty(self, testdir): + def test_teardown_exact_stack_empty(self, testdir) -> None: item = testdir.getitem("def test_func(): pass") ss = runner.SetupState() ss.teardown_exact(item, None) ss.teardown_exact(item, None) ss.teardown_exact(item, None) - def test_setup_fails_and_failure_is_cached(self, testdir): + def test_setup_fails_and_failure_is_cached(self, testdir) -> None: item = testdir.getitem( """ def setup_module(mod): @@ -44,7 +53,7 @@ def test_func(): pass pytest.raises(ValueError, lambda: ss.prepare(item)) pytest.raises(ValueError, lambda: ss.prepare(item)) - def test_teardown_multiple_one_fails(self, testdir): + def test_teardown_multiple_one_fails(self, testdir) -> None: r = [] def fin1(): @@ -66,7 +75,7 @@ def fin3(): assert err.value.args == ("oops",) assert r == ["fin3", "fin1"] - def test_teardown_multiple_fail(self, testdir): + def test_teardown_multiple_fail(self, testdir) -> None: # Ensure the first exception is the one which is re-raised. # Ideally both would be reported however. def fin1(): @@ -83,7 +92,7 @@ def fin2(): ss._callfinalizers(item) assert err.value.args == ("oops2",) - def test_teardown_multiple_scopes_one_fails(self, testdir): + def test_teardown_multiple_scopes_one_fails(self, testdir) -> None: module_teardown = [] def fin_func(): @@ -103,7 +112,7 @@ def fin_module(): class BaseFunctionalTests: - def test_passfunction(self, testdir): + def test_passfunction(self, testdir) -> None: reports = testdir.runitem( """ def test_func(): @@ -116,7 +125,7 @@ def test_func(): assert rep.outcome == "passed" assert not rep.longrepr - def test_failfunction(self, testdir): + def test_failfunction(self, testdir) -> None: reports = testdir.runitem( """ def test_func(): @@ -131,7 +140,7 @@ def test_func(): assert rep.outcome == "failed" # assert isinstance(rep.longrepr, ReprExceptionInfo) - def test_skipfunction(self, testdir): + def test_skipfunction(self, testdir) -> None: reports = testdir.runitem( """ import pytest @@ -151,7 +160,7 @@ def test_func(): # assert rep.skipped.location.path # assert not rep.skipped.failurerepr - def test_skip_in_setup_function(self, testdir): + def test_skip_in_setup_function(self, testdir) -> None: reports = testdir.runitem( """ import pytest @@ -172,7 +181,7 @@ def test_func(): assert len(reports) == 2 assert reports[1].passed # teardown - def test_failure_in_setup_function(self, testdir): + def test_failure_in_setup_function(self, testdir) -> None: reports = testdir.runitem( """ import pytest @@ -189,7 +198,7 @@ def test_func(): assert rep.when == "setup" assert len(reports) == 2 - def test_failure_in_teardown_function(self, testdir): + def test_failure_in_teardown_function(self, testdir) -> None: reports = testdir.runitem( """ import pytest @@ -209,7 +218,7 @@ def test_func(): # assert rep.longrepr.reprcrash.lineno == 3 # assert rep.longrepr.reprtraceback.reprentries - def test_custom_failure_repr(self, testdir): + def test_custom_failure_repr(self, testdir) -> None: testdir.makepyfile( conftest=""" import pytest @@ -234,7 +243,7 @@ def test_func(): # assert rep.failed.where.path.basename == "test_func.py" # assert rep.failed.failurerepr == "hello" - def test_teardown_final_returncode(self, testdir): + def test_teardown_final_returncode(self, testdir) -> None: rec = testdir.inline_runsource( """ def test_func(): @@ -245,7 +254,7 @@ def teardown_function(func): ) assert rec.ret == 1 - def test_logstart_logfinish_hooks(self, testdir): + def test_logstart_logfinish_hooks(self, testdir) -> None: rec = testdir.inline_runsource( """ import pytest @@ -262,7 +271,7 @@ def test_func(): assert rep.nodeid == "test_logstart_logfinish_hooks.py::test_func" assert rep.location == ("test_logstart_logfinish_hooks.py", 1, "test_func") - def test_exact_teardown_issue90(self, testdir): + def test_exact_teardown_issue90(self, testdir) -> None: rec = testdir.inline_runsource( """ import pytest @@ -302,7 +311,7 @@ def teardown_function(func): assert reps[5].nodeid.endswith("test_func") assert reps[5].failed - def test_exact_teardown_issue1206(self, testdir): + def test_exact_teardown_issue1206(self, testdir) -> None: """issue shadowing error with wrong number of arguments on teardown_method.""" rec = testdir.inline_runsource( """ @@ -338,7 +347,7 @@ def test_method(self): "TypeError: teardown_method() takes exactly 4 arguments (2 given)", ) - def test_failure_in_setup_function_ignores_custom_repr(self, testdir): + def test_failure_in_setup_function_ignores_custom_repr(self, testdir) -> None: testdir.makepyfile( conftest=""" import pytest @@ -366,7 +375,7 @@ def test_func(): # assert rep.outcome.where.path.basename == "test_func.py" # assert instanace(rep.failed.failurerepr, PythonFailureRepr) - def test_systemexit_does_not_bail_out(self, testdir): + def test_systemexit_does_not_bail_out(self, testdir) -> None: try: reports = testdir.runitem( """ @@ -380,7 +389,7 @@ def test_func(): assert rep.failed assert rep.when == "call" - def test_exit_propagates(self, testdir): + def test_exit_propagates(self, testdir) -> None: try: testdir.runitem( """ @@ -389,7 +398,7 @@ def test_func(): raise pytest.exit.Exception() """ ) - except pytest.exit.Exception: + except Exit: pass else: pytest.fail("did not raise") @@ -402,7 +411,7 @@ def f(item): return f - def test_keyboardinterrupt_propagates(self, testdir): + def test_keyboardinterrupt_propagates(self, testdir) -> None: try: testdir.runitem( """ @@ -424,7 +433,7 @@ def getrunner(self): boxed = pytest.importorskip("xdist.boxed") return boxed.forked_run_report - def test_suicide(self, testdir): + def test_suicide(self, testdir) -> None: reports = testdir.runitem( """ def test_func(): @@ -438,7 +447,7 @@ def test_func(): class TestSessionReports: - def test_collect_result(self, testdir): + def test_collect_result(self, testdir) -> None: col = testdir.getmodulecol( """ def test_func1(): @@ -461,20 +470,24 @@ class TestClass(object): assert res[1].name == "TestClass" -reporttypes = [reports.BaseReport, reports.TestReport, reports.CollectReport] +reporttypes = [ + reports.BaseReport, + reports.TestReport, + reports.CollectReport, +] # type: List[Type[reports.BaseReport]] @pytest.mark.parametrize( "reporttype", reporttypes, ids=[x.__name__ for x in reporttypes] ) -def test_report_extra_parameters(reporttype): +def test_report_extra_parameters(reporttype: "Type[reports.BaseReport]") -> None: args = list(inspect.signature(reporttype.__init__).parameters.keys())[1:] - basekw = dict.fromkeys(args, []) + basekw = dict.fromkeys(args, []) # type: Dict[str, List[object]] report = reporttype(newthing=1, **basekw) assert report.newthing == 1 -def test_callinfo(): +def test_callinfo() -> None: ci = runner.CallInfo.from_call(lambda: 0, "123") assert ci.when == "123" assert ci.result == 0 @@ -503,7 +516,7 @@ def raise_assertion(): @pytest.mark.xfail -def test_runtest_in_module_ordering(testdir): +def test_runtest_in_module_ordering(testdir) -> None: p1 = testdir.makepyfile( """ import pytest @@ -534,12 +547,12 @@ def pytest_runtest_teardown(item): result.stdout.fnmatch_lines(["*2 passed*"]) -def test_outcomeexception_exceptionattributes(): +def test_outcomeexception_exceptionattributes() -> None: outcome = outcomes.OutcomeException("test") assert outcome.args[0] == outcome.msg -def test_outcomeexception_passes_except_Exception(): +def test_outcomeexception_passes_except_Exception() -> None: with pytest.raises(outcomes.OutcomeException): try: raise outcomes.OutcomeException("test") @@ -547,20 +560,22 @@ def test_outcomeexception_passes_except_Exception(): pass -def test_pytest_exit(): - with pytest.raises(pytest.exit.Exception) as excinfo: +def test_pytest_exit() -> None: + assert Exit == pytest.exit.Exception # type: ignore + with pytest.raises(Exit) as excinfo: pytest.exit("hello") - assert excinfo.errisinstance(pytest.exit.Exception) + assert excinfo.errisinstance(Exit) -def test_pytest_fail(): - with pytest.raises(pytest.fail.Exception) as excinfo: +def test_pytest_fail() -> None: + assert Failed == pytest.fail.Exception # type: ignore + with pytest.raises(Failed) as excinfo: pytest.fail("hello") s = excinfo.exconly(tryshort=True) assert s.startswith("Failed") -def test_pytest_exit_msg(testdir): +def test_pytest_exit_msg(testdir) -> None: testdir.makeconftest( """ import pytest @@ -583,7 +598,7 @@ def _strip_resource_warnings(lines): ] -def test_pytest_exit_returncode(testdir): +def test_pytest_exit_returncode(testdir) -> None: testdir.makepyfile( """\ import pytest @@ -614,7 +629,7 @@ def pytest_sessionstart(): assert result.ret == 98 -def test_pytest_fail_notrace_runtest(testdir): +def test_pytest_fail_notrace_runtest(testdir) -> None: """Test pytest.fail(..., pytrace=False) does not show tracebacks during test run.""" testdir.makepyfile( """ @@ -630,7 +645,7 @@ def teardown_function(function): result.stdout.no_fnmatch_line("*def teardown_function*") -def test_pytest_fail_notrace_collection(testdir): +def test_pytest_fail_notrace_collection(testdir) -> None: """Test pytest.fail(..., pytrace=False) does not show tracebacks during collection.""" testdir.makepyfile( """ @@ -645,7 +660,7 @@ def some_internal_function(): result.stdout.no_fnmatch_line("*def some_internal_function()*") -def test_pytest_fail_notrace_non_ascii(testdir): +def test_pytest_fail_notrace_non_ascii(testdir) -> None: """Fix pytest.fail with pytrace=False with non-ascii characters (#1178). This tests with native and unicode strings containing non-ascii chars. @@ -663,7 +678,7 @@ def test_hello(): result.stdout.no_fnmatch_line("*def test_hello*") -def test_pytest_no_tests_collected_exit_status(testdir): +def test_pytest_no_tests_collected_exit_status(testdir) -> None: result = testdir.runpytest() result.stdout.fnmatch_lines(["*collected 0 items*"]) assert result.ret == main.ExitCode.NO_TESTS_COLLECTED @@ -685,16 +700,17 @@ def test_foo(): assert result.ret == main.ExitCode.NO_TESTS_COLLECTED -def test_exception_printing_skip(): +def test_exception_printing_skip() -> None: + assert Skipped == pytest.skip.Exception # type: ignore try: pytest.skip("hello") - except pytest.skip.Exception: + except Skipped: excinfo = _pytest._code.ExceptionInfo.from_current() s = excinfo.exconly(tryshort=True) assert s.startswith("Skipped") -def test_importorskip(monkeypatch): +def test_importorskip(monkeypatch) -> None: importorskip = pytest.importorskip def f(): @@ -705,45 +721,49 @@ def f(): assert sysmod is sys # path = pytest.importorskip("os.path") # assert path == os.path - excinfo = pytest.raises(pytest.skip.Exception, f) - path = py.path.local(excinfo.getrepr().reprcrash.path) + excinfo = pytest.raises(Skipped, f) + assert excinfo is not None + excrepr = excinfo.getrepr() + assert excrepr is not None + assert excrepr.reprcrash is not None + path = py.path.local(excrepr.reprcrash.path) # check that importorskip reports the actual call # in this test the test_runner.py file assert path.purebasename == "test_runner" pytest.raises(SyntaxError, pytest.importorskip, "x y z") pytest.raises(SyntaxError, pytest.importorskip, "x=y") mod = types.ModuleType("hello123") - mod.__version__ = "1.3" + mod.__version__ = "1.3" # type: ignore monkeypatch.setitem(sys.modules, "hello123", mod) - with pytest.raises(pytest.skip.Exception): + with pytest.raises(Skipped): pytest.importorskip("hello123", minversion="1.3.1") mod2 = pytest.importorskip("hello123", minversion="1.3") assert mod2 == mod - except pytest.skip.Exception: + except Skipped: print(_pytest._code.ExceptionInfo.from_current()) pytest.fail("spurious skip") -def test_importorskip_imports_last_module_part(): +def test_importorskip_imports_last_module_part() -> None: ospath = pytest.importorskip("os.path") assert os.path == ospath -def test_importorskip_dev_module(monkeypatch): +def test_importorskip_dev_module(monkeypatch) -> None: try: mod = types.ModuleType("mockmodule") - mod.__version__ = "0.13.0.dev-43290" + mod.__version__ = "0.13.0.dev-43290" # type: ignore monkeypatch.setitem(sys.modules, "mockmodule", mod) mod2 = pytest.importorskip("mockmodule", minversion="0.12.0") assert mod2 == mod - with pytest.raises(pytest.skip.Exception): + with pytest.raises(Skipped): pytest.importorskip("mockmodule1", minversion="0.14.0") - except pytest.skip.Exception: + except Skipped: print(_pytest._code.ExceptionInfo.from_current()) pytest.fail("spurious skip") -def test_importorskip_module_level(testdir): +def test_importorskip_module_level(testdir) -> None: """importorskip must be able to skip entire modules when used at module level""" testdir.makepyfile( """ @@ -758,7 +778,7 @@ def test_foo(): result.stdout.fnmatch_lines(["*collected 0 items / 1 skipped*"]) -def test_importorskip_custom_reason(testdir): +def test_importorskip_custom_reason(testdir) -> None: """make sure custom reasons are used""" testdir.makepyfile( """ @@ -774,7 +794,7 @@ def test_foo(): result.stdout.fnmatch_lines(["*collected 0 items / 1 skipped*"]) -def test_pytest_cmdline_main(testdir): +def test_pytest_cmdline_main(testdir) -> None: p = testdir.makepyfile( """ import pytest @@ -792,7 +812,7 @@ def test_hello(): assert ret == 0 -def test_unicode_in_longrepr(testdir): +def test_unicode_in_longrepr(testdir) -> None: testdir.makeconftest( """\ import pytest @@ -815,7 +835,7 @@ def test_out(): assert "UnicodeEncodeError" not in result.stderr.str() -def test_failure_in_setup(testdir): +def test_failure_in_setup(testdir) -> None: testdir.makepyfile( """ def setup_module(): @@ -828,7 +848,7 @@ def test_func(): result.stdout.no_fnmatch_line("*def setup_module*") -def test_makereport_getsource(testdir): +def test_makereport_getsource(testdir) -> None: testdir.makepyfile( """ def test_foo(): @@ -841,17 +861,17 @@ def test_foo(): result.stdout.fnmatch_lines(["*else: assert False*"]) -def test_makereport_getsource_dynamic_code(testdir, monkeypatch): +def test_makereport_getsource_dynamic_code(testdir, monkeypatch) -> None: """Test that exception in dynamically generated code doesn't break getting the source line.""" import inspect original_findsource = inspect.findsource - def findsource(obj, *args, **kwargs): + def findsource(obj): # Can be triggered by dynamically created functions if obj.__name__ == "foo": raise IndexError() - return original_findsource(obj, *args, **kwargs) + return original_findsource(obj) monkeypatch.setattr(inspect, "findsource", findsource) @@ -872,7 +892,7 @@ def test_fix(foo): result.stdout.fnmatch_lines(["*test_fix*", "*fixture*'missing'*not found*"]) -def test_store_except_info_on_error(): +def test_store_except_info_on_error() -> None: """ Test that upon test failure, the exception info is stored on sys.last_traceback and friends. """ @@ -891,6 +911,7 @@ def runtest(self): pass # Check that exception info is stored on sys assert sys.last_type is IndexError + assert isinstance(sys.last_value, IndexError) assert sys.last_value.args[0] == "TEST" assert sys.last_traceback @@ -902,8 +923,8 @@ def runtest(self): assert not hasattr(sys, "last_traceback") -def test_current_test_env_var(testdir, monkeypatch): - pytest_current_test_vars = [] +def test_current_test_env_var(testdir, monkeypatch) -> None: + pytest_current_test_vars = [] # type: List[Tuple[str, str]] monkeypatch.setattr( sys, "pytest_current_test_vars", pytest_current_test_vars, raising=False ) @@ -942,7 +963,7 @@ class TestReportContents: def getrunner(self): return lambda item: runner.runtestprotocol(item, log=False) - def test_longreprtext_pass(self, testdir): + def test_longreprtext_pass(self, testdir) -> None: reports = testdir.runitem( """ def test_func(): @@ -952,7 +973,7 @@ def test_func(): rep = reports[1] assert rep.longreprtext == "" - def test_longreprtext_failure(self, testdir): + def test_longreprtext_failure(self, testdir) -> None: reports = testdir.runitem( """ def test_func(): @@ -963,7 +984,7 @@ def test_func(): rep = reports[1] assert "assert 1 == 4" in rep.longreprtext - def test_captured_text(self, testdir): + def test_captured_text(self, testdir) -> None: reports = testdir.runitem( """ import pytest @@ -993,7 +1014,7 @@ def test_func(fix): assert call.capstderr == "setup: stderr\ncall: stderr\n" assert teardown.capstderr == "setup: stderr\ncall: stderr\nteardown: stderr\n" - def test_no_captured_text(self, testdir): + def test_no_captured_text(self, testdir) -> None: reports = testdir.runitem( """ def test_func(): @@ -1005,10 +1026,10 @@ def test_func(): assert rep.capstderr == "" -def test_outcome_exception_bad_msg(): +def test_outcome_exception_bad_msg() -> None: """Check that OutcomeExceptions validate their input to prevent confusing errors (#5578)""" - def func(): + def func() -> None: pass expected = ( @@ -1016,5 +1037,5 @@ def func(): "Perhaps you meant to use a mark?" ) with pytest.raises(TypeError) as excinfo: - OutcomeException(func) + OutcomeException(func) # type: ignore assert str(excinfo.value) == expected