Skip to content

Commit

Permalink
Add support for exception groups and except* (#14020)
Browse files Browse the repository at this point in the history
Ref #12840

It looks like from the point of view of type checking support is quite
easy. Mypyc support however requires some actual work, so I don't
include it in this PR.
  • Loading branch information
ilevkivskyi authored and svalentin committed Nov 7, 2022
1 parent 91b6fc3 commit 719cef9
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 14 deletions.
37 changes: 31 additions & 6 deletions mypy/checker.py
Expand Up @@ -4307,7 +4307,7 @@ def visit_try_without_finally(self, s: TryStmt, try_frame: bool) -> None:
with self.binder.frame_context(can_skip=True, fall_through=4):
typ = s.types[i]
if typ:
t = self.check_except_handler_test(typ)
t = self.check_except_handler_test(typ, s.is_star)
var = s.vars[i]
if var:
# To support local variables, we make this a definition line,
Expand All @@ -4327,7 +4327,7 @@ def visit_try_without_finally(self, s: TryStmt, try_frame: bool) -> None:
if s.else_body:
self.accept(s.else_body)

def check_except_handler_test(self, n: Expression) -> Type:
def check_except_handler_test(self, n: Expression, is_star: bool) -> Type:
"""Type check an exception handler test clause."""
typ = self.expr_checker.accept(n)

Expand All @@ -4343,22 +4343,47 @@ def check_except_handler_test(self, n: Expression) -> Type:
item = ttype.items[0]
if not item.is_type_obj():
self.fail(message_registry.INVALID_EXCEPTION_TYPE, n)
return AnyType(TypeOfAny.from_error)
exc_type = item.ret_type
return self.default_exception_type(is_star)
exc_type = erase_typevars(item.ret_type)
elif isinstance(ttype, TypeType):
exc_type = ttype.item
else:
self.fail(message_registry.INVALID_EXCEPTION_TYPE, n)
return AnyType(TypeOfAny.from_error)
return self.default_exception_type(is_star)

if not is_subtype(exc_type, self.named_type("builtins.BaseException")):
self.fail(message_registry.INVALID_EXCEPTION_TYPE, n)
return AnyType(TypeOfAny.from_error)
return self.default_exception_type(is_star)

all_types.append(exc_type)

if is_star:
new_all_types: list[Type] = []
for typ in all_types:
if is_proper_subtype(typ, self.named_type("builtins.BaseExceptionGroup")):
self.fail(message_registry.INVALID_EXCEPTION_GROUP, n)
new_all_types.append(AnyType(TypeOfAny.from_error))
else:
new_all_types.append(typ)
return self.wrap_exception_group(new_all_types)
return make_simplified_union(all_types)

def default_exception_type(self, is_star: bool) -> Type:
"""Exception type to return in case of a previous type error."""
any_type = AnyType(TypeOfAny.from_error)
if is_star:
return self.named_generic_type("builtins.ExceptionGroup", [any_type])
return any_type

def wrap_exception_group(self, types: Sequence[Type]) -> Type:
"""Transform except* variable type into an appropriate exception group."""
arg = make_simplified_union(types)
if is_subtype(arg, self.named_type("builtins.Exception")):
base = "builtins.ExceptionGroup"
else:
base = "builtins.BaseExceptionGroup"
return self.named_generic_type(base, [arg])

def get_types_from_except_handler(self, typ: Type, n: Expression) -> list[Type]:
"""Helper for check_except_handler_test to retrieve handler types."""
typ = get_proper_type(typ)
Expand Down
2 changes: 1 addition & 1 deletion mypy/fastparse.py
Expand Up @@ -1254,7 +1254,6 @@ def visit_Try(self, n: ast3.Try) -> TryStmt:
return self.set_line(node, n)

def visit_TryStar(self, n: TryStar) -> TryStmt:
# TODO: we treat TryStar exactly like Try, which makes mypy not crash. See #12840
vs = [
self.set_line(NameExpr(h.name), h) if h.name is not None else None for h in n.handlers
]
Expand All @@ -1269,6 +1268,7 @@ def visit_TryStar(self, n: TryStar) -> TryStmt:
self.as_block(n.orelse, n.lineno),
self.as_block(n.finalbody, n.lineno),
)
node.is_star = True
return self.set_line(node, n)

# Assert(expr test, expr? msg)
Expand Down
3 changes: 3 additions & 0 deletions mypy/message_registry.py
Expand Up @@ -44,6 +44,9 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
NO_RETURN_EXPECTED: Final = ErrorMessage("Return statement in function which does not return")
INVALID_EXCEPTION: Final = ErrorMessage("Exception must be derived from BaseException")
INVALID_EXCEPTION_TYPE: Final = ErrorMessage("Exception type must be derived from BaseException")
INVALID_EXCEPTION_GROUP: Final = ErrorMessage(
"Exception type in except* cannot derive from BaseExceptionGroup"
)
RETURN_IN_ASYNC_GENERATOR: Final = ErrorMessage(
'"return" with value in async generator is not allowed'
)
Expand Down
5 changes: 4 additions & 1 deletion mypy/nodes.py
Expand Up @@ -1485,7 +1485,7 @@ def accept(self, visitor: StatementVisitor[T]) -> T:


class TryStmt(Statement):
__slots__ = ("body", "types", "vars", "handlers", "else_body", "finally_body")
__slots__ = ("body", "types", "vars", "handlers", "else_body", "finally_body", "is_star")

body: Block # Try body
# Plain 'except:' also possible
Expand All @@ -1494,6 +1494,8 @@ class TryStmt(Statement):
handlers: list[Block] # Except bodies
else_body: Block | None
finally_body: Block | None
# Whether this is try ... except* (added in Python 3.11)
is_star: bool

def __init__(
self,
Expand All @@ -1511,6 +1513,7 @@ def __init__(
self.handlers = handlers
self.else_body = else_body
self.finally_body = finally_body
self.is_star = False

def accept(self, visitor: StatementVisitor[T]) -> T:
return visitor.visit_try_stmt(self)
Expand Down
2 changes: 2 additions & 0 deletions mypy/strconv.py
Expand Up @@ -276,6 +276,8 @@ def visit_del_stmt(self, o: mypy.nodes.DelStmt) -> str:

def visit_try_stmt(self, o: mypy.nodes.TryStmt) -> str:
a: list[Any] = [o.body]
if o.is_star:
a.append("*")

for i in range(len(o.vars)):
a.append(o.types[i])
Expand Down
4 changes: 3 additions & 1 deletion mypy/treetransform.py
Expand Up @@ -373,14 +373,16 @@ def visit_raise_stmt(self, node: RaiseStmt) -> RaiseStmt:
return RaiseStmt(self.optional_expr(node.expr), self.optional_expr(node.from_expr))

def visit_try_stmt(self, node: TryStmt) -> TryStmt:
return TryStmt(
new = TryStmt(
self.block(node.body),
self.optional_names(node.vars),
self.optional_expressions(node.types),
self.blocks(node.handlers),
self.optional_block(node.else_body),
self.optional_block(node.finally_body),
)
new.is_star = node.is_star
return new

def visit_with_stmt(self, node: WithStmt) -> WithStmt:
new = WithStmt(
Expand Down
2 changes: 2 additions & 0 deletions mypyc/irbuild/statement.py
Expand Up @@ -616,6 +616,8 @@ def transform_try_stmt(builder: IRBuilder, t: TryStmt) -> None:
# constructs that we compile separately. When we have a
# try/except/else/finally, we treat the try/except/else as the
# body of a try/finally block.
if t.is_star:
builder.error("Exception groups and except* cannot be compiled yet", t.line)
if t.finally_body:

def transform_try_body() -> None:
Expand Down
51 changes: 49 additions & 2 deletions test-data/unit/check-python311.test
@@ -1,6 +1,53 @@
[case testTryStarDoesNotCrash]
[case testTryStarSimple]
try:
pass
except* Exception as e:
reveal_type(e) # N: Revealed type is "builtins.Exception"
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[builtins.Exception]"
[builtins fixtures/exception.pyi]

[case testTryStarMultiple]
try:
pass
except* Exception as e:
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[builtins.Exception]"
except* RuntimeError as e:
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[builtins.RuntimeError]"
[builtins fixtures/exception.pyi]

[case testTryStarBase]
try:
pass
except* BaseException as e:
reveal_type(e) # N: Revealed type is "builtins.BaseExceptionGroup[builtins.BaseException]"
[builtins fixtures/exception.pyi]

[case testTryStarTuple]
class Custom(Exception): ...

try:
pass
except* (RuntimeError, Custom) as e:
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[Union[builtins.RuntimeError, __main__.Custom]]"
[builtins fixtures/exception.pyi]

[case testTryStarInvalidType]
class Bad: ...
try:
pass
except* (RuntimeError, Bad) as e: # E: Exception type must be derived from BaseException
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[Any]"
[builtins fixtures/exception.pyi]

[case testTryStarGroupInvalid]
try:
pass
except* ExceptionGroup as e: # E: Exception type in except* cannot derive from BaseExceptionGroup
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[Any]"
[builtins fixtures/exception.pyi]

[case testTryStarGroupInvalidTuple]
try:
pass
except* (RuntimeError, ExceptionGroup) as e: # E: Exception type in except* cannot derive from BaseExceptionGroup
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[Union[builtins.RuntimeError, Any]]"
[builtins fixtures/exception.pyi]
11 changes: 8 additions & 3 deletions test-data/unit/fixtures/exception.pyi
@@ -1,23 +1,28 @@
import sys
from typing import Generic, TypeVar
T = TypeVar('T')

class object:
def __init__(self): pass

class type: pass
class tuple(Generic[T]): pass
class tuple(Generic[T]):
def __ge__(self, other: object) -> bool: ...
class function: pass
class int: pass
class str: pass
class unicode: pass
class bool: pass
class ellipsis: pass

# Note: this is a slight simplification. In Python 2, the inheritance hierarchy
# is actually Exception -> StandardError -> RuntimeError -> ...
class BaseException:
def __init__(self, *args: object) -> None: ...
class Exception(BaseException): pass
class RuntimeError(Exception): pass
class NotImplementedError(RuntimeError): pass

if sys.version_info >= (3, 11):
_BT_co = TypeVar("_BT_co", bound=BaseException, covariant=True)
_T_co = TypeVar("_T_co", bound=Exception, covariant=True)
class BaseExceptionGroup(BaseException, Generic[_BT_co]): ...
class ExceptionGroup(BaseExceptionGroup[_T_co], Exception): ...

0 comments on commit 719cef9

Please sign in to comment.