-
-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add union math for intelligent indexing #6558
Changes from all commits
e87a4ad
ac0f317
be7210f
2174a68
99e219e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
|
||
from collections import OrderedDict | ||
from contextlib import contextmanager | ||
import itertools | ||
from typing import ( | ||
cast, Dict, Set, List, Tuple, Callable, Union, Optional, Sequence, Iterator | ||
) | ||
|
@@ -2554,15 +2555,18 @@ def visit_index_with_type(self, left_type: Type, e: IndexExpr, | |
if isinstance(index, SliceExpr): | ||
return self.visit_tuple_slice_helper(left_type, index) | ||
|
||
n = self._get_value(index) | ||
if n is not None: | ||
if n < 0: | ||
n += len(left_type.items) | ||
if 0 <= n < len(left_type.items): | ||
return left_type.items[n] | ||
else: | ||
self.chk.fail(message_registry.TUPLE_INDEX_OUT_OF_RANGE, e) | ||
return AnyType(TypeOfAny.from_error) | ||
ns = self.try_getting_int_literals(index) | ||
if ns is not None: | ||
out = [] | ||
for n in ns: | ||
if n < 0: | ||
n += len(left_type.items) | ||
if 0 <= n < len(left_type.items): | ||
out.append(left_type.items[n]) | ||
else: | ||
self.chk.fail(message_registry.TUPLE_INDEX_OUT_OF_RANGE, e) | ||
return AnyType(TypeOfAny.from_error) | ||
return UnionType.make_simplified_union(out) | ||
else: | ||
return self.nonliteral_tuple_index_helper(left_type, index) | ||
elif isinstance(left_type, TypedDictType): | ||
|
@@ -2578,26 +2582,66 @@ def visit_index_with_type(self, left_type: Type, e: IndexExpr, | |
return result | ||
|
||
def visit_tuple_slice_helper(self, left_type: TupleType, slic: SliceExpr) -> Type: | ||
begin = None | ||
end = None | ||
stride = None | ||
begin = [None] # type: Sequence[Optional[int]] | ||
end = [None] # type: Sequence[Optional[int]] | ||
stride = [None] # type: Sequence[Optional[int]] | ||
|
||
if slic.begin_index: | ||
begin = self._get_value(slic.begin_index) | ||
if begin is None: | ||
begin_raw = self.try_getting_int_literals(slic.begin_index) | ||
if begin_raw is None: | ||
return self.nonliteral_tuple_index_helper(left_type, slic) | ||
begin = begin_raw | ||
|
||
if slic.end_index: | ||
end = self._get_value(slic.end_index) | ||
if end is None: | ||
end_raw = self.try_getting_int_literals(slic.end_index) | ||
if end_raw is None: | ||
return self.nonliteral_tuple_index_helper(left_type, slic) | ||
end = end_raw | ||
|
||
if slic.stride: | ||
stride = self._get_value(slic.stride) | ||
if stride is None: | ||
stride_raw = self.try_getting_int_literals(slic.stride) | ||
if stride_raw is None: | ||
return self.nonliteral_tuple_index_helper(left_type, slic) | ||
stride = stride_raw | ||
|
||
items = [] # type: List[Type] | ||
for b, e, s in itertools.product(begin, end, stride): | ||
items.append(left_type.slice(b, e, s)) | ||
return UnionType.make_simplified_union(items) | ||
|
||
return left_type.slice(begin, stride, end) | ||
def try_getting_int_literals(self, index: Expression) -> Optional[List[int]]: | ||
"""If the given expression or type corresponds to an int literal | ||
or a union of int literals, returns a list of the underlying ints. | ||
Otherwise, returns None. | ||
|
||
Specifically, this function is guaranteed to return a list with | ||
one or more ints if one one the following is true: | ||
|
||
1. 'expr' is a IntExpr or a UnaryExpr backed by an IntExpr | ||
2. 'typ' is a LiteralType containing an int | ||
3. 'typ' is a UnionType containing only LiteralType of ints | ||
""" | ||
if isinstance(index, IntExpr): | ||
return [index.value] | ||
elif isinstance(index, UnaryExpr): | ||
if index.op == '-': | ||
operand = index.expr | ||
if isinstance(operand, IntExpr): | ||
return [-1 * operand.value] | ||
typ = self.accept(index) | ||
if isinstance(typ, Instance) and typ.last_known_value is not None: | ||
typ = typ.last_known_value | ||
if isinstance(typ, LiteralType) and isinstance(typ.value, int): | ||
return [typ.value] | ||
if isinstance(typ, UnionType): | ||
out = [] | ||
for item in typ.items: | ||
if isinstance(item, LiteralType) and isinstance(item.value, int): | ||
out.append(item.value) | ||
else: | ||
return None | ||
return out | ||
return None | ||
|
||
def nonliteral_tuple_index_helper(self, left_type: TupleType, index: Expression) -> Type: | ||
index_type = self.accept(index) | ||
|
@@ -2614,40 +2658,36 @@ def nonliteral_tuple_index_helper(self, left_type: TupleType, index: Expression) | |
else: | ||
return union | ||
|
||
def _get_value(self, index: Expression) -> Optional[int]: | ||
if isinstance(index, IntExpr): | ||
return index.value | ||
elif isinstance(index, UnaryExpr): | ||
if index.op == '-': | ||
operand = index.expr | ||
if isinstance(operand, IntExpr): | ||
return -1 * operand.value | ||
typ = self.accept(index) | ||
if isinstance(typ, Instance) and typ.last_known_value is not None: | ||
typ = typ.last_known_value | ||
if isinstance(typ, LiteralType) and isinstance(typ.value, int): | ||
return typ.value | ||
return None | ||
|
||
def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression) -> Type: | ||
if isinstance(index, (StrExpr, UnicodeExpr)): | ||
item_name = index.value | ||
key_names = [index.value] | ||
else: | ||
typ = self.accept(index) | ||
if isinstance(typ, Instance) and typ.last_known_value is not None: | ||
typ = typ.last_known_value | ||
|
||
if isinstance(typ, LiteralType) and isinstance(typ.value, str): | ||
item_name = typ.value | ||
if isinstance(typ, UnionType): | ||
key_types = typ.items | ||
else: | ||
self.msg.typeddict_key_must_be_string_literal(td_type, index) | ||
return AnyType(TypeOfAny.from_error) | ||
key_types = [typ] | ||
|
||
item_type = td_type.items.get(item_name) | ||
if item_type is None: | ||
self.msg.typeddict_key_not_found(td_type, item_name, index) | ||
return AnyType(TypeOfAny.from_error) | ||
return item_type | ||
key_names = [] | ||
for key_type in key_types: | ||
if isinstance(key_type, Instance) and key_type.last_known_value is not None: | ||
key_type = key_type.last_known_value | ||
|
||
if isinstance(key_type, LiteralType) and isinstance(key_type.value, str): | ||
key_names.append(key_type.value) | ||
else: | ||
self.msg.typeddict_key_must_be_string_literal(td_type, index) | ||
return AnyType(TypeOfAny.from_error) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be practical to refactor this, so that this uses There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately, no -- this segment of code permits unicode keys but the IMO the correct thing to do is (a) modify But I think making this change fully correct is partly blocked on #6123, so I decided to defer the refactor for now. |
||
|
||
value_types = [] | ||
for key_name in key_names: | ||
value_type = td_type.items.get(key_name) | ||
if value_type is None: | ||
self.msg.typeddict_key_not_found(td_type, key_name, index) | ||
return AnyType(TypeOfAny.from_error) | ||
else: | ||
value_types.append(value_type) | ||
return UnionType.make_simplified_union(value_types) | ||
|
||
def visit_enum_index_expr(self, enum_type: TypeInfo, index: Expression, | ||
context: Context) -> Type: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of product would be clearer with a minimal example, like if there are two unions each with two elements (and stride is the same), the results would be a union with four elements.