Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

more enum-related speedups #12032

Merged
merged 2 commits into from Apr 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 30 additions & 0 deletions mypy/meet.py
Expand Up @@ -64,6 +64,8 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type:
if isinstance(declared, UnionType):
return make_simplified_union([narrow_declared_type(x, narrowed)
for x in declared.relevant_items()])
if is_enum_overlapping_union(declared, narrowed):
return narrowed
elif not is_overlapping_types(declared, narrowed,
prohibit_none_typevar_overlap=True):
if state.strict_optional:
Expand Down Expand Up @@ -137,6 +139,22 @@ def get_possible_variants(typ: Type) -> List[Type]:
return [typ]


def is_enum_overlapping_union(x: ProperType, y: ProperType) -> bool:
"""Return True if x is an Enum, and y is an Union with at least one Literal from x"""
return (
isinstance(x, Instance) and x.type.is_enum and
isinstance(y, UnionType) and
any(isinstance(p, LiteralType) and x.type == p.fallback.type
for p in (get_proper_type(z) for z in y.relevant_items()))
)


def is_literal_in_union(x: ProperType, y: ProperType) -> bool:
"""Return True if x is a Literal and y is an Union that includes x"""
return (isinstance(x, LiteralType) and isinstance(y, UnionType) and
any(x == get_proper_type(z) for z in y.items))


def is_overlapping_types(left: Type,
right: Type,
ignore_promotions: bool = False,
Expand Down Expand Up @@ -198,6 +216,18 @@ def _is_overlapping_types(left: Type, right: Type) -> bool:
#
# These checks will also handle the NoneType and UninhabitedType cases for us.

# enums are sometimes expanded into an Union of Literals
# when that happens we want to make sure we treat the two as overlapping
# and crucially, we want to do that *fast* in case the enum is large
# so we do it before expanding variants below to avoid O(n**2) behavior
if (
is_enum_overlapping_union(left, right)
or is_enum_overlapping_union(right, left)
or is_literal_in_union(left, right)
or is_literal_in_union(right, left)
):
return True

if (is_proper_subtype(left, right, ignore_promotions=ignore_promotions)
or is_proper_subtype(right, left, ignore_promotions=ignore_promotions)):
return True
Expand Down
34 changes: 28 additions & 6 deletions mypy/sametypes.py
@@ -1,12 +1,12 @@
from typing import Sequence
from typing import Sequence, Tuple, Set, List

from mypy.types import (
Type, UnboundType, AnyType, NoneType, TupleType, TypedDictType,
UnionType, CallableType, TypeVarType, Instance, TypeVisitor, ErasedType,
Overloaded, PartialType, DeletedType, UninhabitedType, TypeType, LiteralType,
ProperType, get_proper_type, TypeAliasType, ParamSpecType, Parameters, UnpackType
)
from mypy.typeops import tuple_fallback, make_simplified_union
from mypy.typeops import tuple_fallback, make_simplified_union, is_simple_literal


def is_same_type(left: Type, right: Type) -> bool:
Expand Down Expand Up @@ -49,6 +49,22 @@ def is_same_types(a1: Sequence[Type], a2: Sequence[Type]) -> bool:
return True


def _extract_literals(u: UnionType) -> Tuple[Set[Type], List[Type]]:
"""Given a UnionType, separate out its items into a set of simple literals and a remainder list
This is a useful helper to avoid O(n**2) behavior when comparing large unions, which can often
result from large enums in contexts where type narrowing removes a small subset of entries.
"""
lit: Set[Type] = set()
rem: List[Type] = []
for i in u.relevant_items():
i = get_proper_type(i)
if is_simple_literal(i):
lit.add(i)
else:
rem.append(i)
return lit, rem


class SameTypeVisitor(TypeVisitor[bool]):
"""Visitor for checking whether two types are the 'same' type."""

Expand Down Expand Up @@ -153,14 +169,20 @@ def visit_literal_type(self, left: LiteralType) -> bool:

def visit_union_type(self, left: UnionType) -> bool:
if isinstance(self.right, UnionType):
left_lit, left_rem = _extract_literals(left)
right_lit, right_rem = _extract_literals(self.right)

if left_lit != right_lit:
return False

# Check that everything in left is in right
for left_item in left.items:
if not any(is_same_type(left_item, right_item) for right_item in self.right.items):
for left_item in left_rem:
if not any(is_same_type(left_item, right_item) for right_item in right_rem):
return False

# Check that everything in right is in left
for right_item in self.right.items:
if not any(is_same_type(right_item, left_item) for left_item in left.items):
for right_item in right_rem:
if not any(is_same_type(right_item, left_item) for left_item in left_rem):
return False

return True
Expand Down
59 changes: 49 additions & 10 deletions mypy/subtypes.py
Expand Up @@ -535,6 +535,20 @@ def visit_overloaded(self, left: Overloaded) -> bool:
return False

def visit_union_type(self, left: UnionType) -> bool:
if isinstance(self.right, Instance):
literal_types: Set[Instance] = set()
# avoid redundant check for union of literals
for item in left.relevant_items():
item = get_proper_type(item)
lit_type = mypy.typeops.simple_literal_type(item)
if lit_type is not None:
if lit_type in literal_types:
continue
literal_types.add(lit_type)
item = lit_type
if not self._is_subtype(item, self.orig_right):
return False
return True
return all(self._is_subtype(item, self.orig_right) for item in left.items)

def visit_partial_type(self, left: PartialType) -> bool:
Expand Down Expand Up @@ -1199,6 +1213,27 @@ def report(*args: Any) -> None:
return applied


def try_restrict_literal_union(t: UnionType, s: Type) -> Optional[List[Type]]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a docstring to this function?

"""Return the items of t, excluding any occurrence of s, if and only if
- t only contains simple literals
- s is a simple literal

Otherwise, returns None
"""
ps = get_proper_type(s)
if not mypy.typeops.is_simple_literal(ps):
return None

new_items: List[Type] = []
for i in t.relevant_items():
pi = get_proper_type(i)
if not mypy.typeops.is_simple_literal(pi):
return None
if pi != ps:
new_items.append(i)
return new_items


def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False) -> Type:
"""Return t minus s for runtime type assertions.

Expand All @@ -1212,10 +1247,14 @@ def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False)
s = get_proper_type(s)

if isinstance(t, UnionType):
new_items = [restrict_subtype_away(item, s, ignore_promotions=ignore_promotions)
for item in t.relevant_items()
if (isinstance(get_proper_type(item), AnyType) or
not covers_at_runtime(item, s, ignore_promotions))]
new_items = try_restrict_literal_union(t, s)
if new_items is None:
new_items = [
restrict_subtype_away(item, s, ignore_promotions=ignore_promotions)
for item in t.relevant_items()
if (isinstance(get_proper_type(item), AnyType) or
not covers_at_runtime(item, s, ignore_promotions))
]
return UnionType.make_union(new_items)
elif covers_at_runtime(t, s, ignore_promotions):
return UninhabitedType()
Expand Down Expand Up @@ -1285,11 +1324,11 @@ def _is_proper_subtype(left: Type, right: Type, *,
right = get_proper_type(right)

if isinstance(right, UnionType) and not isinstance(left, UnionType):
return any([is_proper_subtype(orig_left, item,
ignore_promotions=ignore_promotions,
erase_instances=erase_instances,
keep_erased_types=keep_erased_types)
for item in right.items])
return any(is_proper_subtype(orig_left, item,
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved
ignore_promotions=ignore_promotions,
erase_instances=erase_instances,
keep_erased_types=keep_erased_types)
for item in right.items)
return left.accept(ProperSubtypeVisitor(orig_right,
ignore_promotions=ignore_promotions,
erase_instances=erase_instances,
Expand Down Expand Up @@ -1495,7 +1534,7 @@ def visit_overloaded(self, left: Overloaded) -> bool:
return False

def visit_union_type(self, left: UnionType) -> bool:
return all([self._is_proper_subtype(item, self.orig_right) for item in left.items])
return all(self._is_proper_subtype(item, self.orig_right) for item in left.items)

def visit_partial_type(self, left: PartialType) -> bool:
# TODO: What's the right thing to do here?
Expand Down
9 changes: 9 additions & 0 deletions mypy/typeops.py
Expand Up @@ -318,6 +318,15 @@ def simple_literal_value_key(t: ProperType) -> Optional[Tuple[str, ...]]:
return None


def simple_literal_type(t: ProperType) -> Optional[Instance]:
"""Extract the underlying fallback Instance type for a simple Literal"""
if isinstance(t, Instance) and t.last_known_value is not None:
t = t.last_known_value
if isinstance(t, LiteralType):
return t.fallback
return None


def is_simple_literal(t: ProperType) -> bool:
"""Fast way to check if simple_literal_value_key() would return a non-None value."""
if isinstance(t, LiteralType):
Expand Down