Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix type errors, add type tests #118

Merged
merged 12 commits into from
Apr 18, 2024
10 changes: 7 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
pull_request:

jobs:
pyright:
typing:
strategy:
fail-fast: false
matrix:
Expand All @@ -23,9 +23,13 @@ jobs:
path: ~/.cache/pip
key: pip-pyright
- name: Install dependencies
run: pip install -e . pyright
- name: Run pyright
run: pip install -e . pyright mypy
- name: Run pyright verifytypes
run: pyright --verifytypes exceptiongroup --verbose
- name: Run pyright type test
run: pyright tests/check_types.py
- name: Run mypy type test
run: mypy tests/check_types.py

test:
strategy:
Expand Down
25 changes: 19 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ version_scheme = "post-release"
local_scheme = "dirty-tag"
write_to = "src/exceptiongroup/_version.py"

[tool.ruff]
[tool.ruff.lint]
select = [
"E", "F", "W", # default flake-8
"I", # isort
Expand All @@ -55,11 +55,11 @@ select = [
"UP", # pyupgrade
]

[tool.ruff.pyupgrade]
[tool.ruff.lint.pyupgrade]
# Preserve types, even if a file imports `from __future__ import annotations`.
keep-runtime-typing = true

[tool.ruff.isort]
[tool.ruff.lint.isort]
known-first-party = ["exceptiongroup"]

[tool.pytest.ini_options]
Expand All @@ -76,6 +76,14 @@ exclude_also = [
"@overload",
]

[tool.pyright]
# for type tests, the code itself isn't type checked in CI
reportUnnecessaryTypeIgnoreComment = true

[tool.mypy]
# for type tests, the code itself isn't type checked in CI
warn_unused_ignores = true

[tool.tox]
legacy_tox_ini = """
[tox]
Expand All @@ -90,8 +98,13 @@ extras = test
commands = python -m pytest {posargs}
usedevelop = true

[testenv:{py37-,py38-,py39-,py310-,py311-,py312-,}pyright]
deps = pyright
commands = pyright --verifytypes exceptiongroup
[testenv:{py37-,py38-,py39-,py310-,py311-,py312-,}typing]
jakkdl marked this conversation as resolved.
Show resolved Hide resolved
deps =
pyright
mypy
commands =
pyright --verifytypes exceptiongroup
pyright tests/check_types.py
mypy tests/check_types.py
usedevelop = true
"""
6 changes: 4 additions & 2 deletions src/exceptiongroup/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class BaseExceptionGroup(BaseException, Generic[_BaseExceptionT_co]):
"""A combination of multiple unrelated exceptions."""

def __new__(
cls: _BaseExceptionGroupSelf,
cls: type[_BaseExceptionGroupSelf],
__message: str,
__exceptions: Sequence[_BaseExceptionT_co],
) -> _BaseExceptionGroupSelf:
Expand Down Expand Up @@ -265,7 +265,9 @@ def __repr__(self) -> str:

class ExceptionGroup(BaseExceptionGroup[_ExceptionT_co], Exception):
def __new__(
cls, __message: str, __exceptions: Sequence[_ExceptionT_co]
cls: type[_ExceptionGroupSelf],
__message: str,
__exceptions: Sequence[_ExceptionT_co],
) -> _ExceptionGroupSelf:
return super().__new__(cls, __message, __exceptions)

Expand Down
8 changes: 5 additions & 3 deletions src/exceptiongroup/_suppress.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import sys
from contextlib import AbstractContextManager
from types import TracebackType
from typing import TYPE_CHECKING, Optional, Type
from typing import TYPE_CHECKING, Optional, Type, cast

if sys.version_info < (3, 11):
from ._exceptions import BaseExceptionGroup
Expand All @@ -16,7 +18,7 @@
class suppress(BaseClass):
"""Backport of :class:`contextlib.suppress` from Python 3.12.1."""

def __init__(self, *exceptions: BaseException):
def __init__(self, *exceptions: type[BaseException]):
self._exceptions = exceptions

def __enter__(self) -> None:
Expand Down Expand Up @@ -44,7 +46,7 @@ def __exit__(
return True

if issubclass(exctype, BaseExceptionGroup):
match, rest = excinst.split(self._exceptions)
match, rest = cast(BaseExceptionGroup, excinst).split(self._exceptions)
if rest is None:
return True

Expand Down
36 changes: 36 additions & 0 deletions tests/check_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing_extensions import assert_type

from exceptiongroup import BaseExceptionGroup, ExceptionGroup, catch, suppress

# issue 117
a = BaseExceptionGroup("", (KeyboardInterrupt(),))
assert_type(a, BaseExceptionGroup[KeyboardInterrupt])
b = BaseExceptionGroup("", (ValueError(),))
assert_type(b, BaseExceptionGroup[ValueError])
c = ExceptionGroup("", (ValueError(),))
assert_type(c, ExceptionGroup[ValueError])

# expected type error when passing a BaseException to ExceptionGroup
ExceptionGroup("", (KeyboardInterrupt(),)) # type: ignore[type-var]


# code snippets from the README


def value_key_err_handler(excgroup: BaseExceptionGroup) -> None:
for exc in excgroup.exceptions:
print("Caught exception:", type(exc))


def runtime_err_handler(exc: BaseExceptionGroup) -> None:
print("Caught runtime error")


with catch(
{(ValueError, KeyError): value_key_err_handler, RuntimeError: runtime_err_handler}
):
...


with suppress(RuntimeError):
raise ExceptionGroup("", [RuntimeError("boo")])