Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

more enum-related speedups #12032

Merged
merged 2 commits into from Apr 18, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 29 additions & 0 deletions mypy/meet.py
Expand Up @@ -64,6 +64,8 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type:
if isinstance(declared, UnionType):
return make_simplified_union([narrow_declared_type(x, narrowed)
for x in declared.relevant_items()])
if is_enum_overlapping_union(declared, narrowed):
return narrowed
elif not is_overlapping_types(declared, narrowed,
prohibit_none_typevar_overlap=True):
if state.strict_optional:
Expand Down Expand Up @@ -137,6 +139,21 @@ def get_possible_variants(typ: Type) -> List[Type]:
return [typ]


def is_enum_overlapping_union(x: ProperType, y: ProperType) -> bool:
return (
isinstance(x, Instance) and x.type.is_enum and
isinstance(y, UnionType) and
all(x.type == p.fallback.type
for p in (get_proper_type(z) for z in y.relevant_items())
if isinstance(p, LiteralType))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
all(x.type == p.fallback.type
for p in (get_proper_type(z) for z in y.relevant_items())
if isinstance(p, LiteralType))
all(isinstance(p, LiteralType)) and x.type == p.fallback.type
for p in (get_proper_type(z) for z in y.relevant_items()))

Should it be this instead? As I understand it, this function checks that x is an enum and y is a union containing only members of x.

It would be helpful to add a docstring to this function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upon looking at this again, I think the previous implementation was wrong since we only care about partial overlap. Fixed and documented accordingly.

)


def is_literal_in_union(x: ProperType, y: ProperType) -> bool:
return (isinstance(x, LiteralType) and isinstance(y, UnionType) and
any(x == get_proper_type(z) for z in y.items))


def is_overlapping_types(left: Type,
right: Type,
ignore_promotions: bool = False,
Expand Down Expand Up @@ -198,6 +215,18 @@ def _is_overlapping_types(left: Type, right: Type) -> bool:
#
# These checks will also handle the NoneType and UninhabitedType cases for us.

# enums are sometimes expanded into an Union of Literals
# when that happens we want to make sure we treat the two as overlapping
# and crucially, we want to do that *fast* in case the enum is large
# so we do it before expanding variants below to avoid O(n**2) behavior
if (
is_enum_overlapping_union(left, right)
or is_enum_overlapping_union(right, left)
or is_literal_in_union(left, right)
or is_literal_in_union(right, left)
):
return True

if (is_proper_subtype(left, right, ignore_promotions=ignore_promotions)
or is_proper_subtype(right, left, ignore_promotions=ignore_promotions)):
return True
Expand Down
30 changes: 24 additions & 6 deletions mypy/sametypes.py
@@ -1,12 +1,12 @@
from typing import Sequence
from typing import Sequence, Tuple, Set, List

from mypy.types import (
Type, UnboundType, AnyType, NoneType, TupleType, TypedDictType,
UnionType, CallableType, TypeVarType, Instance, TypeVisitor, ErasedType,
Overloaded, PartialType, DeletedType, UninhabitedType, TypeType, LiteralType,
ProperType, get_proper_type, TypeAliasType, ParamSpecType, Parameters, UnpackType
)
from mypy.typeops import tuple_fallback, make_simplified_union
from mypy.typeops import tuple_fallback, make_simplified_union, is_simple_literal


def is_same_type(left: Type, right: Type) -> bool:
Expand Down Expand Up @@ -153,14 +153,32 @@ def visit_literal_type(self, left: LiteralType) -> bool:

def visit_union_type(self, left: UnionType) -> bool:
if isinstance(self.right, UnionType):
# fast path for simple literals
def _extract_literals(u: UnionType) -> Tuple[Set[Type], List[Type]]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe nested functions don't perform well with mypyc, can you make this a global function instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

lit: Set[Type] = set()
rem: List[Type] = []
for i in u.relevant_items():
i = get_proper_type(i)
if is_simple_literal(i):
lit.add(i)
else:
rem.append(i)
return lit, rem

left_lit, left_rem = _extract_literals(left)
right_lit, right_rem = _extract_literals(self.right)

if left_lit != right_lit:
return False

# Check that everything in left is in right
for left_item in left.items:
if not any(is_same_type(left_item, right_item) for right_item in self.right.items):
for left_item in left_rem:
if not any(is_same_type(left_item, right_item) for right_item in right_rem):
return False

# Check that everything in right is in left
for right_item in self.right.items:
if not any(is_same_type(right_item, left_item) for left_item in left.items):
for right_item in right_rem:
if not any(is_same_type(right_item, left_item) for left_item in left_rem):
return False

return True
Expand Down
49 changes: 39 additions & 10 deletions mypy/subtypes.py
Expand Up @@ -535,6 +535,20 @@ def visit_overloaded(self, left: Overloaded) -> bool:
return False

def visit_union_type(self, left: UnionType) -> bool:
if isinstance(self.right, Instance):
literal_types: Set[Instance] = set()
# avoid redundant check for union of literals
for item in left.relevant_items():
item = get_proper_type(item)
lit_type = mypy.typeops.simple_literal_type(item)
if lit_type is not None:
if lit_type in literal_types:
continue
literal_types.add(lit_type)
item = lit_type
if not self._is_subtype(item, self.orig_right):
return False
return True
return all(self._is_subtype(item, self.orig_right) for item in left.items)

def visit_partial_type(self, left: PartialType) -> bool:
Expand Down Expand Up @@ -1199,6 +1213,18 @@ def report(*args: Any) -> None:
return applied


def try_restrict_literal_union(t: UnionType, s: Type) -> Optional[List[Type]]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a docstring to this function?

"""Helper function for restrict_subtype_away, allowing a fast path for Union of simple literals"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This docstring doesn't say what the function does, which is the part I was having trouble with. What are the parameters and what does it return?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

new_items: List[Type] = []
for i in t.relevant_items():
it = get_proper_type(i)
if not mypy.typeops.is_simple_literal(it):
return None
if it != s:
new_items.append(i)
return new_items


def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False) -> Type:
"""Return t minus s for runtime type assertions.

Expand All @@ -1212,10 +1238,13 @@ def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False)
s = get_proper_type(s)

if isinstance(t, UnionType):
new_items = [restrict_subtype_away(item, s, ignore_promotions=ignore_promotions)
for item in t.relevant_items()
if (isinstance(get_proper_type(item), AnyType) or
not covers_at_runtime(item, s, ignore_promotions))]
new_items = try_restrict_literal_union(t, s) if isinstance(s, LiteralType) else []
new_items = new_items or [
restrict_subtype_away(item, s, ignore_promotions=ignore_promotions)
for item in t.relevant_items()
if (isinstance(get_proper_type(item), AnyType) or
not covers_at_runtime(item, s, ignore_promotions))
]
return UnionType.make_union(new_items)
elif covers_at_runtime(t, s, ignore_promotions):
return UninhabitedType()
Expand Down Expand Up @@ -1285,11 +1314,11 @@ def _is_proper_subtype(left: Type, right: Type, *,
right = get_proper_type(right)

if isinstance(right, UnionType) and not isinstance(left, UnionType):
return any([is_proper_subtype(orig_left, item,
ignore_promotions=ignore_promotions,
erase_instances=erase_instances,
keep_erased_types=keep_erased_types)
for item in right.items])
return any(is_proper_subtype(orig_left, item,
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved
ignore_promotions=ignore_promotions,
erase_instances=erase_instances,
keep_erased_types=keep_erased_types)
for item in right.items)
return left.accept(ProperSubtypeVisitor(orig_right,
ignore_promotions=ignore_promotions,
erase_instances=erase_instances,
Expand Down Expand Up @@ -1495,7 +1524,7 @@ def visit_overloaded(self, left: Overloaded) -> bool:
return False

def visit_union_type(self, left: UnionType) -> bool:
return all([self._is_proper_subtype(item, self.orig_right) for item in left.items])
return all(self._is_proper_subtype(item, self.orig_right) for item in left.items)

def visit_partial_type(self, left: PartialType) -> bool:
# TODO: What's the right thing to do here?
Expand Down
9 changes: 9 additions & 0 deletions mypy/typeops.py
Expand Up @@ -318,6 +318,15 @@ def simple_literal_value_key(t: ProperType) -> Optional[Tuple[str, ...]]:
return None


def simple_literal_type(t: ProperType) -> Optional[Instance]:
"""Extract the underlying fallback Instance type for a simple Literal"""
if isinstance(t, Instance) and t.last_known_value is not None:
t = t.last_known_value
if isinstance(t, LiteralType):
return t.fallback
return None


def is_simple_literal(t: ProperType) -> bool:
"""Fast way to check if simple_literal_value_key() would return a non-None value."""
if isinstance(t, LiteralType):
Expand Down