From 1a96f16401c7bf10d4873f54b02f7bbbf94f492d Mon Sep 17 00:00:00 2001 From: Ilya Konstantinov Date: Fri, 13 Jan 2023 09:29:38 -0500 Subject: [PATCH] Derive pytest.raises from AbstractContextManager Makes `AbstractContextManager` the shared base class between "raises" and other context managers. The motivation is for type checkers to narrow `pytest.raises(...) if x else nullcontext()` to a `ContextManager` rather than `object`. --- changelog/10660.bugfix.rst | 2 ++ src/_pytest/python_api.py | 4 ++-- testing/typing_checks.py | 11 +++++++++++ 3 files changed, 15 insertions(+), 2 deletions(-) create mode 100644 changelog/10660.bugfix.rst diff --git a/changelog/10660.bugfix.rst b/changelog/10660.bugfix.rst new file mode 100644 index 00000000000..62e3549413b --- /dev/null +++ b/changelog/10660.bugfix.rst @@ -0,0 +1,2 @@ +Fix :py:func:`pytest.raises` to return a 'ContextManager' so that type-checkers could narrow +:code:`pytest.raises(...) if ... else nullcontext()` down to 'ContextManager' rather than 'object'. diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index ea45753cde6..4bc9348613a 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -8,7 +8,7 @@ from typing import Any from typing import Callable from typing import cast -from typing import Generic +from typing import ContextManager from typing import List from typing import Mapping from typing import Optional @@ -957,7 +957,7 @@ def raises( # noqa: F811 @final -class RaisesContext(Generic[E]): +class RaisesContext(ContextManager[_pytest._code.ExceptionInfo[E]]): def __init__( self, expected_exception: Union[Type[E], Tuple[Type[E], ...]], diff --git a/testing/typing_checks.py b/testing/typing_checks.py index 0a6b5ad2841..d15b3988bb5 100644 --- a/testing/typing_checks.py +++ b/testing/typing_checks.py @@ -3,6 +3,11 @@ This file is not executed, it is only checked by mypy to ensure that none of the code triggers any mypy errors. """ +import contextlib +from typing import Optional + +from typing_extensions import assert_type + import pytest @@ -22,3 +27,9 @@ def check_fixture_ids_callable() -> None: @pytest.mark.parametrize("func", [str, int], ids=lambda x: str(x.__name__)) def check_parametrize_ids_callable(func) -> None: pass + + +def check_raises_is_a_context_manager(val: bool) -> None: + with pytest.raises(RuntimeError) if val else contextlib.nullcontext() as excinfo: + pass + assert_type(excinfo, Optional[pytest.ExceptionInfo[RuntimeError]])