Skip to content

Commit

Permalink
more enum-related speedups (#12032)
Browse files Browse the repository at this point in the history
As a followup to #9394 address a few more O(n**2) behaviors
caused by decomposing enums into unions of literals.
  • Loading branch information
huguesb authored and JukkaL committed Apr 20, 2022
1 parent 914506b commit 963e4ae
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 16 deletions.
30 changes: 30 additions & 0 deletions mypy/meet.py
Expand Up @@ -64,6 +64,8 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type:
if isinstance(declared, UnionType):
return make_simplified_union([narrow_declared_type(x, narrowed)
for x in declared.relevant_items()])
if is_enum_overlapping_union(declared, narrowed):
return narrowed
elif not is_overlapping_types(declared, narrowed,
prohibit_none_typevar_overlap=True):
if state.strict_optional:
Expand Down Expand Up @@ -137,6 +139,22 @@ def get_possible_variants(typ: Type) -> List[Type]:
return [typ]


def is_enum_overlapping_union(x: ProperType, y: ProperType) -> bool:
"""Return True if x is an Enum, and y is an Union with at least one Literal from x"""
return (
isinstance(x, Instance) and x.type.is_enum and
isinstance(y, UnionType) and
any(isinstance(p, LiteralType) and x.type == p.fallback.type
for p in (get_proper_type(z) for z in y.relevant_items()))
)


def is_literal_in_union(x: ProperType, y: ProperType) -> bool:
"""Return True if x is a Literal and y is an Union that includes x"""
return (isinstance(x, LiteralType) and isinstance(y, UnionType) and
any(x == get_proper_type(z) for z in y.items))


def is_overlapping_types(left: Type,
right: Type,
ignore_promotions: bool = False,
Expand Down Expand Up @@ -198,6 +216,18 @@ def _is_overlapping_types(left: Type, right: Type) -> bool:
#
# These checks will also handle the NoneType and UninhabitedType cases for us.

# enums are sometimes expanded into an Union of Literals
# when that happens we want to make sure we treat the two as overlapping
# and crucially, we want to do that *fast* in case the enum is large
# so we do it before expanding variants below to avoid O(n**2) behavior
if (
is_enum_overlapping_union(left, right)
or is_enum_overlapping_union(right, left)
or is_literal_in_union(left, right)
or is_literal_in_union(right, left)
):
return True

if (is_proper_subtype(left, right, ignore_promotions=ignore_promotions)
or is_proper_subtype(right, left, ignore_promotions=ignore_promotions)):
return True
Expand Down
34 changes: 28 additions & 6 deletions mypy/sametypes.py
@@ -1,12 +1,12 @@
from typing import Sequence
from typing import Sequence, Tuple, Set, List

from mypy.types import (
Type, UnboundType, AnyType, NoneType, TupleType, TypedDictType,
UnionType, CallableType, TypeVarType, Instance, TypeVisitor, ErasedType,
Overloaded, PartialType, DeletedType, UninhabitedType, TypeType, LiteralType,
ProperType, get_proper_type, TypeAliasType, ParamSpecType, Parameters, UnpackType
)
from mypy.typeops import tuple_fallback, make_simplified_union
from mypy.typeops import tuple_fallback, make_simplified_union, is_simple_literal


def is_same_type(left: Type, right: Type) -> bool:
Expand Down Expand Up @@ -49,6 +49,22 @@ def is_same_types(a1: Sequence[Type], a2: Sequence[Type]) -> bool:
return True


def _extract_literals(u: UnionType) -> Tuple[Set[Type], List[Type]]:
"""Given a UnionType, separate out its items into a set of simple literals and a remainder list
This is a useful helper to avoid O(n**2) behavior when comparing large unions, which can often
result from large enums in contexts where type narrowing removes a small subset of entries.
"""
lit: Set[Type] = set()
rem: List[Type] = []
for i in u.relevant_items():
i = get_proper_type(i)
if is_simple_literal(i):
lit.add(i)
else:
rem.append(i)
return lit, rem


class SameTypeVisitor(TypeVisitor[bool]):
"""Visitor for checking whether two types are the 'same' type."""

Expand Down Expand Up @@ -153,14 +169,20 @@ def visit_literal_type(self, left: LiteralType) -> bool:

def visit_union_type(self, left: UnionType) -> bool:
if isinstance(self.right, UnionType):
left_lit, left_rem = _extract_literals(left)
right_lit, right_rem = _extract_literals(self.right)

if left_lit != right_lit:
return False

# Check that everything in left is in right
for left_item in left.items:
if not any(is_same_type(left_item, right_item) for right_item in self.right.items):
for left_item in left_rem:
if not any(is_same_type(left_item, right_item) for right_item in right_rem):
return False

# Check that everything in right is in left
for right_item in self.right.items:
if not any(is_same_type(right_item, left_item) for left_item in left.items):
for right_item in right_rem:
if not any(is_same_type(right_item, left_item) for left_item in left_rem):
return False

return True
Expand Down
59 changes: 49 additions & 10 deletions mypy/subtypes.py
Expand Up @@ -535,6 +535,20 @@ def visit_overloaded(self, left: Overloaded) -> bool:
return False

def visit_union_type(self, left: UnionType) -> bool:
if isinstance(self.right, Instance):
literal_types: Set[Instance] = set()
# avoid redundant check for union of literals
for item in left.relevant_items():
item = get_proper_type(item)
lit_type = mypy.typeops.simple_literal_type(item)
if lit_type is not None:
if lit_type in literal_types:
continue
literal_types.add(lit_type)
item = lit_type
if not self._is_subtype(item, self.orig_right):
return False
return True
return all(self._is_subtype(item, self.orig_right) for item in left.items)

def visit_partial_type(self, left: PartialType) -> bool:
Expand Down Expand Up @@ -1199,6 +1213,27 @@ def report(*args: Any) -> None:
return applied


def try_restrict_literal_union(t: UnionType, s: Type) -> Optional[List[Type]]:
"""Return the items of t, excluding any occurrence of s, if and only if
- t only contains simple literals
- s is a simple literal
Otherwise, returns None
"""
ps = get_proper_type(s)
if not mypy.typeops.is_simple_literal(ps):
return None

new_items: List[Type] = []
for i in t.relevant_items():
pi = get_proper_type(i)
if not mypy.typeops.is_simple_literal(pi):
return None
if pi != ps:
new_items.append(i)
return new_items


def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False) -> Type:
"""Return t minus s for runtime type assertions.
Expand All @@ -1212,10 +1247,14 @@ def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False)
s = get_proper_type(s)

if isinstance(t, UnionType):
new_items = [restrict_subtype_away(item, s, ignore_promotions=ignore_promotions)
for item in t.relevant_items()
if (isinstance(get_proper_type(item), AnyType) or
not covers_at_runtime(item, s, ignore_promotions))]
new_items = try_restrict_literal_union(t, s)
if new_items is None:
new_items = [
restrict_subtype_away(item, s, ignore_promotions=ignore_promotions)
for item in t.relevant_items()
if (isinstance(get_proper_type(item), AnyType) or
not covers_at_runtime(item, s, ignore_promotions))
]
return UnionType.make_union(new_items)
elif covers_at_runtime(t, s, ignore_promotions):
return UninhabitedType()
Expand Down Expand Up @@ -1285,11 +1324,11 @@ def _is_proper_subtype(left: Type, right: Type, *,
right = get_proper_type(right)

if isinstance(right, UnionType) and not isinstance(left, UnionType):
return any([is_proper_subtype(orig_left, item,
ignore_promotions=ignore_promotions,
erase_instances=erase_instances,
keep_erased_types=keep_erased_types)
for item in right.items])
return any(is_proper_subtype(orig_left, item,
ignore_promotions=ignore_promotions,
erase_instances=erase_instances,
keep_erased_types=keep_erased_types)
for item in right.items)
return left.accept(ProperSubtypeVisitor(orig_right,
ignore_promotions=ignore_promotions,
erase_instances=erase_instances,
Expand Down Expand Up @@ -1495,7 +1534,7 @@ def visit_overloaded(self, left: Overloaded) -> bool:
return False

def visit_union_type(self, left: UnionType) -> bool:
return all([self._is_proper_subtype(item, self.orig_right) for item in left.items])
return all(self._is_proper_subtype(item, self.orig_right) for item in left.items)

def visit_partial_type(self, left: PartialType) -> bool:
# TODO: What's the right thing to do here?
Expand Down
9 changes: 9 additions & 0 deletions mypy/typeops.py
Expand Up @@ -318,6 +318,15 @@ def simple_literal_value_key(t: ProperType) -> Optional[Tuple[str, ...]]:
return None


def simple_literal_type(t: ProperType) -> Optional[Instance]:
"""Extract the underlying fallback Instance type for a simple Literal"""
if isinstance(t, Instance) and t.last_known_value is not None:
t = t.last_known_value
if isinstance(t, LiteralType):
return t.fallback
return None


def is_simple_literal(t: ProperType) -> bool:
"""Fast way to check if simple_literal_value_key() would return a non-None value."""
if isinstance(t, LiteralType):
Expand Down

0 comments on commit 963e4ae

Please sign in to comment.