From 963e4aedbe823ad754847fcac122ce678270f72f Mon Sep 17 00:00:00 2001 From: Hugues Date: Mon, 18 Apr 2022 07:54:02 -0700 Subject: [PATCH] more enum-related speedups (#12032) As a followup to #9394 address a few more O(n**2) behaviors caused by decomposing enums into unions of literals. --- mypy/meet.py | 30 ++++++++++++++++++++++++ mypy/sametypes.py | 34 ++++++++++++++++++++++----- mypy/subtypes.py | 59 +++++++++++++++++++++++++++++++++++++++-------- mypy/typeops.py | 9 ++++++++ 4 files changed, 116 insertions(+), 16 deletions(-) diff --git a/mypy/meet.py b/mypy/meet.py index 535e48bc3a23..5ee64416490d 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -64,6 +64,8 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type: if isinstance(declared, UnionType): return make_simplified_union([narrow_declared_type(x, narrowed) for x in declared.relevant_items()]) + if is_enum_overlapping_union(declared, narrowed): + return narrowed elif not is_overlapping_types(declared, narrowed, prohibit_none_typevar_overlap=True): if state.strict_optional: @@ -137,6 +139,22 @@ def get_possible_variants(typ: Type) -> List[Type]: return [typ] +def is_enum_overlapping_union(x: ProperType, y: ProperType) -> bool: + """Return True if x is an Enum, and y is an Union with at least one Literal from x""" + return ( + isinstance(x, Instance) and x.type.is_enum and + isinstance(y, UnionType) and + any(isinstance(p, LiteralType) and x.type == p.fallback.type + for p in (get_proper_type(z) for z in y.relevant_items())) + ) + + +def is_literal_in_union(x: ProperType, y: ProperType) -> bool: + """Return True if x is a Literal and y is an Union that includes x""" + return (isinstance(x, LiteralType) and isinstance(y, UnionType) and + any(x == get_proper_type(z) for z in y.items)) + + def is_overlapping_types(left: Type, right: Type, ignore_promotions: bool = False, @@ -198,6 +216,18 @@ def _is_overlapping_types(left: Type, right: Type) -> bool: # # These checks will also handle the NoneType and UninhabitedType cases for us. + # enums are sometimes expanded into an Union of Literals + # when that happens we want to make sure we treat the two as overlapping + # and crucially, we want to do that *fast* in case the enum is large + # so we do it before expanding variants below to avoid O(n**2) behavior + if ( + is_enum_overlapping_union(left, right) + or is_enum_overlapping_union(right, left) + or is_literal_in_union(left, right) + or is_literal_in_union(right, left) + ): + return True + if (is_proper_subtype(left, right, ignore_promotions=ignore_promotions) or is_proper_subtype(right, left, ignore_promotions=ignore_promotions)): return True diff --git a/mypy/sametypes.py b/mypy/sametypes.py index 46798018410d..1c22c32f8b06 100644 --- a/mypy/sametypes.py +++ b/mypy/sametypes.py @@ -1,4 +1,4 @@ -from typing import Sequence +from typing import Sequence, Tuple, Set, List from mypy.types import ( Type, UnboundType, AnyType, NoneType, TupleType, TypedDictType, @@ -6,7 +6,7 @@ Overloaded, PartialType, DeletedType, UninhabitedType, TypeType, LiteralType, ProperType, get_proper_type, TypeAliasType, ParamSpecType, Parameters, UnpackType ) -from mypy.typeops import tuple_fallback, make_simplified_union +from mypy.typeops import tuple_fallback, make_simplified_union, is_simple_literal def is_same_type(left: Type, right: Type) -> bool: @@ -49,6 +49,22 @@ def is_same_types(a1: Sequence[Type], a2: Sequence[Type]) -> bool: return True +def _extract_literals(u: UnionType) -> Tuple[Set[Type], List[Type]]: + """Given a UnionType, separate out its items into a set of simple literals and a remainder list + This is a useful helper to avoid O(n**2) behavior when comparing large unions, which can often + result from large enums in contexts where type narrowing removes a small subset of entries. + """ + lit: Set[Type] = set() + rem: List[Type] = [] + for i in u.relevant_items(): + i = get_proper_type(i) + if is_simple_literal(i): + lit.add(i) + else: + rem.append(i) + return lit, rem + + class SameTypeVisitor(TypeVisitor[bool]): """Visitor for checking whether two types are the 'same' type.""" @@ -153,14 +169,20 @@ def visit_literal_type(self, left: LiteralType) -> bool: def visit_union_type(self, left: UnionType) -> bool: if isinstance(self.right, UnionType): + left_lit, left_rem = _extract_literals(left) + right_lit, right_rem = _extract_literals(self.right) + + if left_lit != right_lit: + return False + # Check that everything in left is in right - for left_item in left.items: - if not any(is_same_type(left_item, right_item) for right_item in self.right.items): + for left_item in left_rem: + if not any(is_same_type(left_item, right_item) for right_item in right_rem): return False # Check that everything in right is in left - for right_item in self.right.items: - if not any(is_same_type(right_item, left_item) for left_item in left.items): + for right_item in right_rem: + if not any(is_same_type(right_item, left_item) for left_item in left_rem): return False return True diff --git a/mypy/subtypes.py b/mypy/subtypes.py index af35821d7ef4..d977a114bf2f 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -535,6 +535,20 @@ def visit_overloaded(self, left: Overloaded) -> bool: return False def visit_union_type(self, left: UnionType) -> bool: + if isinstance(self.right, Instance): + literal_types: Set[Instance] = set() + # avoid redundant check for union of literals + for item in left.relevant_items(): + item = get_proper_type(item) + lit_type = mypy.typeops.simple_literal_type(item) + if lit_type is not None: + if lit_type in literal_types: + continue + literal_types.add(lit_type) + item = lit_type + if not self._is_subtype(item, self.orig_right): + return False + return True return all(self._is_subtype(item, self.orig_right) for item in left.items) def visit_partial_type(self, left: PartialType) -> bool: @@ -1199,6 +1213,27 @@ def report(*args: Any) -> None: return applied +def try_restrict_literal_union(t: UnionType, s: Type) -> Optional[List[Type]]: + """Return the items of t, excluding any occurrence of s, if and only if + - t only contains simple literals + - s is a simple literal + + Otherwise, returns None + """ + ps = get_proper_type(s) + if not mypy.typeops.is_simple_literal(ps): + return None + + new_items: List[Type] = [] + for i in t.relevant_items(): + pi = get_proper_type(i) + if not mypy.typeops.is_simple_literal(pi): + return None + if pi != ps: + new_items.append(i) + return new_items + + def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False) -> Type: """Return t minus s for runtime type assertions. @@ -1212,10 +1247,14 @@ def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False) s = get_proper_type(s) if isinstance(t, UnionType): - new_items = [restrict_subtype_away(item, s, ignore_promotions=ignore_promotions) - for item in t.relevant_items() - if (isinstance(get_proper_type(item), AnyType) or - not covers_at_runtime(item, s, ignore_promotions))] + new_items = try_restrict_literal_union(t, s) + if new_items is None: + new_items = [ + restrict_subtype_away(item, s, ignore_promotions=ignore_promotions) + for item in t.relevant_items() + if (isinstance(get_proper_type(item), AnyType) or + not covers_at_runtime(item, s, ignore_promotions)) + ] return UnionType.make_union(new_items) elif covers_at_runtime(t, s, ignore_promotions): return UninhabitedType() @@ -1285,11 +1324,11 @@ def _is_proper_subtype(left: Type, right: Type, *, right = get_proper_type(right) if isinstance(right, UnionType) and not isinstance(left, UnionType): - return any([is_proper_subtype(orig_left, item, - ignore_promotions=ignore_promotions, - erase_instances=erase_instances, - keep_erased_types=keep_erased_types) - for item in right.items]) + return any(is_proper_subtype(orig_left, item, + ignore_promotions=ignore_promotions, + erase_instances=erase_instances, + keep_erased_types=keep_erased_types) + for item in right.items) return left.accept(ProperSubtypeVisitor(orig_right, ignore_promotions=ignore_promotions, erase_instances=erase_instances, @@ -1495,7 +1534,7 @@ def visit_overloaded(self, left: Overloaded) -> bool: return False def visit_union_type(self, left: UnionType) -> bool: - return all([self._is_proper_subtype(item, self.orig_right) for item in left.items]) + return all(self._is_proper_subtype(item, self.orig_right) for item in left.items) def visit_partial_type(self, left: PartialType) -> bool: # TODO: What's the right thing to do here? diff --git a/mypy/typeops.py b/mypy/typeops.py index bb5d8c291b12..9eb2c9bee18f 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -318,6 +318,15 @@ def simple_literal_value_key(t: ProperType) -> Optional[Tuple[str, ...]]: return None +def simple_literal_type(t: ProperType) -> Optional[Instance]: + """Extract the underlying fallback Instance type for a simple Literal""" + if isinstance(t, Instance) and t.last_known_value is not None: + t = t.last_known_value + if isinstance(t, LiteralType): + return t.fallback + return None + + def is_simple_literal(t: ProperType) -> bool: """Fast way to check if simple_literal_value_key() would return a non-None value.""" if isinstance(t, LiteralType):