Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add union math for intelligent indexing #6558

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
134 changes: 87 additions & 47 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections import OrderedDict
from contextlib import contextmanager
import itertools
from typing import (
cast, Dict, Set, List, Tuple, Callable, Union, Optional, Sequence, Iterator
)
Expand Down Expand Up @@ -2554,15 +2555,18 @@ def visit_index_with_type(self, left_type: Type, e: IndexExpr,
if isinstance(index, SliceExpr):
return self.visit_tuple_slice_helper(left_type, index)

n = self._get_value(index)
if n is not None:
if n < 0:
n += len(left_type.items)
if 0 <= n < len(left_type.items):
return left_type.items[n]
else:
self.chk.fail(message_registry.TUPLE_INDEX_OUT_OF_RANGE, e)
return AnyType(TypeOfAny.from_error)
ns = self.try_getting_int_literals(index)
if ns is not None:
out = []
for n in ns:
if n < 0:
n += len(left_type.items)
if 0 <= n < len(left_type.items):
out.append(left_type.items[n])
else:
self.chk.fail(message_registry.TUPLE_INDEX_OUT_OF_RANGE, e)
return AnyType(TypeOfAny.from_error)
return UnionType.make_simplified_union(out)
else:
return self.nonliteral_tuple_index_helper(left_type, index)
elif isinstance(left_type, TypedDictType):
Expand All @@ -2578,26 +2582,66 @@ def visit_index_with_type(self, left_type: Type, e: IndexExpr,
return result

def visit_tuple_slice_helper(self, left_type: TupleType, slic: SliceExpr) -> Type:
begin = None
end = None
stride = None
begin = [None] # type: Sequence[Optional[int]]
end = [None] # type: Sequence[Optional[int]]
stride = [None] # type: Sequence[Optional[int]]

if slic.begin_index:
begin = self._get_value(slic.begin_index)
if begin is None:
begin_raw = self.try_getting_int_literals(slic.begin_index)
if begin_raw is None:
return self.nonliteral_tuple_index_helper(left_type, slic)
begin = begin_raw

if slic.end_index:
end = self._get_value(slic.end_index)
if end is None:
end_raw = self.try_getting_int_literals(slic.end_index)
if end_raw is None:
return self.nonliteral_tuple_index_helper(left_type, slic)
end = end_raw

if slic.stride:
stride = self._get_value(slic.stride)
if stride is None:
stride_raw = self.try_getting_int_literals(slic.stride)
if stride_raw is None:
return self.nonliteral_tuple_index_helper(left_type, slic)
stride = stride_raw

items = [] # type: List[Type]
for b, e, s in itertools.product(begin, end, stride):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of product would be clearer with a minimal example, like if there are two unions each with two elements (and stride is the same), the results would be a union with four elements.

items.append(left_type.slice(b, e, s))
return UnionType.make_simplified_union(items)

return left_type.slice(begin, stride, end)
def try_getting_int_literals(self, index: Expression) -> Optional[List[int]]:
"""If the given expression or type corresponds to an int literal
or a union of int literals, returns a list of the underlying ints.
Otherwise, returns None.

Specifically, this function is guaranteed to return a list with
one or more ints if one one the following is true:

1. 'expr' is a IntExpr or a UnaryExpr backed by an IntExpr
2. 'typ' is a LiteralType containing an int
3. 'typ' is a UnionType containing only LiteralType of ints
"""
if isinstance(index, IntExpr):
return [index.value]
elif isinstance(index, UnaryExpr):
if index.op == '-':
operand = index.expr
if isinstance(operand, IntExpr):
return [-1 * operand.value]
typ = self.accept(index)
if isinstance(typ, Instance) and typ.last_known_value is not None:
typ = typ.last_known_value
if isinstance(typ, LiteralType) and isinstance(typ.value, int):
return [typ.value]
if isinstance(typ, UnionType):
out = []
for item in typ.items:
if isinstance(item, LiteralType) and isinstance(item.value, int):
out.append(item.value)
else:
return None
return out
return None

def nonliteral_tuple_index_helper(self, left_type: TupleType, index: Expression) -> Type:
index_type = self.accept(index)
Expand All @@ -2614,40 +2658,36 @@ def nonliteral_tuple_index_helper(self, left_type: TupleType, index: Expression)
else:
return union

def _get_value(self, index: Expression) -> Optional[int]:
if isinstance(index, IntExpr):
return index.value
elif isinstance(index, UnaryExpr):
if index.op == '-':
operand = index.expr
if isinstance(operand, IntExpr):
return -1 * operand.value
typ = self.accept(index)
if isinstance(typ, Instance) and typ.last_known_value is not None:
typ = typ.last_known_value
if isinstance(typ, LiteralType) and isinstance(typ.value, int):
return typ.value
return None

def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression) -> Type:
if isinstance(index, (StrExpr, UnicodeExpr)):
item_name = index.value
key_names = [index.value]
else:
typ = self.accept(index)
if isinstance(typ, Instance) and typ.last_known_value is not None:
typ = typ.last_known_value

if isinstance(typ, LiteralType) and isinstance(typ.value, str):
item_name = typ.value
if isinstance(typ, UnionType):
key_types = typ.items
else:
self.msg.typeddict_key_must_be_string_literal(td_type, index)
return AnyType(TypeOfAny.from_error)
key_types = [typ]

item_type = td_type.items.get(item_name)
if item_type is None:
self.msg.typeddict_key_not_found(td_type, item_name, index)
return AnyType(TypeOfAny.from_error)
return item_type
key_names = []
for key_type in key_types:
if isinstance(key_type, Instance) and key_type.last_known_value is not None:
key_type = key_type.last_known_value

if isinstance(key_type, LiteralType) and isinstance(key_type.value, str):
key_names.append(key_type.value)
else:
self.msg.typeddict_key_must_be_string_literal(td_type, index)
return AnyType(TypeOfAny.from_error)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be practical to refactor this, so that this uses try_getting_str_literals()? It looks like there is some code duplication. (If yes, the function should be probably moved to nodes.py and imported both here and in the plugin.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, no -- this segment of code permits unicode keys but the try_getting_str_literals section doesn't.

IMO the correct thing to do is (a) modify try_getting_str_literals so it also allows unicode keys and (b) convert this entire function into a plugin.

But I think making this change fully correct is partly blocked on #6123, so I decided to defer the refactor for now.


value_types = []
for key_name in key_names:
value_type = td_type.items.get(key_name)
if value_type is None:
self.msg.typeddict_key_not_found(td_type, key_name, index)
return AnyType(TypeOfAny.from_error)
else:
value_types.append(value_type)
return UnionType.make_simplified_union(value_types)

def visit_enum_index_expr(self, enum_type: TypeInfo, index: Expression,
context: Context) -> Type:
Expand Down
8 changes: 8 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,6 +1177,14 @@ def typeddict_key_cannot_be_deleted(
self.fail("Key '{}' of TypedDict {} cannot be deleted".format(
item_name, self.format(typ)), context)

def typeddict_setdefault_arguments_inconsistent(
self,
default: Type,
expected: Type,
context: Context) -> None:
msg = 'Argument 2 to "setdefault" of "TypedDict" has incompatible type {}; expected {}'
self.fail(msg.format(self.format(default), self.format(expected)), context)

def type_arguments_not_allowed(self, context: Context) -> None:
self.fail('Parameterized generics cannot be used with class or instance checks', context)

Expand Down
44 changes: 30 additions & 14 deletions mypy/plugins/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
)
from mypy.plugin import ClassDefContext
from mypy.semanal import set_callable_name
from mypy.types import CallableType, Overloaded, Type, TypeVarDef, LiteralType, Instance
from mypy.types import CallableType, Overloaded, Type, TypeVarDef, LiteralType, Instance, UnionType
from mypy.typevars import fill_typevars
from mypy.util import get_unique_redefinition_name

Expand Down Expand Up @@ -129,18 +129,34 @@ def add_method(
info.defn.defs.body.append(func)


def try_getting_str_literal(expr: Expression, typ: Type) -> Optional[str]:
"""If this expression is a string literal, or if the corresponding type
is something like 'Literal["some string here"]', returns the underlying
string value. Otherwise, returns None."""
def try_getting_str_literals(expr: Expression, typ: Type) -> Optional[List[str]]:
"""If the given expression or type corresponds to a string literal
or a union of string literals, returns a list of the underlying strings.
Otherwise, returns None.

Specifically, this function is guaranteed to return a list with
one or more strings if one one the following is true:

1. 'expr' is a StrExpr
2. 'typ' is a LiteralType containing a string
3. 'typ' is a UnionType containing only LiteralType of strings
"""
if isinstance(expr, StrExpr):
return [expr.value]

if isinstance(typ, Instance) and typ.last_known_value is not None:
typ = typ.last_known_value

if isinstance(typ, LiteralType) and typ.fallback.type.fullname() == 'builtins.str':
val = typ.value
assert isinstance(val, str)
return val
elif isinstance(expr, StrExpr):
return expr.value
possible_literals = [typ.last_known_value] # type: List[Type]
elif isinstance(typ, UnionType):
possible_literals = typ.items
else:
return None
possible_literals = [typ]

strings = []
for lit in possible_literals:
if isinstance(lit, LiteralType) and lit.fallback.type.fullname() == 'builtins.str':
val = lit.value
assert isinstance(val, str)
strings.append(val)
else:
return None
return strings