diff --git a/mypy/checker.py b/mypy/checker.py index 8973ade98228..5744a4ef4937 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4305,7 +4305,7 @@ def visit_try_without_finally(self, s: TryStmt, try_frame: bool) -> None: with self.binder.frame_context(can_skip=True, fall_through=4): typ = s.types[i] if typ: - t = self.check_except_handler_test(typ) + t = self.check_except_handler_test(typ, s.is_star) var = s.vars[i] if var: # To support local variables, we make this a definition line, @@ -4325,7 +4325,7 @@ def visit_try_without_finally(self, s: TryStmt, try_frame: bool) -> None: if s.else_body: self.accept(s.else_body) - def check_except_handler_test(self, n: Expression) -> Type: + def check_except_handler_test(self, n: Expression, is_star: bool) -> Type: """Type check an exception handler test clause.""" typ = self.expr_checker.accept(n) @@ -4341,22 +4341,47 @@ def check_except_handler_test(self, n: Expression) -> Type: item = ttype.items[0] if not item.is_type_obj(): self.fail(message_registry.INVALID_EXCEPTION_TYPE, n) - return AnyType(TypeOfAny.from_error) - exc_type = item.ret_type + return self.default_exception_type(is_star) + exc_type = erase_typevars(item.ret_type) elif isinstance(ttype, TypeType): exc_type = ttype.item else: self.fail(message_registry.INVALID_EXCEPTION_TYPE, n) - return AnyType(TypeOfAny.from_error) + return self.default_exception_type(is_star) if not is_subtype(exc_type, self.named_type("builtins.BaseException")): self.fail(message_registry.INVALID_EXCEPTION_TYPE, n) - return AnyType(TypeOfAny.from_error) + return self.default_exception_type(is_star) all_types.append(exc_type) + if is_star: + new_all_types: list[Type] = [] + for typ in all_types: + if is_proper_subtype(typ, self.named_type("builtins.BaseExceptionGroup")): + self.fail(message_registry.INVALID_EXCEPTION_GROUP, n) + new_all_types.append(AnyType(TypeOfAny.from_error)) + else: + new_all_types.append(typ) + return self.wrap_exception_group(new_all_types) return make_simplified_union(all_types) + def default_exception_type(self, is_star: bool) -> Type: + """Exception type to return in case of a previous type error.""" + any_type = AnyType(TypeOfAny.from_error) + if is_star: + return self.named_generic_type("builtins.ExceptionGroup", [any_type]) + return any_type + + def wrap_exception_group(self, types: Sequence[Type]) -> Type: + """Transform except* variable type into an appropriate exception group.""" + arg = make_simplified_union(types) + if is_subtype(arg, self.named_type("builtins.Exception")): + base = "builtins.ExceptionGroup" + else: + base = "builtins.BaseExceptionGroup" + return self.named_generic_type(base, [arg]) + def get_types_from_except_handler(self, typ: Type, n: Expression) -> list[Type]: """Helper for check_except_handler_test to retrieve handler types.""" typ = get_proper_type(typ) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 0d42ef53f456..209ebb89f36b 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -1254,7 +1254,6 @@ def visit_Try(self, n: ast3.Try) -> TryStmt: return self.set_line(node, n) def visit_TryStar(self, n: TryStar) -> TryStmt: - # TODO: we treat TryStar exactly like Try, which makes mypy not crash. See #12840 vs = [ self.set_line(NameExpr(h.name), h) if h.name is not None else None for h in n.handlers ] @@ -1269,6 +1268,7 @@ def visit_TryStar(self, n: TryStar) -> TryStmt: self.as_block(n.orelse, n.lineno), self.as_block(n.finalbody, n.lineno), ) + node.is_star = True return self.set_line(node, n) # Assert(expr test, expr? msg) diff --git a/mypy/message_registry.py b/mypy/message_registry.py index c84ce120dbda..18acb2cd7a71 100644 --- a/mypy/message_registry.py +++ b/mypy/message_registry.py @@ -44,6 +44,9 @@ def with_additional_msg(self, info: str) -> ErrorMessage: NO_RETURN_EXPECTED: Final = ErrorMessage("Return statement in function which does not return") INVALID_EXCEPTION: Final = ErrorMessage("Exception must be derived from BaseException") INVALID_EXCEPTION_TYPE: Final = ErrorMessage("Exception type must be derived from BaseException") +INVALID_EXCEPTION_GROUP: Final = ErrorMessage( + "Exception type in except* cannot derive from BaseExceptionGroup" +) RETURN_IN_ASYNC_GENERATOR: Final = ErrorMessage( '"return" with value in async generator is not allowed' ) diff --git a/mypy/nodes.py b/mypy/nodes.py index 9221ec48aa61..0ea89611dc1a 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1529,9 +1529,9 @@ def accept(self, visitor: StatementVisitor[T]) -> T: class TryStmt(Statement): - __slots__ = ("body", "types", "vars", "handlers", "else_body", "finally_body") + __slots__ = ("body", "types", "vars", "handlers", "else_body", "finally_body", "is_star") - __match_args__ = ("body", "types", "vars", "handlers", "else_body", "finally_body") + __match_args__ = ("body", "types", "vars", "handlers", "else_body", "finally_body", "is_star") body: Block # Try body # Plain 'except:' also possible @@ -1540,6 +1540,8 @@ class TryStmt(Statement): handlers: list[Block] # Except bodies else_body: Block | None finally_body: Block | None + # Whether this is try ... except* (added in Python 3.11) + is_star: bool def __init__( self, @@ -1557,6 +1559,7 @@ def __init__( self.handlers = handlers self.else_body = else_body self.finally_body = finally_body + self.is_star = False def accept(self, visitor: StatementVisitor[T]) -> T: return visitor.visit_try_stmt(self) diff --git a/mypy/strconv.py b/mypy/strconv.py index 1acf7699316c..9b369618b88e 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -276,6 +276,8 @@ def visit_del_stmt(self, o: mypy.nodes.DelStmt) -> str: def visit_try_stmt(self, o: mypy.nodes.TryStmt) -> str: a: list[Any] = [o.body] + if o.is_star: + a.append("*") for i in range(len(o.vars)): a.append(o.types[i]) diff --git a/mypy/treetransform.py b/mypy/treetransform.py index d7f159d02a22..c863db6b3dd5 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -373,7 +373,7 @@ def visit_raise_stmt(self, node: RaiseStmt) -> RaiseStmt: return RaiseStmt(self.optional_expr(node.expr), self.optional_expr(node.from_expr)) def visit_try_stmt(self, node: TryStmt) -> TryStmt: - return TryStmt( + new = TryStmt( self.block(node.body), self.optional_names(node.vars), self.optional_expressions(node.types), @@ -381,6 +381,8 @@ def visit_try_stmt(self, node: TryStmt) -> TryStmt: self.optional_block(node.else_body), self.optional_block(node.finally_body), ) + new.is_star = node.is_star + return new def visit_with_stmt(self, node: WithStmt) -> WithStmt: new = WithStmt( diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 371a305e67b9..a1d36c011aa1 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -616,6 +616,8 @@ def transform_try_stmt(builder: IRBuilder, t: TryStmt) -> None: # constructs that we compile separately. When we have a # try/except/else/finally, we treat the try/except/else as the # body of a try/finally block. + if t.is_star: + builder.error("Exception groups and except* cannot be compiled yet", t.line) if t.finally_body: def transform_try_body() -> None: diff --git a/test-data/unit/check-python311.test b/test-data/unit/check-python311.test index b98bccc9059d..9bf62b0c489d 100644 --- a/test-data/unit/check-python311.test +++ b/test-data/unit/check-python311.test @@ -1,6 +1,53 @@ -[case testTryStarDoesNotCrash] +[case testTryStarSimple] try: pass except* Exception as e: - reveal_type(e) # N: Revealed type is "builtins.Exception" + reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[builtins.Exception]" +[builtins fixtures/exception.pyi] + +[case testTryStarMultiple] +try: + pass +except* Exception as e: + reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[builtins.Exception]" +except* RuntimeError as e: + reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[builtins.RuntimeError]" +[builtins fixtures/exception.pyi] + +[case testTryStarBase] +try: + pass +except* BaseException as e: + reveal_type(e) # N: Revealed type is "builtins.BaseExceptionGroup[builtins.BaseException]" +[builtins fixtures/exception.pyi] + +[case testTryStarTuple] +class Custom(Exception): ... + +try: + pass +except* (RuntimeError, Custom) as e: + reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[Union[builtins.RuntimeError, __main__.Custom]]" +[builtins fixtures/exception.pyi] + +[case testTryStarInvalidType] +class Bad: ... +try: + pass +except* (RuntimeError, Bad) as e: # E: Exception type must be derived from BaseException + reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[Any]" +[builtins fixtures/exception.pyi] + +[case testTryStarGroupInvalid] +try: + pass +except* ExceptionGroup as e: # E: Exception type in except* cannot derive from BaseExceptionGroup + reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[Any]" +[builtins fixtures/exception.pyi] + +[case testTryStarGroupInvalidTuple] +try: + pass +except* (RuntimeError, ExceptionGroup) as e: # E: Exception type in except* cannot derive from BaseExceptionGroup + reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[Union[builtins.RuntimeError, Any]]" [builtins fixtures/exception.pyi] diff --git a/test-data/unit/fixtures/exception.pyi b/test-data/unit/fixtures/exception.pyi index bf6d21c8716e..1c88723e7191 100644 --- a/test-data/unit/fixtures/exception.pyi +++ b/test-data/unit/fixtures/exception.pyi @@ -1,3 +1,4 @@ +import sys from typing import Generic, TypeVar T = TypeVar('T') @@ -5,7 +6,8 @@ class object: def __init__(self): pass class type: pass -class tuple(Generic[T]): pass +class tuple(Generic[T]): + def __ge__(self, other: object) -> bool: ... class function: pass class int: pass class str: pass @@ -13,11 +15,14 @@ class unicode: pass class bool: pass class ellipsis: pass -# Note: this is a slight simplification. In Python 2, the inheritance hierarchy -# is actually Exception -> StandardError -> RuntimeError -> ... class BaseException: def __init__(self, *args: object) -> None: ... class Exception(BaseException): pass class RuntimeError(Exception): pass class NotImplementedError(RuntimeError): pass +if sys.version_info >= (3, 11): + _BT_co = TypeVar("_BT_co", bound=BaseException, covariant=True) + _T_co = TypeVar("_T_co", bound=Exception, covariant=True) + class BaseExceptionGroup(BaseException, Generic[_BT_co]): ... + class ExceptionGroup(BaseExceptionGroup[_T_co], Exception): ...