From 521f02f7b1282ea3aaba20f450beb6d6283a5037 Mon Sep 17 00:00:00 2001 From: John Litborn <11260241+jakkdl@users.noreply.github.com> Date: Thu, 18 Apr 2024 23:16:54 +0200 Subject: [PATCH] Fixed type errors, added type tests (#118) --- .github/workflows/test.yml | 10 ++++++--- CHANGES.rst | 3 +++ pyproject.toml | 27 +++++++++++++++++------ src/exceptiongroup/_exceptions.py | 6 ++++-- src/exceptiongroup/_suppress.py | 8 ++++--- tests/check_types.py | 36 +++++++++++++++++++++++++++++++ 6 files changed, 75 insertions(+), 15 deletions(-) create mode 100644 tests/check_types.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c1a0343..aec0aa3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,7 +6,7 @@ on: pull_request: jobs: - pyright: + typing: strategy: fail-fast: false matrix: @@ -23,9 +23,13 @@ jobs: path: ~/.cache/pip key: pip-pyright - name: Install dependencies - run: pip install -e . pyright - - name: Run pyright + run: pip install -e . pyright mypy + - name: Run pyright --verifytypes run: pyright --verifytypes exceptiongroup --verbose + - name: Run pyright type test + run: pyright tests/check_types.py + - name: Run mypy type test + run: mypy tests/check_types.py test: strategy: diff --git a/CHANGES.rst b/CHANGES.rst index cb902f8..1c756b8 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -5,8 +5,11 @@ This library adheres to `Semantic Versioning 2.0 `_. **UNRELEASED** +- Updated the copying of ``__notes__`` to match CPython behavior (PR by CF Bolz-Tereick) - Corrected the type annotation of the exception handler callback to accept a ``BaseExceptionGroup`` instead of ``BaseException`` +- Fixed type errors on Python < 3.10 and the type annotation of ``suppress()`` + (PR by John Litborn) **1.2.0** diff --git a/pyproject.toml b/pyproject.toml index f132de0..aa47cdc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ version_scheme = "post-release" local_scheme = "dirty-tag" write_to = "src/exceptiongroup/_version.py" -[tool.ruff] +[tool.ruff.lint] select = [ "E", "F", "W", # default flake-8 "I", # isort @@ -55,11 +55,11 @@ select = [ "UP", # pyupgrade ] -[tool.ruff.pyupgrade] +[tool.ruff.lint.pyupgrade] # Preserve types, even if a file imports `from __future__ import annotations`. keep-runtime-typing = true -[tool.ruff.isort] +[tool.ruff.lint.isort] known-first-party = ["exceptiongroup"] [tool.pytest.ini_options] @@ -76,12 +76,20 @@ exclude_also = [ "@overload", ] +[tool.pyright] +# for type tests, the code itself isn't type checked in CI +reportUnnecessaryTypeIgnoreComment = true + +[tool.mypy] +# for type tests, the code itself isn't type checked in CI +warn_unused_ignores = true + [tool.tox] legacy_tox_ini = """ [tox] envlist = py37, py38, py39, py310, py311, py312, pypy3 labels = - pyright = py{310,311,312}-pyright + typing = py{310,311,312}-typing skip_missing_interpreters = true minversion = 4.0 @@ -91,8 +99,13 @@ commands = python -m pytest {posargs} package = editable usedevelop = true -[testenv:{py37-,py38-,py39-,py310-,py311-,py312-,}pyright] -deps = pyright -commands = pyright --verifytypes exceptiongroup +[testenv:{py37-,py38-,py39-,py310-,py311-,py312-,}typing] +deps = + pyright + mypy +commands = + pyright --verifytypes exceptiongroup + pyright tests/check_types.py + mypy tests/check_types.py usedevelop = true """ diff --git a/src/exceptiongroup/_exceptions.py b/src/exceptiongroup/_exceptions.py index 2513bd9..a4a7ace 100644 --- a/src/exceptiongroup/_exceptions.py +++ b/src/exceptiongroup/_exceptions.py @@ -57,7 +57,7 @@ class BaseExceptionGroup(BaseException, Generic[_BaseExceptionT_co]): """A combination of multiple unrelated exceptions.""" def __new__( - cls: _BaseExceptionGroupSelf, + cls: type[_BaseExceptionGroupSelf], __message: str, __exceptions: Sequence[_BaseExceptionT_co], ) -> _BaseExceptionGroupSelf: @@ -265,7 +265,9 @@ def __repr__(self) -> str: class ExceptionGroup(BaseExceptionGroup[_ExceptionT_co], Exception): def __new__( - cls, __message: str, __exceptions: Sequence[_ExceptionT_co] + cls: type[_ExceptionGroupSelf], + __message: str, + __exceptions: Sequence[_ExceptionT_co], ) -> _ExceptionGroupSelf: return super().__new__(cls, __message, __exceptions) diff --git a/src/exceptiongroup/_suppress.py b/src/exceptiongroup/_suppress.py index baed57f..11467ee 100644 --- a/src/exceptiongroup/_suppress.py +++ b/src/exceptiongroup/_suppress.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import sys from contextlib import AbstractContextManager from types import TracebackType -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING, Optional, Type, cast if sys.version_info < (3, 11): from ._exceptions import BaseExceptionGroup @@ -16,7 +18,7 @@ class suppress(BaseClass): """Backport of :class:`contextlib.suppress` from Python 3.12.1.""" - def __init__(self, *exceptions: BaseException): + def __init__(self, *exceptions: type[BaseException]): self._exceptions = exceptions def __enter__(self) -> None: @@ -44,7 +46,7 @@ def __exit__( return True if issubclass(exctype, BaseExceptionGroup): - match, rest = excinst.split(self._exceptions) + match, rest = cast(BaseExceptionGroup, excinst).split(self._exceptions) if rest is None: return True diff --git a/tests/check_types.py b/tests/check_types.py new file mode 100644 index 0000000..f7a102d --- /dev/null +++ b/tests/check_types.py @@ -0,0 +1,36 @@ +from typing_extensions import assert_type + +from exceptiongroup import BaseExceptionGroup, ExceptionGroup, catch, suppress + +# issue 117 +a = BaseExceptionGroup("", (KeyboardInterrupt(),)) +assert_type(a, BaseExceptionGroup[KeyboardInterrupt]) +b = BaseExceptionGroup("", (ValueError(),)) +assert_type(b, BaseExceptionGroup[ValueError]) +c = ExceptionGroup("", (ValueError(),)) +assert_type(c, ExceptionGroup[ValueError]) + +# expected type error when passing a BaseException to ExceptionGroup +ExceptionGroup("", (KeyboardInterrupt(),)) # type: ignore[type-var] + + +# code snippets from the README + + +def value_key_err_handler(excgroup: BaseExceptionGroup) -> None: + for exc in excgroup.exceptions: + print("Caught exception:", type(exc)) + + +def runtime_err_handler(exc: BaseExceptionGroup) -> None: + print("Caught runtime error") + + +with catch( + {(ValueError, KeyError): value_key_err_handler, RuntimeError: runtime_err_handler} +): + ... + + +with suppress(RuntimeError): + raise ExceptionGroup("", [RuntimeError("boo")])