Skip to content

Commit

Permalink
Updated type annotations to match those in typeshed (agronholm#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
agronholm committed Nov 13, 2022
1 parent 2167639 commit cbac14e
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 27 deletions.
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ Version history

This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.

**UNRELEASED**

- Updated type annotations to match the ones in ``typeshed``

**1.0.1**

- Fixed formatted traceback missing exceptions beyond 2 nesting levels of
Expand Down
149 changes: 122 additions & 27 deletions src/exceptiongroup/_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from __future__ import annotations

from collections.abc import Sequence
from collections.abc import Callable, Sequence
from functools import partial
from inspect import getmro, isclass
from typing import Any, Callable, Generic, Tuple, Type, TypeVar, Union, cast
from typing import TYPE_CHECKING, Any, Generic, Type, TypeVar, cast, overload

T = TypeVar("T", bound="BaseExceptionGroup")
EBase = TypeVar("EBase", bound=BaseException)
E = TypeVar("E", bound=Exception)
_SplitCondition = Union[
Type[EBase],
Tuple[Type[EBase], ...],
Callable[[EBase], bool],
]
if TYPE_CHECKING:
from typing import Self

_BaseExceptionT_co = TypeVar("_BaseExceptionT_co", bound=BaseException, covariant=True)
_BaseExceptionT = TypeVar("_BaseExceptionT", bound=BaseException)
_ExceptionT_co = TypeVar("_ExceptionT_co", bound=Exception, covariant=True)
_ExceptionT = TypeVar("_ExceptionT", bound=Exception)


def check_direct_subclass(
Expand All @@ -25,7 +24,11 @@ def check_direct_subclass(
return False


def get_condition_filter(condition: _SplitCondition) -> Callable[[BaseException], bool]:
def get_condition_filter(
condition: type[_BaseExceptionT]
| tuple[type[_BaseExceptionT], ...]
| Callable[[_BaseExceptionT_co], bool]
) -> Callable[[_BaseExceptionT_co], bool]:
if isclass(condition) and issubclass(
cast(Type[BaseException], condition), BaseException
):
Expand All @@ -34,17 +37,17 @@ def get_condition_filter(condition: _SplitCondition) -> Callable[[BaseException]
if all(isclass(x) and issubclass(x, BaseException) for x in condition):
return partial(check_direct_subclass, parents=condition)
elif callable(condition):
return cast(Callable[[BaseException], bool], condition)
return cast("Callable[[BaseException], bool]", condition)

raise TypeError("expected a function, exception type or tuple of exception types")


class BaseExceptionGroup(BaseException, Generic[EBase]):
class BaseExceptionGroup(BaseException, Generic[_BaseExceptionT_co]):
"""A combination of multiple unrelated exceptions."""

def __new__(
cls, __message: str, __exceptions: Sequence[EBase]
) -> BaseExceptionGroup[EBase] | ExceptionGroup[E]:
cls, __message: str, __exceptions: Sequence[_BaseExceptionT_co]
) -> Self:
if not isinstance(__message, str):
raise TypeError(f"argument 1 must be str, not {type(__message)}")
if not isinstance(__exceptions, Sequence):
Expand All @@ -66,7 +69,9 @@ def __new__(

return super().__new__(cls, __message, __exceptions)

def __init__(self, __message: str, __exceptions: Sequence[EBase], *args: Any):
def __init__(
self, __message: str, __exceptions: Sequence[_BaseExceptionT_co], *args: Any
):
super().__init__(__message, __exceptions, *args)
self._message = __message
self._exceptions = __exceptions
Expand All @@ -87,10 +92,29 @@ def message(self) -> str:
return self._message

@property
def exceptions(self) -> tuple[EBase, ...]:
def exceptions(
self,
) -> tuple[_BaseExceptionT_co | BaseExceptionGroup[_BaseExceptionT_co], ...]:
return tuple(self._exceptions)

def subgroup(self: T, __condition: _SplitCondition[EBase]) -> T | None:
@overload
def subgroup(
self, __condition: type[_BaseExceptionT] | tuple[type[_BaseExceptionT], ...]
) -> BaseExceptionGroup[_BaseExceptionT] | None:
...

@overload
def subgroup(
self: Self, __condition: Callable[[_BaseExceptionT_co], bool]
) -> Self | None:
...

def subgroup(
self: Self,
__condition: type[_BaseExceptionT]
| tuple[type[_BaseExceptionT], ...]
| Callable[[_BaseExceptionT_co], bool],
) -> BaseExceptionGroup[_BaseExceptionT] | Self | None:
condition = get_condition_filter(__condition)
modified = False
if condition(self):
Expand All @@ -99,7 +123,7 @@ def subgroup(self: T, __condition: _SplitCondition[EBase]) -> T | None:
exceptions: list[BaseException] = []
for exc in self.exceptions:
if isinstance(exc, BaseExceptionGroup):
subgroup = exc.subgroup(condition)
subgroup = exc.subgroup(__condition)
if subgroup is not None:
exceptions.append(subgroup)

Expand All @@ -121,9 +145,27 @@ def subgroup(self: T, __condition: _SplitCondition[EBase]) -> T | None:
else:
return None

@overload
def split(
self: Self,
__condition: type[_BaseExceptionT] | tuple[type[_BaseExceptionT], ...],
) -> tuple[BaseExceptionGroup[_BaseExceptionT] | None, Self | None]:
...

@overload
def split(
self: T, __condition: _SplitCondition[EBase]
) -> tuple[T | None, T | None]:
self: Self, __condition: Callable[[_BaseExceptionT_co], bool]
) -> tuple[Self | None, Self | None]:
...

def split(
self: Self,
__condition: type[_BaseExceptionT]
| tuple[type[_BaseExceptionT], ...]
| Callable[[_BaseExceptionT_co], bool],
) -> tuple[BaseExceptionGroup[_BaseExceptionT] | None, Self | None] | tuple[
Self | None, Self | None
]:
condition = get_condition_filter(__condition)
if condition(self):
return self, None
Expand All @@ -143,14 +185,14 @@ def split(
else:
nonmatching_exceptions.append(exc)

matching_group: T | None = None
matching_group: Self | None = None
if matching_exceptions:
matching_group = self.derive(matching_exceptions)
matching_group.__cause__ = self.__cause__
matching_group.__context__ = self.__context__
matching_group.__traceback__ = self.__traceback__

nonmatching_group: T | None = None
nonmatching_group: Self | None = None
if nonmatching_exceptions:
nonmatching_group = self.derive(nonmatching_exceptions)
nonmatching_group.__cause__ = self.__cause__
Expand All @@ -159,11 +201,12 @@ def split(

return matching_group, nonmatching_group

def derive(self: T, __excs: Sequence[EBase]) -> T:
def derive(self: Self, __excs: Sequence[_BaseExceptionT_co]) -> Self:
eg = BaseExceptionGroup(self.message, __excs)
if hasattr(self, "__notes__"):
# Create a new list so that add_note() only affects one exceptiongroup
eg.__notes__ = list(self.__notes__)

return eg

def __str__(self) -> str:
Expand All @@ -174,12 +217,64 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.message!r}, {self._exceptions!r})"


class ExceptionGroup(BaseExceptionGroup[E], Exception, Generic[E]):
def __new__(cls, __message: str, __exceptions: Sequence[E]) -> ExceptionGroup[E]:
instance: ExceptionGroup[E] = super().__new__(cls, __message, __exceptions)
class ExceptionGroup(BaseExceptionGroup[_ExceptionT_co], Exception):
def __new__(cls, __message: str, __exceptions: Sequence[_ExceptionT_co]) -> Self:
instance: ExceptionGroup[_ExceptionT_co] = super().__new__(
cls, __message, __exceptions
)
if cls is ExceptionGroup:
for exc in __exceptions:
if not isinstance(exc, Exception):
raise TypeError("Cannot nest BaseExceptions in an ExceptionGroup")

return instance

if TYPE_CHECKING:

@property
def exceptions(
self,
) -> tuple[_ExceptionT_co | ExceptionGroup[_ExceptionT_co], ...]:
...

@overload # type: ignore[override]
def subgroup(
self, __condition: type[_ExceptionT] | tuple[type[_ExceptionT], ...]
) -> ExceptionGroup[_ExceptionT] | None:
...

@overload
def subgroup(
self: Self, __condition: Callable[[_ExceptionT_co], bool]
) -> Self | None:
...

def subgroup(
self: Self,
__condition: type[_ExceptionT]
| tuple[type[_ExceptionT], ...]
| Callable[[_ExceptionT_co], bool],
) -> ExceptionGroup[_ExceptionT] | Self | None:
return super().subgroup(__condition)

@overload # type: ignore[override]
def split(
self: Self, __condition: type[_ExceptionT] | tuple[type[_ExceptionT], ...]
) -> tuple[ExceptionGroup[_ExceptionT] | None, Self | None]:
...

@overload
def split(
self: Self, __condition: Callable[[_ExceptionT_co], bool]
) -> tuple[Self | None, Self | None]:
...

def split(
self: Self,
__condition: type[_ExceptionT]
| tuple[type[_ExceptionT], ...]
| Callable[[_ExceptionT_co], bool],
) -> tuple[ExceptionGroup[_ExceptionT] | None, Self | None] | tuple[
Self | None, Self | None
]:
return super().split(__condition)

0 comments on commit cbac14e

Please sign in to comment.