Skip to content

Commit

Permalink
fix type errors, add type tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jakkdl committed Apr 15, 2024
1 parent 8d2f627 commit 3d17e19
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 10 deletions.
10 changes: 7 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
pull_request:

jobs:
pyright:
typing:
strategy:
fail-fast: false
matrix:
Expand All @@ -23,9 +23,13 @@ jobs:
path: ~/.cache/pip
key: pip-pyright
- name: Install dependencies
run: pip install -e . pyright
- name: Run pyright
run: pip install -e . pyright mypy
- name: Run pyright verifytypes
run: pyright --verifytypes exceptiongroup --verbose
- name: Run pyright type test
run: pyright tests/type_test.py
- name: Run mypy type test
run: mypy tests/type_test.py

test:
strategy:
Expand Down
17 changes: 15 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ exclude_also = [
"@overload",
]

[tool.pyright]
# for type tests, the code itself isn't type checked in CI
reportUnnecessaryTypeIgnoreComment = true

[tool.mypy]
# for type tests, the code itself isn't type checked in CI
warn_unused_ignores = true

[tool.tox]
legacy_tox_ini = """
[tox]
Expand All @@ -91,7 +99,12 @@ commands = python -m pytest {posargs}
usedevelop = true
[testenv:{py37-,py38-,py39-,py310-,py311-,py312-,}pyright]
deps = pyright
commands = pyright --verifytypes exceptiongroup
deps =
pyright
mypy
commands =
pyright --verifytypes exceptiongroup
pyright tests/type_test.py
mypy tests/type_test.py
usedevelop = true
"""
4 changes: 2 additions & 2 deletions src/exceptiongroup/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class BaseExceptionGroup(BaseException, Generic[_BaseExceptionT_co]):
"""A combination of multiple unrelated exceptions."""

def __new__(
cls: _BaseExceptionGroupSelf,
cls: type[_BaseExceptionGroupSelf],
__message: str,
__exceptions: Sequence[_BaseExceptionT_co],
) -> _BaseExceptionGroupSelf:
Expand Down Expand Up @@ -265,7 +265,7 @@ def __repr__(self) -> str:

class ExceptionGroup(BaseExceptionGroup[_ExceptionT_co], Exception):
def __new__(
cls, __message: str, __exceptions: Sequence[_ExceptionT_co]
cls: type[_ExceptionGroupSelf], __message: str, __exceptions: Sequence[_ExceptionT_co]
) -> _ExceptionGroupSelf:
return super().__new__(cls, __message, __exceptions)

Expand Down
6 changes: 3 additions & 3 deletions src/exceptiongroup/_suppress.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
from contextlib import AbstractContextManager
from types import TracebackType
from typing import TYPE_CHECKING, Optional, Type
from typing import TYPE_CHECKING, Optional, Type, cast

if sys.version_info < (3, 11):
from ._exceptions import BaseExceptionGroup
Expand All @@ -16,7 +16,7 @@
class suppress(BaseClass):
"""Backport of :class:`contextlib.suppress` from Python 3.12.1."""

def __init__(self, *exceptions: BaseException):
def __init__(self, *exceptions: type[BaseException]):
self._exceptions = exceptions

def __enter__(self) -> None:
Expand Down Expand Up @@ -44,7 +44,7 @@ def __exit__(
return True

if issubclass(exctype, BaseExceptionGroup):
match, rest = excinst.split(self._exceptions)
match, rest = cast(BaseExceptionGroup, excinst).split(self._exceptions)
if rest is None:
return True

Expand Down
34 changes: 34 additions & 0 deletions tests/type_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing_extensions import assert_type

from exceptiongroup import BaseExceptionGroup, ExceptionGroup, catch, suppress

# issue 117
a = BaseExceptionGroup("", (KeyboardInterrupt(),))
assert_type(a, BaseExceptionGroup[KeyboardInterrupt])
b = BaseExceptionGroup("", (ValueError(),))
assert_type(b, BaseExceptionGroup[ValueError])
c = ExceptionGroup("", (ValueError(),))
assert_type(c, ExceptionGroup[ValueError])

# expected type error when passing a BaseException to ExceptionGroup
ExceptionGroup("", (KeyboardInterrupt(),)) # type: ignore


# code snippets from the README

def value_key_err_handler(excgroup: BaseExceptionGroup) -> None:
for exc in excgroup.exceptions:
print('Caught exception:', type(exc))

def runtime_err_handler(exc: BaseExceptionGroup) -> None:
print('Caught runtime error')

with catch({
(ValueError, KeyError): value_key_err_handler,
RuntimeError: runtime_err_handler
}):
...


with suppress(RuntimeError):
raise ExceptionGroup("", [RuntimeError("boo")])

0 comments on commit 3d17e19

Please sign in to comment.