New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
more enum-related speedups #12032
more enum-related speedups #12032
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,12 @@ | ||
from typing import Sequence | ||
from typing import Sequence, Tuple, Set, List | ||
|
||
from mypy.types import ( | ||
Type, UnboundType, AnyType, NoneType, TupleType, TypedDictType, | ||
UnionType, CallableType, TypeVarType, Instance, TypeVisitor, ErasedType, | ||
Overloaded, PartialType, DeletedType, UninhabitedType, TypeType, LiteralType, | ||
ProperType, get_proper_type, TypeAliasType, ParamSpecType, Parameters, UnpackType | ||
) | ||
from mypy.typeops import tuple_fallback, make_simplified_union | ||
from mypy.typeops import tuple_fallback, make_simplified_union, is_simple_literal | ||
|
||
|
||
def is_same_type(left: Type, right: Type) -> bool: | ||
|
@@ -153,14 +153,32 @@ def visit_literal_type(self, left: LiteralType) -> bool: | |
|
||
def visit_union_type(self, left: UnionType) -> bool: | ||
if isinstance(self.right, UnionType): | ||
# fast path for simple literals | ||
def _extract_literals(u: UnionType) -> Tuple[Set[Type], List[Type]]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe nested functions don't perform well with mypyc, can you make this a global function instead? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
lit: Set[Type] = set() | ||
rem: List[Type] = [] | ||
for i in u.relevant_items(): | ||
i = get_proper_type(i) | ||
if is_simple_literal(i): | ||
lit.add(i) | ||
else: | ||
rem.append(i) | ||
return lit, rem | ||
|
||
left_lit, left_rem = _extract_literals(left) | ||
right_lit, right_rem = _extract_literals(self.right) | ||
|
||
if left_lit != right_lit: | ||
return False | ||
|
||
# Check that everything in left is in right | ||
for left_item in left.items: | ||
if not any(is_same_type(left_item, right_item) for right_item in self.right.items): | ||
for left_item in left_rem: | ||
if not any(is_same_type(left_item, right_item) for right_item in right_rem): | ||
return False | ||
|
||
# Check that everything in right is in left | ||
for right_item in self.right.items: | ||
if not any(is_same_type(right_item, left_item) for left_item in left.items): | ||
for right_item in right_rem: | ||
if not any(is_same_type(right_item, left_item) for left_item in left_rem): | ||
return False | ||
|
||
return True | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -535,6 +535,20 @@ def visit_overloaded(self, left: Overloaded) -> bool: | |
return False | ||
|
||
def visit_union_type(self, left: UnionType) -> bool: | ||
if isinstance(self.right, Instance): | ||
literal_types: Set[Instance] = set() | ||
# avoid redundant check for union of literals | ||
for item in left.relevant_items(): | ||
item = get_proper_type(item) | ||
lit_type = mypy.typeops.simple_literal_type(item) | ||
if lit_type is not None: | ||
if lit_type in literal_types: | ||
continue | ||
literal_types.add(lit_type) | ||
item = lit_type | ||
if not self._is_subtype(item, self.orig_right): | ||
return False | ||
return True | ||
return all(self._is_subtype(item, self.orig_right) for item in left.items) | ||
|
||
def visit_partial_type(self, left: PartialType) -> bool: | ||
|
@@ -1199,6 +1213,18 @@ def report(*args: Any) -> None: | |
return applied | ||
|
||
|
||
def try_restrict_literal_union(t: UnionType, s: Type) -> Optional[List[Type]]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a docstring to this function? |
||
"""Helper function for restrict_subtype_away, allowing a fast path for Union of simple literals""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This docstring doesn't say what the function does, which is the part I was having trouble with. What are the parameters and what does it return? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
new_items: List[Type] = [] | ||
for i in t.relevant_items(): | ||
it = get_proper_type(i) | ||
if not mypy.typeops.is_simple_literal(it): | ||
return None | ||
if it != s: | ||
new_items.append(i) | ||
return new_items | ||
|
||
|
||
def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False) -> Type: | ||
"""Return t minus s for runtime type assertions. | ||
|
||
|
@@ -1212,10 +1238,13 @@ def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False) | |
s = get_proper_type(s) | ||
|
||
if isinstance(t, UnionType): | ||
new_items = [restrict_subtype_away(item, s, ignore_promotions=ignore_promotions) | ||
for item in t.relevant_items() | ||
if (isinstance(get_proper_type(item), AnyType) or | ||
not covers_at_runtime(item, s, ignore_promotions))] | ||
new_items = try_restrict_literal_union(t, s) if isinstance(s, LiteralType) else [] | ||
new_items = new_items or [ | ||
restrict_subtype_away(item, s, ignore_promotions=ignore_promotions) | ||
for item in t.relevant_items() | ||
if (isinstance(get_proper_type(item), AnyType) or | ||
not covers_at_runtime(item, s, ignore_promotions)) | ||
] | ||
return UnionType.make_union(new_items) | ||
elif covers_at_runtime(t, s, ignore_promotions): | ||
return UninhabitedType() | ||
|
@@ -1285,11 +1314,11 @@ def _is_proper_subtype(left: Type, right: Type, *, | |
right = get_proper_type(right) | ||
|
||
if isinstance(right, UnionType) and not isinstance(left, UnionType): | ||
return any([is_proper_subtype(orig_left, item, | ||
ignore_promotions=ignore_promotions, | ||
erase_instances=erase_instances, | ||
keep_erased_types=keep_erased_types) | ||
for item in right.items]) | ||
return any(is_proper_subtype(orig_left, item, | ||
JelleZijlstra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ignore_promotions=ignore_promotions, | ||
erase_instances=erase_instances, | ||
keep_erased_types=keep_erased_types) | ||
for item in right.items) | ||
return left.accept(ProperSubtypeVisitor(orig_right, | ||
ignore_promotions=ignore_promotions, | ||
erase_instances=erase_instances, | ||
|
@@ -1495,7 +1524,7 @@ def visit_overloaded(self, left: Overloaded) -> bool: | |
return False | ||
|
||
def visit_union_type(self, left: UnionType) -> bool: | ||
return all([self._is_proper_subtype(item, self.orig_right) for item in left.items]) | ||
return all(self._is_proper_subtype(item, self.orig_right) for item in left.items) | ||
|
||
def visit_partial_type(self, left: PartialType) -> bool: | ||
# TODO: What's the right thing to do here? | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it be this instead? As I understand it, this function checks that
x
is an enum andy
is a union containing only members ofx
.It would be helpful to add a docstring to this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Upon looking at this again, I think the previous implementation was wrong since we only care about partial overlap. Fixed and documented accordingly.