Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Now *CustomType and **CustomType arguments are properly checked #11151

Merged
merged 8 commits into from Oct 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
45 changes: 31 additions & 14 deletions mypy/argmap.py
@@ -1,12 +1,16 @@
"""Utilities for mapping between actual and formal arguments (and their types)."""

from typing import List, Optional, Sequence, Callable, Set
from typing import TYPE_CHECKING, List, Optional, Sequence, Callable, Set

from mypy.maptype import map_instance_to_supertype
from mypy.types import (
Type, Instance, TupleType, AnyType, TypeOfAny, TypedDictType, get_proper_type
)
from mypy import nodes

if TYPE_CHECKING:
from mypy.infer import ArgumentInferContext


def map_actuals_to_formals(actual_kinds: List[nodes.ArgKind],
actual_names: Optional[Sequence[Optional[str]]],
Expand Down Expand Up @@ -140,11 +144,13 @@ def f(x: int, *args: str) -> None: ...
needs a separate instance since instances have per-call state.
"""

def __init__(self) -> None:
def __init__(self, context: 'ArgumentInferContext') -> None:
# Next tuple *args index to use.
self.tuple_index = 0
# Keyword arguments in TypedDict **kwargs used.
self.kwargs_used: Set[str] = set()
# Type context for `*` and `**` arg kinds.
self.context = context

def expand_actual_type(self,
actual_type: Type,
Expand All @@ -162,16 +168,21 @@ def expand_actual_type(self,
This is supposed to be called for each formal, in order. Call multiple times per
formal if multiple actuals map to a formal.
"""
from mypy.subtypes import is_subtype

actual_type = get_proper_type(actual_type)
if actual_kind == nodes.ARG_STAR:
if isinstance(actual_type, Instance):
if actual_type.type.fullname == 'builtins.list':
# List *arg.
return actual_type.args[0]
elif actual_type.args:
# TODO: Try to map type arguments to Iterable
return actual_type.args[0]
if isinstance(actual_type, Instance) and actual_type.args:
if is_subtype(actual_type, self.context.iterable_type):
return map_instance_to_supertype(
actual_type,
self.context.iterable_type.type,
).args[0]
else:
# We cannot properly unpack anything other
# than `Iterable` type with `*`.
# Just return `Any`, other parts of code would raise
# a different error for improper use.
return AnyType(TypeOfAny.from_error)
elif isinstance(actual_type, TupleType):
# Get the next tuple item of a tuple *arg.
Expand All @@ -193,11 +204,17 @@ def expand_actual_type(self,
formal_name = (set(actual_type.items.keys()) - self.kwargs_used).pop()
self.kwargs_used.add(formal_name)
return actual_type.items[formal_name]
elif (isinstance(actual_type, Instance)
and (actual_type.type.fullname == 'builtins.dict')):
# Dict **arg.
# TODO: Handle arbitrary Mapping
return actual_type.args[1]
elif (
isinstance(actual_type, Instance) and
len(actual_type.args) > 1 and
is_subtype(actual_type, self.context.mapping_type)
):
# Only `Mapping` type can be unpacked with `**`.
# Other types will produce an error somewhere else.
return map_instance_to_supertype(
actual_type,
self.context.mapping_type.type,
).args[1]
else:
return AnyType(TypeOfAny.from_error)
else:
Expand Down
17 changes: 14 additions & 3 deletions mypy/checkexpr.py
Expand Up @@ -45,7 +45,9 @@
from mypy.maptype import map_instance_to_supertype
from mypy.messages import MessageBuilder
from mypy import message_registry
from mypy.infer import infer_type_arguments, infer_function_type_arguments
from mypy.infer import (
ArgumentInferContext, infer_type_arguments, infer_function_type_arguments,
)
from mypy import join
from mypy.meet import narrow_declared_type, is_overlapping_types
from mypy.subtypes import is_subtype, is_proper_subtype, is_equivalent, non_method_protocol_members
Expand Down Expand Up @@ -1240,6 +1242,7 @@ def infer_function_type_arguments(self, callee_type: CallableType,

inferred_args = infer_function_type_arguments(
callee_type, pass1_args, arg_kinds, formal_to_actual,
context=self.argument_infer_context(),
strict=self.chk.in_checked_function())

if 2 in arg_pass_nums:
Expand Down Expand Up @@ -1301,10 +1304,18 @@ def infer_function_type_arguments_pass2(
callee_type, args, arg_kinds, formal_to_actual)

inferred_args = infer_function_type_arguments(
callee_type, arg_types, arg_kinds, formal_to_actual)
callee_type, arg_types, arg_kinds, formal_to_actual,
context=self.argument_infer_context(),
)

return callee_type, inferred_args

def argument_infer_context(self) -> ArgumentInferContext:
return ArgumentInferContext(
self.chk.named_type('typing.Mapping'),
self.chk.named_type('typing.Iterable'),
)

def get_arg_infer_passes(self, arg_types: List[Type],
formal_to_actual: List[List[int]],
num_actuals: int) -> List[int]:
Expand Down Expand Up @@ -1479,7 +1490,7 @@ def check_argument_types(self,
messages = messages or self.msg
check_arg = check_arg or self.check_arg
# Keep track of consumed tuple *arg items.
mapper = ArgTypeExpander()
mapper = ArgTypeExpander(self.argument_infer_context())
for i, actuals in enumerate(formal_to_actual):
for actual in actuals:
actual_type = arg_types[actual]
Expand Down
14 changes: 10 additions & 4 deletions mypy/constraints.py
@@ -1,6 +1,6 @@
"""Type inference constraints."""

from typing import Iterable, List, Optional, Sequence
from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence
from typing_extensions import Final

from mypy.types import (
Expand All @@ -18,6 +18,9 @@
from mypy.argmap import ArgTypeExpander
from mypy.typestate import TypeState

if TYPE_CHECKING:
from mypy.infer import ArgumentInferContext

SUBTYPE_OF: Final = 0
SUPERTYPE_OF: Final = 1

Expand Down Expand Up @@ -45,14 +48,17 @@ def __repr__(self) -> str:


def infer_constraints_for_callable(
callee: CallableType, arg_types: Sequence[Optional[Type]], arg_kinds: List[ArgKind],
formal_to_actual: List[List[int]]) -> List[Constraint]:
callee: CallableType,
arg_types: Sequence[Optional[Type]],
arg_kinds: List[ArgKind],
formal_to_actual: List[List[int]],
context: 'ArgumentInferContext') -> List[Constraint]:
"""Infer type variable constraints for a callable and actual arguments.

Return a list of constraints.
"""
constraints: List[Constraint] = []
mapper = ArgTypeExpander()
mapper = ArgTypeExpander(context)

for i, actuals in enumerate(formal_to_actual):
for actual in actuals:
Expand Down
21 changes: 18 additions & 3 deletions mypy/infer.py
@@ -1,19 +1,34 @@
"""Utilities for type argument inference."""

from typing import List, Optional, Sequence
from typing import List, Optional, Sequence, NamedTuple

from mypy.constraints import (
infer_constraints, infer_constraints_for_callable, SUBTYPE_OF, SUPERTYPE_OF
)
from mypy.types import Type, TypeVarId, CallableType
from mypy.types import Type, TypeVarId, CallableType, Instance
from mypy.nodes import ArgKind
from mypy.solve import solve_constraints


class ArgumentInferContext(NamedTuple):
"""Type argument inference context.

We need this because we pass around ``Mapping`` and ``Iterable`` types.
These types are only known by ``TypeChecker`` itself.
It is required for ``*`` and ``**`` argument inference.

https://github.com/python/mypy/issues/11144
"""

mapping_type: Instance
iterable_type: Instance


def infer_function_type_arguments(callee_type: CallableType,
arg_types: Sequence[Optional[Type]],
arg_kinds: List[ArgKind],
formal_to_actual: List[List[int]],
context: ArgumentInferContext,
strict: bool = True) -> List[Optional[Type]]:
"""Infer the type arguments of a generic function.

Expand All @@ -30,7 +45,7 @@ def infer_function_type_arguments(callee_type: CallableType,
"""
# Infer constraints.
constraints = infer_constraints_for_callable(
callee_type, arg_types, arg_kinds, formal_to_actual)
callee_type, arg_types, arg_kinds, formal_to_actual, context)

# Solve constraints.
type_vars = callee_type.type_var_ids()
Expand Down
35 changes: 4 additions & 31 deletions test-data/unit/check-expressions.test
Expand Up @@ -63,13 +63,7 @@ if str():
a = 1.1
class A:
pass
[file builtins.py]
class object:
def __init__(self): pass
class type: pass
class function: pass
class float: pass
class str: pass
[builtins fixtures/dict.pyi]

[case testComplexLiteral]
a = 0.0j
Expand All @@ -80,13 +74,7 @@ if str():
a = 1.1j
class A:
pass
[file builtins.py]
class object:
def __init__(self): pass
class type: pass
class function: pass
class complex: pass
class str: pass
[builtins fixtures/dict.pyi]

[case testBytesLiteral]
b, a = None, None # type: (bytes, A)
Expand All @@ -99,14 +87,7 @@ if str():
if str():
a = b'foo' # E: Incompatible types in assignment (expression has type "bytes", variable has type "A")
class A: pass
[file builtins.py]
class object:
def __init__(self): pass
class type: pass
class tuple: pass
class function: pass
class bytes: pass
class str: pass
[builtins fixtures/dict.pyi]

[case testUnicodeLiteralInPython3]
s = None # type: str
Expand Down Expand Up @@ -1535,15 +1516,7 @@ if str():
....a # E: "ellipsis" has no attribute "a"

class A: pass
[file builtins.py]
class object:
def __init__(self): pass
class ellipsis:
def __init__(self): pass
__class__ = object()
class type: pass
class function: pass
class str: pass
[builtins fixtures/dict.pyi]
[out]


Expand Down