Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for exception groups and except* #14020

Merged
merged 1 commit into from Nov 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
37 changes: 31 additions & 6 deletions mypy/checker.py
Expand Up @@ -4305,7 +4305,7 @@ def visit_try_without_finally(self, s: TryStmt, try_frame: bool) -> None:
with self.binder.frame_context(can_skip=True, fall_through=4):
typ = s.types[i]
if typ:
t = self.check_except_handler_test(typ)
t = self.check_except_handler_test(typ, s.is_star)
var = s.vars[i]
if var:
# To support local variables, we make this a definition line,
Expand All @@ -4325,7 +4325,7 @@ def visit_try_without_finally(self, s: TryStmt, try_frame: bool) -> None:
if s.else_body:
self.accept(s.else_body)

def check_except_handler_test(self, n: Expression) -> Type:
def check_except_handler_test(self, n: Expression, is_star: bool) -> Type:
"""Type check an exception handler test clause."""
typ = self.expr_checker.accept(n)

Expand All @@ -4341,22 +4341,47 @@ def check_except_handler_test(self, n: Expression) -> Type:
item = ttype.items[0]
if not item.is_type_obj():
self.fail(message_registry.INVALID_EXCEPTION_TYPE, n)
return AnyType(TypeOfAny.from_error)
exc_type = item.ret_type
return self.default_exception_type(is_star)
exc_type = erase_typevars(item.ret_type)
elif isinstance(ttype, TypeType):
exc_type = ttype.item
else:
self.fail(message_registry.INVALID_EXCEPTION_TYPE, n)
return AnyType(TypeOfAny.from_error)
return self.default_exception_type(is_star)

if not is_subtype(exc_type, self.named_type("builtins.BaseException")):
self.fail(message_registry.INVALID_EXCEPTION_TYPE, n)
return AnyType(TypeOfAny.from_error)
return self.default_exception_type(is_star)

all_types.append(exc_type)

if is_star:
new_all_types: list[Type] = []
for typ in all_types:
if is_proper_subtype(typ, self.named_type("builtins.BaseExceptionGroup")):
self.fail(message_registry.INVALID_EXCEPTION_GROUP, n)
new_all_types.append(AnyType(TypeOfAny.from_error))
else:
new_all_types.append(typ)
return self.wrap_exception_group(new_all_types)
return make_simplified_union(all_types)

def default_exception_type(self, is_star: bool) -> Type:
"""Exception type to return in case of a previous type error."""
any_type = AnyType(TypeOfAny.from_error)
if is_star:
return self.named_generic_type("builtins.ExceptionGroup", [any_type])
return any_type

def wrap_exception_group(self, types: Sequence[Type]) -> Type:
"""Transform except* variable type into an appropriate exception group."""
arg = make_simplified_union(types)
if is_subtype(arg, self.named_type("builtins.Exception")):
base = "builtins.ExceptionGroup"
else:
base = "builtins.BaseExceptionGroup"
return self.named_generic_type(base, [arg])

def get_types_from_except_handler(self, typ: Type, n: Expression) -> list[Type]:
"""Helper for check_except_handler_test to retrieve handler types."""
typ = get_proper_type(typ)
Expand Down
2 changes: 1 addition & 1 deletion mypy/fastparse.py
Expand Up @@ -1254,7 +1254,6 @@ def visit_Try(self, n: ast3.Try) -> TryStmt:
return self.set_line(node, n)

def visit_TryStar(self, n: TryStar) -> TryStmt:
# TODO: we treat TryStar exactly like Try, which makes mypy not crash. See #12840
vs = [
self.set_line(NameExpr(h.name), h) if h.name is not None else None for h in n.handlers
]
Expand All @@ -1269,6 +1268,7 @@ def visit_TryStar(self, n: TryStar) -> TryStmt:
self.as_block(n.orelse, n.lineno),
self.as_block(n.finalbody, n.lineno),
)
node.is_star = True
return self.set_line(node, n)

# Assert(expr test, expr? msg)
Expand Down
3 changes: 3 additions & 0 deletions mypy/message_registry.py
Expand Up @@ -44,6 +44,9 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
NO_RETURN_EXPECTED: Final = ErrorMessage("Return statement in function which does not return")
INVALID_EXCEPTION: Final = ErrorMessage("Exception must be derived from BaseException")
INVALID_EXCEPTION_TYPE: Final = ErrorMessage("Exception type must be derived from BaseException")
INVALID_EXCEPTION_GROUP: Final = ErrorMessage(
"Exception type in except* cannot derive from BaseExceptionGroup"
)
RETURN_IN_ASYNC_GENERATOR: Final = ErrorMessage(
'"return" with value in async generator is not allowed'
)
Expand Down
7 changes: 5 additions & 2 deletions mypy/nodes.py
Expand Up @@ -1529,9 +1529,9 @@ def accept(self, visitor: StatementVisitor[T]) -> T:


class TryStmt(Statement):
__slots__ = ("body", "types", "vars", "handlers", "else_body", "finally_body")
__slots__ = ("body", "types", "vars", "handlers", "else_body", "finally_body", "is_star")

__match_args__ = ("body", "types", "vars", "handlers", "else_body", "finally_body")
__match_args__ = ("body", "types", "vars", "handlers", "else_body", "finally_body", "is_star")

body: Block # Try body
# Plain 'except:' also possible
Expand All @@ -1540,6 +1540,8 @@ class TryStmt(Statement):
handlers: list[Block] # Except bodies
else_body: Block | None
finally_body: Block | None
# Whether this is try ... except* (added in Python 3.11)
is_star: bool

def __init__(
self,
Expand All @@ -1557,6 +1559,7 @@ def __init__(
self.handlers = handlers
self.else_body = else_body
self.finally_body = finally_body
self.is_star = False

def accept(self, visitor: StatementVisitor[T]) -> T:
return visitor.visit_try_stmt(self)
Expand Down
2 changes: 2 additions & 0 deletions mypy/strconv.py
Expand Up @@ -276,6 +276,8 @@ def visit_del_stmt(self, o: mypy.nodes.DelStmt) -> str:

def visit_try_stmt(self, o: mypy.nodes.TryStmt) -> str:
a: list[Any] = [o.body]
if o.is_star:
a.append("*")

for i in range(len(o.vars)):
a.append(o.types[i])
Expand Down
4 changes: 3 additions & 1 deletion mypy/treetransform.py
Expand Up @@ -373,14 +373,16 @@ def visit_raise_stmt(self, node: RaiseStmt) -> RaiseStmt:
return RaiseStmt(self.optional_expr(node.expr), self.optional_expr(node.from_expr))

def visit_try_stmt(self, node: TryStmt) -> TryStmt:
return TryStmt(
new = TryStmt(
self.block(node.body),
self.optional_names(node.vars),
self.optional_expressions(node.types),
self.blocks(node.handlers),
self.optional_block(node.else_body),
self.optional_block(node.finally_body),
)
new.is_star = node.is_star
return new

def visit_with_stmt(self, node: WithStmt) -> WithStmt:
new = WithStmt(
Expand Down
2 changes: 2 additions & 0 deletions mypyc/irbuild/statement.py
Expand Up @@ -616,6 +616,8 @@ def transform_try_stmt(builder: IRBuilder, t: TryStmt) -> None:
# constructs that we compile separately. When we have a
# try/except/else/finally, we treat the try/except/else as the
# body of a try/finally block.
if t.is_star:
builder.error("Exception groups and except* cannot be compiled yet", t.line)
if t.finally_body:

def transform_try_body() -> None:
Expand Down
51 changes: 49 additions & 2 deletions test-data/unit/check-python311.test
@@ -1,6 +1,53 @@
[case testTryStarDoesNotCrash]
[case testTryStarSimple]
try:
pass
except* Exception as e:
reveal_type(e) # N: Revealed type is "builtins.Exception"
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[builtins.Exception]"
[builtins fixtures/exception.pyi]

[case testTryStarMultiple]
try:
pass
except* Exception as e:
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[builtins.Exception]"
except* RuntimeError as e:
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[builtins.RuntimeError]"
[builtins fixtures/exception.pyi]

[case testTryStarBase]
try:
pass
except* BaseException as e:
reveal_type(e) # N: Revealed type is "builtins.BaseExceptionGroup[builtins.BaseException]"
[builtins fixtures/exception.pyi]

[case testTryStarTuple]
class Custom(Exception): ...

try:
pass
except* (RuntimeError, Custom) as e:
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[Union[builtins.RuntimeError, __main__.Custom]]"
[builtins fixtures/exception.pyi]

[case testTryStarInvalidType]
class Bad: ...
try:
pass
except* (RuntimeError, Bad) as e: # E: Exception type must be derived from BaseException
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[Any]"
[builtins fixtures/exception.pyi]

[case testTryStarGroupInvalid]
try:
pass
except* ExceptionGroup as e: # E: Exception type in except* cannot derive from BaseExceptionGroup
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[Any]"
[builtins fixtures/exception.pyi]

[case testTryStarGroupInvalidTuple]
try:
pass
except* (RuntimeError, ExceptionGroup) as e: # E: Exception type in except* cannot derive from BaseExceptionGroup
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[Union[builtins.RuntimeError, Any]]"
[builtins fixtures/exception.pyi]
11 changes: 8 additions & 3 deletions test-data/unit/fixtures/exception.pyi
@@ -1,23 +1,28 @@
import sys
from typing import Generic, TypeVar
T = TypeVar('T')

class object:
def __init__(self): pass

class type: pass
class tuple(Generic[T]): pass
class tuple(Generic[T]):
def __ge__(self, other: object) -> bool: ...
class function: pass
class int: pass
class str: pass
class unicode: pass
class bool: pass
class ellipsis: pass

# Note: this is a slight simplification. In Python 2, the inheritance hierarchy
# is actually Exception -> StandardError -> RuntimeError -> ...
class BaseException:
def __init__(self, *args: object) -> None: ...
class Exception(BaseException): pass
class RuntimeError(Exception): pass
class NotImplementedError(RuntimeError): pass

if sys.version_info >= (3, 11):
_BT_co = TypeVar("_BT_co", bound=BaseException, covariant=True)
_T_co = TypeVar("_T_co", bound=Exception, covariant=True)
class BaseExceptionGroup(BaseException, Generic[_BT_co]): ...
class ExceptionGroup(BaseExceptionGroup[_T_co], Exception): ...