Skip to content

Commit

Permalink
Properly check *CustomType and **CustomType arguments (python#11151)
Browse files Browse the repository at this point in the history
  • Loading branch information
sobolevn authored and tushar-deepsource committed Jan 20, 2022
1 parent 54fe2b6 commit fbf722a
Show file tree
Hide file tree
Showing 15 changed files with 336 additions and 62 deletions.
45 changes: 31 additions & 14 deletions mypy/argmap.py
@@ -1,12 +1,16 @@
"""Utilities for mapping between actual and formal arguments (and their types)."""

from typing import List, Optional, Sequence, Callable, Set
from typing import TYPE_CHECKING, List, Optional, Sequence, Callable, Set

from mypy.maptype import map_instance_to_supertype
from mypy.types import (
Type, Instance, TupleType, AnyType, TypeOfAny, TypedDictType, get_proper_type
)
from mypy import nodes

if TYPE_CHECKING:
from mypy.infer import ArgumentInferContext


def map_actuals_to_formals(actual_kinds: List[nodes.ArgKind],
actual_names: Optional[Sequence[Optional[str]]],
Expand Down Expand Up @@ -140,11 +144,13 @@ def f(x: int, *args: str) -> None: ...
needs a separate instance since instances have per-call state.
"""

def __init__(self) -> None:
def __init__(self, context: 'ArgumentInferContext') -> None:
# Next tuple *args index to use.
self.tuple_index = 0
# Keyword arguments in TypedDict **kwargs used.
self.kwargs_used: Set[str] = set()
# Type context for `*` and `**` arg kinds.
self.context = context

def expand_actual_type(self,
actual_type: Type,
Expand All @@ -162,16 +168,21 @@ def expand_actual_type(self,
This is supposed to be called for each formal, in order. Call multiple times per
formal if multiple actuals map to a formal.
"""
from mypy.subtypes import is_subtype

actual_type = get_proper_type(actual_type)
if actual_kind == nodes.ARG_STAR:
if isinstance(actual_type, Instance):
if actual_type.type.fullname == 'builtins.list':
# List *arg.
return actual_type.args[0]
elif actual_type.args:
# TODO: Try to map type arguments to Iterable
return actual_type.args[0]
if isinstance(actual_type, Instance) and actual_type.args:
if is_subtype(actual_type, self.context.iterable_type):
return map_instance_to_supertype(
actual_type,
self.context.iterable_type.type,
).args[0]
else:
# We cannot properly unpack anything other
# than `Iterable` type with `*`.
# Just return `Any`, other parts of code would raise
# a different error for improper use.
return AnyType(TypeOfAny.from_error)
elif isinstance(actual_type, TupleType):
# Get the next tuple item of a tuple *arg.
Expand All @@ -193,11 +204,17 @@ def expand_actual_type(self,
formal_name = (set(actual_type.items.keys()) - self.kwargs_used).pop()
self.kwargs_used.add(formal_name)
return actual_type.items[formal_name]
elif (isinstance(actual_type, Instance)
and (actual_type.type.fullname == 'builtins.dict')):
# Dict **arg.
# TODO: Handle arbitrary Mapping
return actual_type.args[1]
elif (
isinstance(actual_type, Instance) and
len(actual_type.args) > 1 and
is_subtype(actual_type, self.context.mapping_type)
):
# Only `Mapping` type can be unpacked with `**`.
# Other types will produce an error somewhere else.
return map_instance_to_supertype(
actual_type,
self.context.mapping_type.type,
).args[1]
else:
return AnyType(TypeOfAny.from_error)
else:
Expand Down
17 changes: 14 additions & 3 deletions mypy/checkexpr.py
Expand Up @@ -45,7 +45,9 @@
from mypy.maptype import map_instance_to_supertype
from mypy.messages import MessageBuilder
from mypy import message_registry
from mypy.infer import infer_type_arguments, infer_function_type_arguments
from mypy.infer import (
ArgumentInferContext, infer_type_arguments, infer_function_type_arguments,
)
from mypy import join
from mypy.meet import narrow_declared_type, is_overlapping_types
from mypy.subtypes import is_subtype, is_proper_subtype, is_equivalent, non_method_protocol_members
Expand Down Expand Up @@ -1240,6 +1242,7 @@ def infer_function_type_arguments(self, callee_type: CallableType,

inferred_args = infer_function_type_arguments(
callee_type, pass1_args, arg_kinds, formal_to_actual,
context=self.argument_infer_context(),
strict=self.chk.in_checked_function())

if 2 in arg_pass_nums:
Expand Down Expand Up @@ -1301,10 +1304,18 @@ def infer_function_type_arguments_pass2(
callee_type, args, arg_kinds, formal_to_actual)

inferred_args = infer_function_type_arguments(
callee_type, arg_types, arg_kinds, formal_to_actual)
callee_type, arg_types, arg_kinds, formal_to_actual,
context=self.argument_infer_context(),
)

return callee_type, inferred_args

def argument_infer_context(self) -> ArgumentInferContext:
return ArgumentInferContext(
self.chk.named_type('typing.Mapping'),
self.chk.named_type('typing.Iterable'),
)

def get_arg_infer_passes(self, arg_types: List[Type],
formal_to_actual: List[List[int]],
num_actuals: int) -> List[int]:
Expand Down Expand Up @@ -1479,7 +1490,7 @@ def check_argument_types(self,
messages = messages or self.msg
check_arg = check_arg or self.check_arg
# Keep track of consumed tuple *arg items.
mapper = ArgTypeExpander()
mapper = ArgTypeExpander(self.argument_infer_context())
for i, actuals in enumerate(formal_to_actual):
for actual in actuals:
actual_type = arg_types[actual]
Expand Down
14 changes: 10 additions & 4 deletions mypy/constraints.py
@@ -1,6 +1,6 @@
"""Type inference constraints."""

from typing import Iterable, List, Optional, Sequence
from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence
from typing_extensions import Final

from mypy.types import (
Expand All @@ -18,6 +18,9 @@
from mypy.argmap import ArgTypeExpander
from mypy.typestate import TypeState

if TYPE_CHECKING:
from mypy.infer import ArgumentInferContext

SUBTYPE_OF: Final = 0
SUPERTYPE_OF: Final = 1

Expand Down Expand Up @@ -45,14 +48,17 @@ def __repr__(self) -> str:


def infer_constraints_for_callable(
callee: CallableType, arg_types: Sequence[Optional[Type]], arg_kinds: List[ArgKind],
formal_to_actual: List[List[int]]) -> List[Constraint]:
callee: CallableType,
arg_types: Sequence[Optional[Type]],
arg_kinds: List[ArgKind],
formal_to_actual: List[List[int]],
context: 'ArgumentInferContext') -> List[Constraint]:
"""Infer type variable constraints for a callable and actual arguments.
Return a list of constraints.
"""
constraints: List[Constraint] = []
mapper = ArgTypeExpander()
mapper = ArgTypeExpander(context)

for i, actuals in enumerate(formal_to_actual):
for actual in actuals:
Expand Down
21 changes: 18 additions & 3 deletions mypy/infer.py
@@ -1,19 +1,34 @@
"""Utilities for type argument inference."""

from typing import List, Optional, Sequence
from typing import List, Optional, Sequence, NamedTuple

from mypy.constraints import (
infer_constraints, infer_constraints_for_callable, SUBTYPE_OF, SUPERTYPE_OF
)
from mypy.types import Type, TypeVarId, CallableType
from mypy.types import Type, TypeVarId, CallableType, Instance
from mypy.nodes import ArgKind
from mypy.solve import solve_constraints


class ArgumentInferContext(NamedTuple):
"""Type argument inference context.
We need this because we pass around ``Mapping`` and ``Iterable`` types.
These types are only known by ``TypeChecker`` itself.
It is required for ``*`` and ``**`` argument inference.
https://github.com/python/mypy/issues/11144
"""

mapping_type: Instance
iterable_type: Instance


def infer_function_type_arguments(callee_type: CallableType,
arg_types: Sequence[Optional[Type]],
arg_kinds: List[ArgKind],
formal_to_actual: List[List[int]],
context: ArgumentInferContext,
strict: bool = True) -> List[Optional[Type]]:
"""Infer the type arguments of a generic function.
Expand All @@ -30,7 +45,7 @@ def infer_function_type_arguments(callee_type: CallableType,
"""
# Infer constraints.
constraints = infer_constraints_for_callable(
callee_type, arg_types, arg_kinds, formal_to_actual)
callee_type, arg_types, arg_kinds, formal_to_actual, context)

# Solve constraints.
type_vars = callee_type.type_var_ids()
Expand Down
35 changes: 4 additions & 31 deletions test-data/unit/check-expressions.test
Expand Up @@ -63,13 +63,7 @@ if str():
a = 1.1
class A:
pass
[file builtins.py]
class object:
def __init__(self): pass
class type: pass
class function: pass
class float: pass
class str: pass
[builtins fixtures/dict.pyi]

[case testComplexLiteral]
a = 0.0j
Expand All @@ -80,13 +74,7 @@ if str():
a = 1.1j
class A:
pass
[file builtins.py]
class object:
def __init__(self): pass
class type: pass
class function: pass
class complex: pass
class str: pass
[builtins fixtures/dict.pyi]

[case testBytesLiteral]
b, a = None, None # type: (bytes, A)
Expand All @@ -99,14 +87,7 @@ if str():
if str():
a = b'foo' # E: Incompatible types in assignment (expression has type "bytes", variable has type "A")
class A: pass
[file builtins.py]
class object:
def __init__(self): pass
class type: pass
class tuple: pass
class function: pass
class bytes: pass
class str: pass
[builtins fixtures/dict.pyi]

[case testUnicodeLiteralInPython3]
s = None # type: str
Expand Down Expand Up @@ -1535,15 +1516,7 @@ if str():
....a # E: "ellipsis" has no attribute "a"

class A: pass
[file builtins.py]
class object:
def __init__(self): pass
class ellipsis:
def __init__(self): pass
__class__ = object()
class type: pass
class function: pass
class str: pass
[builtins fixtures/dict.pyi]
[out]


Expand Down

0 comments on commit fbf722a

Please sign in to comment.