diff --git a/changelog/10660.bugfix.rst b/changelog/10660.bugfix.rst new file mode 100644 index 00000000000..62e3549413b --- /dev/null +++ b/changelog/10660.bugfix.rst @@ -0,0 +1,2 @@ +Fix :py:func:`pytest.raises` to return a 'ContextManager' so that type-checkers could narrow +:code:`pytest.raises(...) if ... else nullcontext()` down to 'ContextManager' rather than 'object'. diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index 515d437f0d8..20908cb67b8 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -8,7 +8,7 @@ from typing import Any from typing import Callable from typing import cast -from typing import Generic +from typing import ContextManager from typing import List from typing import Mapping from typing import Optional @@ -957,7 +957,7 @@ def raises( # noqa: F811 @final -class RaisesContext(Generic[E]): +class RaisesContext(ContextManager[_pytest._code.ExceptionInfo[E]]): def __init__( self, expected_exception: Union[Type[E], Tuple[Type[E], ...]], diff --git a/testing/typing_checks.py b/testing/typing_checks.py index 0a6b5ad2841..d15b3988bb5 100644 --- a/testing/typing_checks.py +++ b/testing/typing_checks.py @@ -3,6 +3,11 @@ This file is not executed, it is only checked by mypy to ensure that none of the code triggers any mypy errors. """ +import contextlib +from typing import Optional + +from typing_extensions import assert_type + import pytest @@ -22,3 +27,9 @@ def check_fixture_ids_callable() -> None: @pytest.mark.parametrize("func", [str, int], ids=lambda x: str(x.__name__)) def check_parametrize_ids_callable(func) -> None: pass + + +def check_raises_is_a_context_manager(val: bool) -> None: + with pytest.raises(RuntimeError) if val else contextlib.nullcontext() as excinfo: + pass + assert_type(excinfo, Optional[pytest.ExceptionInfo[RuntimeError]])