Skip to content

Commit

Permalink
Now TypeInfo.get_method also returns Decorator nodes (#11150)
Browse files Browse the repository at this point in the history
Support decorators properly in additional contexts.

Closes #10409
  • Loading branch information
sobolevn committed Dec 3, 2021
1 parent 872bc86 commit 4e34fec
Show file tree
Hide file tree
Showing 10 changed files with 505 additions and 66 deletions.
30 changes: 23 additions & 7 deletions mypy/checker.py
Expand Up @@ -46,7 +46,8 @@
)
import mypy.checkexpr
from mypy.checkmember import (
analyze_member_access, analyze_descriptor_access, type_object_type,
MemberContext, analyze_member_access, analyze_descriptor_access, analyze_var,
type_object_type,
)
from mypy.typeops import (
map_type_from_supertype, bind_self, erase_to_bound, make_simplified_union,
Expand Down Expand Up @@ -3205,9 +3206,12 @@ def check_member_assignment(self, instance_type: Type, attribute_type: Type,
code=codes.ASSIGNMENT)
return rvalue_type, attribute_type, True

get_type = analyze_descriptor_access(
instance_type, attribute_type, self.named_type,
self.msg, context, chk=self)
mx = MemberContext(
is_lvalue=False, is_super=False, is_operator=False,
original_type=instance_type, context=context, self_type=None,
msg=self.msg, chk=self,
)
get_type = analyze_descriptor_access(attribute_type, mx)
if not attribute_type.type.has_readable_member('__set__'):
# If there is no __set__, we type-check that the assigned value matches
# the return type of __get__. This doesn't match the python semantics,
Expand All @@ -3221,9 +3225,15 @@ def check_member_assignment(self, instance_type: Type, attribute_type: Type,
if dunder_set is None:
self.fail(message_registry.DESCRIPTOR_SET_NOT_CALLABLE.format(attribute_type), context)
return AnyType(TypeOfAny.from_error), get_type, False

function = function_type(dunder_set, self.named_type('builtins.function'))
bound_method = bind_self(function, attribute_type)
if isinstance(dunder_set, Decorator):
bound_method = analyze_var(
'__set__', dunder_set.var, attribute_type, attribute_type.type, mx,
)
else:
bound_method = bind_self(
function_type(dunder_set, self.named_type('builtins.function')),
attribute_type,
)
typ = map_instance_to_supertype(attribute_type, dunder_set.info)
dunder_set_type = expand_type_by_instance(bound_method, typ)

Expand Down Expand Up @@ -6214,6 +6224,12 @@ def is_untyped_decorator(typ: Optional[Type]) -> bool:
elif isinstance(typ, Instance):
method = typ.type.get_method('__call__')
if method:
if isinstance(method, Decorator):
return (
is_untyped_decorator(method.func.type)
or is_untyped_decorator(method.var.type)
)

if isinstance(method.type, Overloaded):
return any(is_untyped_decorator(item) for item in method.type.items)
else:
Expand Down
109 changes: 64 additions & 45 deletions mypy/checkmember.py
Expand Up @@ -67,14 +67,17 @@ def not_ready_callback(self, name: str, context: Context) -> None:
self.chk.handle_cannot_determine_type(name, context)

def copy_modified(self, *, messages: Optional[MessageBuilder] = None,
self_type: Optional[Type] = None) -> 'MemberContext':
self_type: Optional[Type] = None,
is_lvalue: Optional[bool] = None) -> 'MemberContext':
mx = MemberContext(self.is_lvalue, self.is_super, self.is_operator,
self.original_type, self.context, self.msg, self.chk,
self.self_type, self.module_symbol_table)
if messages is not None:
mx.msg = messages
if self_type is not None:
mx.self_type = self_type
if is_lvalue is not None:
mx.is_lvalue = is_lvalue
return mx


Expand Down Expand Up @@ -197,7 +200,7 @@ def analyze_instance_member_access(name: str,

# Look up the member. First look up the method dictionary.
method = info.get_method(name)
if method:
if method and not isinstance(method, Decorator):
if method.is_property:
assert isinstance(method, OverloadedFuncDef)
first_item = cast(Decorator, method.items[0])
Expand Down Expand Up @@ -390,29 +393,46 @@ def analyze_member_var_access(name: str,
if not mx.is_lvalue:
for method_name in ('__getattribute__', '__getattr__'):
method = info.get_method(method_name)

# __getattribute__ is defined on builtins.object and returns Any, so without
# the guard this search will always find object.__getattribute__ and conclude
# that the attribute exists
if method and method.info.fullname != 'builtins.object':
function = function_type(method, mx.named_type('builtins.function'))
bound_method = bind_self(function, mx.self_type)
if isinstance(method, Decorator):
# https://github.com/python/mypy/issues/10409
bound_method = analyze_var(method_name, method.var, itype, info, mx)
else:
bound_method = bind_self(
function_type(method, mx.named_type('builtins.function')),
mx.self_type,
)
typ = map_instance_to_supertype(itype, method.info)
getattr_type = get_proper_type(expand_type_by_instance(bound_method, typ))
if isinstance(getattr_type, CallableType):
result = getattr_type.ret_type

# Call the attribute hook before returning.
fullname = '{}.{}'.format(method.info.fullname, name)
hook = mx.chk.plugin.get_attribute_hook(fullname)
if hook:
result = hook(AttributeContext(get_proper_type(mx.original_type),
result, mx.context, mx.chk))
return result
else:
result = getattr_type

# Call the attribute hook before returning.
fullname = '{}.{}'.format(method.info.fullname, name)
hook = mx.chk.plugin.get_attribute_hook(fullname)
if hook:
result = hook(AttributeContext(get_proper_type(mx.original_type),
result, mx.context, mx.chk))
return result
else:
setattr_meth = info.get_method('__setattr__')
if setattr_meth and setattr_meth.info.fullname != 'builtins.object':
setattr_func = function_type(setattr_meth, mx.named_type('builtins.function'))
bound_type = bind_self(setattr_func, mx.self_type)
if isinstance(setattr_meth, Decorator):
bound_type = analyze_var(
name, setattr_meth.var, itype, info,
mx.copy_modified(is_lvalue=False),
)
else:
bound_type = bind_self(
function_type(setattr_meth, mx.named_type('builtins.function')),
mx.self_type,
)
typ = map_instance_to_supertype(itype, setattr_meth.info)
setattr_type = get_proper_type(expand_type_by_instance(bound_type, typ))
if isinstance(setattr_type, CallableType) and len(setattr_type.arg_types) > 0:
Expand Down Expand Up @@ -441,32 +461,24 @@ def check_final_member(name: str, info: TypeInfo, msg: MessageBuilder, ctx: Cont
msg.cant_assign_to_final(name, attr_assign=True, ctx=ctx)


def analyze_descriptor_access(instance_type: Type,
descriptor_type: Type,
named_type: Callable[[str], Instance],
msg: MessageBuilder,
context: Context, *,
chk: 'mypy.checker.TypeChecker') -> Type:
def analyze_descriptor_access(descriptor_type: Type,
mx: MemberContext) -> Type:
"""Type check descriptor access.
Arguments:
instance_type: The type of the instance on which the descriptor
attribute is being accessed (the type of ``a`` in ``a.f`` when
``f`` is a descriptor).
descriptor_type: The type of the descriptor attribute being accessed
(the type of ``f`` in ``a.f`` when ``f`` is a descriptor).
context: The node defining the context of this inference.
mx: The current member access context.
Return:
The return type of the appropriate ``__get__`` overload for the descriptor.
"""
instance_type = get_proper_type(instance_type)
instance_type = get_proper_type(mx.original_type)
descriptor_type = get_proper_type(descriptor_type)

if isinstance(descriptor_type, UnionType):
# Map the access over union types
return make_simplified_union([
analyze_descriptor_access(instance_type, typ, named_type,
msg, context, chk=chk)
analyze_descriptor_access(typ, mx)
for typ in descriptor_type.items
])
elif not isinstance(descriptor_type, Instance):
Expand All @@ -476,13 +488,21 @@ def analyze_descriptor_access(instance_type: Type,
return descriptor_type

dunder_get = descriptor_type.type.get_method('__get__')

if dunder_get is None:
msg.fail(message_registry.DESCRIPTOR_GET_NOT_CALLABLE.format(descriptor_type), context)
mx.msg.fail(message_registry.DESCRIPTOR_GET_NOT_CALLABLE.format(descriptor_type),
mx.context)
return AnyType(TypeOfAny.from_error)

function = function_type(dunder_get, named_type('builtins.function'))
bound_method = bind_self(function, descriptor_type)
if isinstance(dunder_get, Decorator):
bound_method = analyze_var(
'__set__', dunder_get.var, descriptor_type, descriptor_type.type, mx,
)
else:
bound_method = bind_self(
function_type(dunder_get, mx.named_type('builtins.function')),
descriptor_type,
)

typ = map_instance_to_supertype(descriptor_type, dunder_get.info)
dunder_get_type = expand_type_by_instance(bound_method, typ)

Expand All @@ -495,19 +515,19 @@ def analyze_descriptor_access(instance_type: Type,
else:
owner_type = instance_type

callable_name = chk.expr_checker.method_fullname(descriptor_type, "__get__")
dunder_get_type = chk.expr_checker.transform_callee_type(
callable_name = mx.chk.expr_checker.method_fullname(descriptor_type, "__get__")
dunder_get_type = mx.chk.expr_checker.transform_callee_type(
callable_name, dunder_get_type,
[TempNode(instance_type, context=context),
TempNode(TypeType.make_normalized(owner_type), context=context)],
[ARG_POS, ARG_POS], context, object_type=descriptor_type,
[TempNode(instance_type, context=mx.context),
TempNode(TypeType.make_normalized(owner_type), context=mx.context)],
[ARG_POS, ARG_POS], mx.context, object_type=descriptor_type,
)

_, inferred_dunder_get_type = chk.expr_checker.check_call(
_, inferred_dunder_get_type = mx.chk.expr_checker.check_call(
dunder_get_type,
[TempNode(instance_type, context=context),
TempNode(TypeType.make_normalized(owner_type), context=context)],
[ARG_POS, ARG_POS], context, object_type=descriptor_type,
[TempNode(instance_type, context=mx.context),
TempNode(TypeType.make_normalized(owner_type), context=mx.context)],
[ARG_POS, ARG_POS], mx.context, object_type=descriptor_type,
callable_name=callable_name)

inferred_dunder_get_type = get_proper_type(inferred_dunder_get_type)
Expand All @@ -516,7 +536,8 @@ def analyze_descriptor_access(instance_type: Type,
return inferred_dunder_get_type

if not isinstance(inferred_dunder_get_type, CallableType):
msg.fail(message_registry.DESCRIPTOR_GET_NOT_CALLABLE.format(descriptor_type), context)
mx.msg.fail(message_registry.DESCRIPTOR_GET_NOT_CALLABLE.format(descriptor_type),
mx.context)
return AnyType(TypeOfAny.from_error)

return inferred_dunder_get_type.ret_type
Expand Down Expand Up @@ -605,8 +626,7 @@ def analyze_var(name: str,
fullname = '{}.{}'.format(var.info.fullname, name)
hook = mx.chk.plugin.get_attribute_hook(fullname)
if result and not mx.is_lvalue and not implicit:
result = analyze_descriptor_access(mx.original_type, result, mx.named_type,
mx.msg, mx.context, chk=mx.chk)
result = analyze_descriptor_access(result, mx)
if hook:
result = hook(AttributeContext(get_proper_type(mx.original_type),
result, mx.context, mx.chk))
Expand Down Expand Up @@ -785,8 +805,7 @@ def analyze_class_attribute_access(itype: Instance,
result = add_class_tvars(t, isuper, is_classmethod,
mx.self_type, original_vars=original_vars)
if not mx.is_lvalue:
result = analyze_descriptor_access(mx.original_type, result, mx.named_type,
mx.msg, mx.context, chk=mx.chk)
result = analyze_descriptor_access(result, mx)
return result
elif isinstance(node.node, Var):
mx.not_ready_callback(name, mx.context)
Expand Down
4 changes: 3 additions & 1 deletion mypy/nodes.py
Expand Up @@ -2687,12 +2687,14 @@ def __bool__(self) -> bool:
def has_readable_member(self, name: str) -> bool:
return self.get(name) is not None

def get_method(self, name: str) -> Optional[FuncBase]:
def get_method(self, name: str) -> Union[FuncBase, Decorator, None]:
for cls in self.mro:
if name in cls.names:
node = cls.names[name].node
if isinstance(node, FuncBase):
return node
elif isinstance(node, Decorator): # Two `if`s make `mypyc` happy
return node
else:
return None
return None
Expand Down
24 changes: 12 additions & 12 deletions mypy/subtypes.py
Expand Up @@ -650,6 +650,8 @@ def find_member(name: str,
info = itype.type
method = info.get_method(name)
if method:
if isinstance(method, Decorator):
return find_node_type(method.var, itype, subtype)
if method.is_property:
assert isinstance(method, OverloadedFuncDef)
dec = method.items[0]
Expand All @@ -659,12 +661,7 @@ def find_member(name: str,
else:
# don't have such method, maybe variable or decorator?
node = info.get(name)
if not node:
v = None
else:
v = node.node
if isinstance(v, Decorator):
v = v.var
v = node.node if node else None
if isinstance(v, Var):
return find_node_type(v, itype, subtype)
if (not v and name not in ['__getattr__', '__setattr__', '__getattribute__'] and
Expand All @@ -676,9 +673,13 @@ def find_member(name: str,
# structural subtyping.
method = info.get_method(method_name)
if method and method.info.fullname != 'builtins.object':
getattr_type = get_proper_type(find_node_type(method, itype, subtype))
if isinstance(method, Decorator):
getattr_type = get_proper_type(find_node_type(method.var, itype, subtype))
else:
getattr_type = get_proper_type(find_node_type(method, itype, subtype))
if isinstance(getattr_type, CallableType):
return getattr_type.ret_type
return getattr_type
if itype.type.fallback_to_any:
return AnyType(TypeOfAny.special_form)
return None
Expand All @@ -698,8 +699,10 @@ def get_member_flags(name: str, info: TypeInfo) -> Set[int]:
method = info.get_method(name)
setattr_meth = info.get_method('__setattr__')
if method:
# this could be settable property
if method.is_property:
if isinstance(method, Decorator):
if method.var.is_staticmethod or method.var.is_classmethod:
return {IS_CLASS_OR_STATIC}
elif method.is_property: # this could be settable property
assert isinstance(method, OverloadedFuncDef)
dec = method.items[0]
assert isinstance(dec, Decorator)
Expand All @@ -712,9 +715,6 @@ def get_member_flags(name: str, info: TypeInfo) -> Set[int]:
return {IS_SETTABLE}
return set()
v = node.node
if isinstance(v, Decorator):
if v.var.is_staticmethod or v.var.is_classmethod:
return {IS_CLASS_OR_STATIC}
# just a variable
if isinstance(v, Var) and not v.is_property:
flags = {IS_SETTABLE}
Expand Down

0 comments on commit 4e34fec

Please sign in to comment.