Skip to content

Commit

Permalink
BaseExceptionGroup.derive should not copy __notes__ (#112)
Browse files Browse the repository at this point in the history
This makes the behaviour follow that of CPython more closely. Instead, copy `__notes__` (if present) in the *callers* of derive. The (modified) test passes now, and it passes on Python 3.11. It fails before the changes.
  • Loading branch information
cfbolz committed Mar 5, 2024
1 parent 2f23259 commit ee53e9f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 18 deletions.
35 changes: 17 additions & 18 deletions src/exceptiongroup/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ def get_condition_filter(
raise TypeError("expected a function, exception type or tuple of exception types")


def _derive_and_copy_attributes(self, excs):
eg = self.derive(excs)
eg.__cause__ = self.__cause__
eg.__context__ = self.__context__
eg.__traceback__ = self.__traceback__
if hasattr(self, "__notes__"):
# Create a new list so that add_note() only affects one exceptiongroup
eg.__notes__ = list(self.__notes__)
return eg


class BaseExceptionGroup(BaseException, Generic[_BaseExceptionT_co]):
"""A combination of multiple unrelated exceptions."""

Expand Down Expand Up @@ -154,10 +165,7 @@ def subgroup(
if not modified:
return self
elif exceptions:
group = self.derive(exceptions)
group.__cause__ = self.__cause__
group.__context__ = self.__context__
group.__traceback__ = self.__traceback__
group = _derive_and_copy_attributes(self, exceptions)
return group
else:
return None
Expand Down Expand Up @@ -230,17 +238,13 @@ def split(

matching_group: _BaseExceptionGroupSelf | None = None
if matching_exceptions:
matching_group = self.derive(matching_exceptions)
matching_group.__cause__ = self.__cause__
matching_group.__context__ = self.__context__
matching_group.__traceback__ = self.__traceback__
matching_group = _derive_and_copy_attributes(self, matching_exceptions)

nonmatching_group: _BaseExceptionGroupSelf | None = None
if nonmatching_exceptions:
nonmatching_group = self.derive(nonmatching_exceptions)
nonmatching_group.__cause__ = self.__cause__
nonmatching_group.__context__ = self.__context__
nonmatching_group.__traceback__ = self.__traceback__
nonmatching_group = _derive_and_copy_attributes(
self, nonmatching_exceptions
)

return matching_group, nonmatching_group

Expand All @@ -257,12 +261,7 @@ def derive(
def derive(
self, __excs: Sequence[_BaseExceptionT]
) -> BaseExceptionGroup[_BaseExceptionT]:
eg = BaseExceptionGroup(self.message, __excs)
if hasattr(self, "__notes__"):
# Create a new list so that add_note() only affects one exceptiongroup
eg.__notes__ = list(self.__notes__)

return eg
return BaseExceptionGroup(self.message, __excs)

def __str__(self) -> str:
suffix = "" if len(self._exceptions) == 1 else "s"
Expand Down
14 changes: 14 additions & 0 deletions tests/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,13 @@ def test_notes_is_list_of_strings_if_it_exists(self):
eg.add_note(note)
self.assertEqual(eg.__notes__, [note])

def test_derive_doesn_copy_notes(self):
eg = create_simple_eg()
eg.add_note("hello")
assert eg.__notes__ == ["hello"]
eg2 = eg.derive([ValueError()])
assert not hasattr(eg2, "__notes__")


class ExceptionGroupTestBase(unittest.TestCase):
def assertMatchesTemplate(self, exc, exc_type, template):
Expand Down Expand Up @@ -786,6 +793,7 @@ def derive(self, excs):
except ValueError as ve:
raise EG("eg", [ve, nested], 42)
except EG as e:
e.add_note("hello")
eg = e

self.assertMatchesTemplate(eg, EG, [ValueError(1), [TypeError(2)]])
Expand All @@ -796,29 +804,35 @@ def derive(self, excs):
self.assertMatchesTemplate(rest, EG, [ValueError(1), [TypeError(2)]])
self.assertEqual(rest.code, 42)
self.assertEqual(rest.exceptions[1].code, 101)
self.assertEqual(rest.__notes__, ["hello"])

# Match Everything
match, rest = self.split_exception_group(eg, (ValueError, TypeError))
self.assertMatchesTemplate(match, EG, [ValueError(1), [TypeError(2)]])
self.assertEqual(match.code, 42)
self.assertEqual(match.exceptions[1].code, 101)
self.assertEqual(match.__notes__, ["hello"])
self.assertIsNone(rest)

# Match ValueErrors
match, rest = self.split_exception_group(eg, ValueError)
self.assertMatchesTemplate(match, EG, [ValueError(1)])
self.assertEqual(match.code, 42)
self.assertEqual(match.__notes__, ["hello"])
self.assertMatchesTemplate(rest, EG, [[TypeError(2)]])
self.assertEqual(rest.code, 42)
self.assertEqual(rest.exceptions[0].code, 101)
self.assertEqual(rest.__notes__, ["hello"])

# Match TypeErrors
match, rest = self.split_exception_group(eg, TypeError)
self.assertMatchesTemplate(match, EG, [[TypeError(2)]])
self.assertEqual(match.code, 42)
self.assertEqual(match.exceptions[0].code, 101)
self.assertEqual(match.__notes__, ["hello"])
self.assertMatchesTemplate(rest, EG, [ValueError(1)])
self.assertEqual(rest.code, 42)
self.assertEqual(rest.__notes__, ["hello"])


def test_repr():
Expand Down

0 comments on commit ee53e9f

Please sign in to comment.