Skip to content

Commit

Permalink
Fixed bare raise and exception chaining when a handler raises an ex…
Browse files Browse the repository at this point in the history
…ception (#71)
  • Loading branch information
agronholm committed Aug 9, 2023
1 parent 0c94abe commit 8b8791b
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 9 deletions.
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
- `catch()` now raises a `TypeError` if passed an async exception handler instead of
just giving a `RuntimeWarning` about the coroutine never being awaited. (#66, PR by
John Litborn)
- Fixed plain ``raise`` statement in an exception handler callback to work like a
``raise`` in an ``except*`` block
- Fixed new exception group not being chained to the original exception when raising an
exception group from exceptions raised in handler callbacks

**1.1.2**

Expand Down
21 changes: 16 additions & 5 deletions src/exceptiongroup/_catch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,16 @@ def __exit__(
elif unhandled is None:
return True
else:
raise unhandled from None
if isinstance(exc, BaseExceptionGroup):
try:
raise unhandled from exc.__cause__
except BaseExceptionGroup:
# Change __context__ to __cause__ because Python 3.11 does this
# too
unhandled.__context__ = exc.__cause__
raise

raise unhandled from exc

return False

Expand All @@ -50,7 +59,12 @@ def handle_exception(self, exc: BaseException) -> BaseException | None:
matched, excgroup = excgroup.split(exc_types)
if matched:
try:
result = handler(matched)
try:
raise matched
except BaseExceptionGroup:
result = handler(matched)
except BaseExceptionGroup as new_exc:
new_exceptions.extend(new_exc.exceptions)
except BaseException as new_exc:
new_exceptions.append(new_exc)
else:
Expand All @@ -67,9 +81,6 @@ def handle_exception(self, exc: BaseException) -> BaseException | None:
if len(new_exceptions) == 1:
return new_exceptions[0]

if excgroup:
new_exceptions.append(excgroup)

return BaseExceptionGroup("", new_exceptions)
elif (
excgroup and len(excgroup.exceptions) == 1 and excgroup.exceptions[0] is exc
Expand Down
36 changes: 34 additions & 2 deletions tests/test_catch.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,41 @@ def test_catch_handler_raises():
def handler(exc):
raise RuntimeError("new")

with pytest.raises(RuntimeError, match="new"):
with pytest.raises(RuntimeError, match="new") as exc:
with catch({(ValueError, ValueError): handler}):
raise ExceptionGroup("booboo", [ValueError("bar")])
excgrp = ExceptionGroup("booboo", [ValueError("bar")])
raise excgrp

context = exc.value.__context__
assert isinstance(context, ExceptionGroup)
assert str(context) == "booboo (1 sub-exception)"
assert len(context.exceptions) == 1
assert isinstance(context.exceptions[0], ValueError)
assert exc.value.__cause__ is None


def test_bare_raise_in_handler():
"""Test that a bare "raise" "middle" ecxeption group gets discarded."""

def handler(exc):
raise

with pytest.raises(ExceptionGroup) as excgrp:
with catch({(ValueError,): handler, (RuntimeError,): lambda eg: None}):
try:
first_exc = RuntimeError("first")
raise first_exc
except RuntimeError as exc:
middle_exc = ExceptionGroup(
"bad", [ValueError(), ValueError(), TypeError()]
)
raise middle_exc from exc

assert len(excgrp.value.exceptions) == 2
assert all(isinstance(exc, ValueError) for exc in excgrp.value.exceptions)
assert excgrp.value is not middle_exc
assert excgrp.value.__cause__ is first_exc
assert excgrp.value.__context__ is first_exc


def test_catch_subclass():
Expand Down
34 changes: 32 additions & 2 deletions tests/test_catch_py311.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,20 @@ def test_catch_full_match():
reason="Behavior was changed in 3.11.4",
)
def test_catch_handler_raises():
with pytest.raises(RuntimeError, match="new"):
with pytest.raises(RuntimeError, match="new") as exc:
try:
raise ExceptionGroup("booboo", [ValueError("bar")])
excgrp = ExceptionGroup("booboo", [ValueError("bar")])
raise excgrp
except* ValueError:
raise RuntimeError("new")

context = exc.value.__context__
assert isinstance(context, ExceptionGroup)
assert str(context) == "booboo (1 sub-exception)"
assert len(context.exceptions) == 1
assert isinstance(context.exceptions[0], ValueError)
assert exc.value.__cause__ is None


def test_catch_subclass():
lookup_errors = []
Expand All @@ -146,3 +154,25 @@ def test_catch_subclass():
assert isinstance(lookup_errors[0], ExceptionGroup)
exceptions = lookup_errors[0].exceptions
assert isinstance(exceptions[0], KeyError)


def test_bare_raise_in_handler():
"""Test that the "middle" ecxeption group gets discarded."""
with pytest.raises(ExceptionGroup) as excgrp:
try:
try:
first_exc = RuntimeError("first")
raise first_exc
except RuntimeError as exc:
middle_exc = ExceptionGroup(
"bad", [ValueError(), ValueError(), TypeError()]
)
raise middle_exc from exc
except* ValueError:
raise
except* TypeError:
pass

assert excgrp.value is not middle_exc
assert excgrp.value.__cause__ is first_exc
assert excgrp.value.__context__ is first_exc

0 comments on commit 8b8791b

Please sign in to comment.