Skip to content

Commit

Permalink
Backported upstream fix for gh-99553 (#51)
Browse files Browse the repository at this point in the history
Fixes #50.
  • Loading branch information
agronholm committed Dec 22, 2022
1 parent f32faa2 commit 6e0f331
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 15 deletions.
5 changes: 5 additions & 0 deletions CHANGES.rst
Expand Up @@ -3,6 +3,11 @@ Version history

This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.

**UNRELEASED**

- Backported upstream fix for gh-99553 (custom subclasses of ``BaseExceptionGroup`` that
also inherit from ``Exception`` should not be able to wrap base exceptions)

**1.0.4**

- Fixed regression introduced in v1.0.3 where the code computing the suggestions would
Expand Down
22 changes: 13 additions & 9 deletions src/exceptiongroup/_exceptions.py
Expand Up @@ -67,6 +67,18 @@ def __new__(
if all(isinstance(exc, Exception) for exc in __exceptions):
cls = ExceptionGroup

if issubclass(cls, Exception):
for exc in __exceptions:
if not isinstance(exc, Exception):
if cls is ExceptionGroup:
raise TypeError(
"Cannot nest BaseExceptions in an ExceptionGroup"
)
else:
raise TypeError(
f"Cannot nest BaseExceptions in {cls.__name__!r}"
)

return super().__new__(cls, __message, __exceptions)

def __init__(
Expand Down Expand Up @@ -219,15 +231,7 @@ def __repr__(self) -> str:

class ExceptionGroup(BaseExceptionGroup[_ExceptionT_co], Exception):
def __new__(cls, __message: str, __exceptions: Sequence[_ExceptionT_co]) -> Self:
instance: ExceptionGroup[_ExceptionT_co] = super().__new__(
cls, __message, __exceptions
)
if cls is ExceptionGroup:
for exc in __exceptions:
if not isinstance(exc, Exception):
raise TypeError("Cannot nest BaseExceptions in an ExceptionGroup")

return instance
return super().__new__(cls, __message, __exceptions)

if TYPE_CHECKING:

Expand Down
30 changes: 24 additions & 6 deletions tests/test_exceptions.py
Expand Up @@ -3,6 +3,8 @@
import sys
import unittest

import pytest

from exceptiongroup import BaseExceptionGroup, ExceptionGroup


Expand Down Expand Up @@ -90,19 +92,35 @@ def test_BEG_wraps_BaseException__creates_BEG(self):
beg = BaseExceptionGroup("beg", [ValueError(1), KeyboardInterrupt(2)])
self.assertIs(type(beg), BaseExceptionGroup)

def test_EG_subclass_wraps_anything(self):
def test_EG_subclass_wraps_non_base_exceptions(self):
class MyEG(ExceptionGroup):
pass

self.assertIs(type(MyEG("eg", [ValueError(12), TypeError(42)])), MyEG)
self.assertIs(type(MyEG("eg", [ValueError(12), KeyboardInterrupt(42)])), MyEG)

def test_BEG_subclass_wraps_anything(self):
class MyBEG(BaseExceptionGroup):
@pytest.mark.skipif(
sys.version_info[:3] == (3, 11, 0),
reason="Behavior was made stricter in 3.11.1",
)
def test_EG_subclass_does_not_wrap_base_exceptions(self):
class MyEG(ExceptionGroup):
pass

msg = "Cannot nest BaseExceptions in 'MyEG'"
with self.assertRaisesRegex(TypeError, msg):
MyEG("eg", [ValueError(12), KeyboardInterrupt(42)])

@pytest.mark.skipif(
sys.version_info[:3] == (3, 11, 0),
reason="Behavior was made stricter in 3.11.1",
)
def test_BEG_and_E_subclass_does_not_wrap_base_exceptions(self):
class MyEG(BaseExceptionGroup, ValueError):
pass

self.assertIs(type(MyBEG("eg", [ValueError(12), TypeError(42)])), MyBEG)
self.assertIs(type(MyBEG("eg", [ValueError(12), KeyboardInterrupt(42)])), MyBEG)
msg = "Cannot nest BaseExceptions in 'MyEG'"
with self.assertRaisesRegex(TypeError, msg):
MyEG("eg", [ValueError(12), KeyboardInterrupt(42)])


def create_simple_eg():
Expand Down

0 comments on commit 6e0f331

Please sign in to comment.