Skip to content

Commit

Permalink
Make disable_type_names a context manager (#11716)
Browse files Browse the repository at this point in the history
  • Loading branch information
sobolevn committed Jan 7, 2022
1 parent e8cf960 commit 55bee20
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 23 deletions.
28 changes: 18 additions & 10 deletions mypy/checkexpr.py
Expand Up @@ -2064,11 +2064,19 @@ def check_union_call(self,
arg_names: Optional[Sequence[Optional[str]]],
context: Context,
arg_messages: MessageBuilder) -> Tuple[Type, Type]:
self.msg.disable_type_names += 1
results = [self.check_call(subtype, args, arg_kinds, context, arg_names,
arg_messages=arg_messages)
for subtype in callee.relevant_items()]
self.msg.disable_type_names -= 1
with self.msg.disable_type_names():
results = [
self.check_call(
subtype,
args,
arg_kinds,
context,
arg_names,
arg_messages=arg_messages,
)
for subtype in callee.relevant_items()
]

return (make_simplified_union([res[0] for res in results]),
callee)

Expand Down Expand Up @@ -2462,11 +2470,11 @@ def check_union_method_call_by_name(self,
for typ in base_type.relevant_items():
# Format error messages consistently with
# mypy.checkmember.analyze_union_member_access().
local_errors.disable_type_names += 1
item, meth_item = self.check_method_call_by_name(method, typ, args, arg_kinds,
context, local_errors,
original_type)
local_errors.disable_type_names -= 1
with local_errors.disable_type_names():
item, meth_item = self.check_method_call_by_name(
method, typ, args, arg_kinds,
context, local_errors, original_type,
)
res.append(item)
meth_res.append(meth_item)
return make_simplified_union(res), make_simplified_union(meth_res)
Expand Down
13 changes: 6 additions & 7 deletions mypy/checkmember.py
Expand Up @@ -311,13 +311,12 @@ def analyze_type_type_member_access(name: str,


def analyze_union_member_access(name: str, typ: UnionType, mx: MemberContext) -> Type:
mx.msg.disable_type_names += 1
results = []
for subtype in typ.relevant_items():
# Self types should be bound to every individual item of a union.
item_mx = mx.copy_modified(self_type=subtype)
results.append(_analyze_member_access(name, subtype, item_mx))
mx.msg.disable_type_names -= 1
with mx.msg.disable_type_names():
results = []
for subtype in typ.relevant_items():
# Self types should be bound to every individual item of a union.
item_mx = mx.copy_modified(self_type=subtype)
results.append(_analyze_member_access(name, subtype, item_mx))
return make_simplified_union(results)


Expand Down
20 changes: 14 additions & 6 deletions mypy/messages.py
Expand Up @@ -107,13 +107,13 @@ class MessageBuilder:
disable_count = 0

# Hack to deduplicate error messages from union types
disable_type_names = 0
disable_type_names_count = 0

def __init__(self, errors: Errors, modules: Dict[str, MypyFile]) -> None:
self.errors = errors
self.modules = modules
self.disable_count = 0
self.disable_type_names = 0
self.disable_type_names_count = 0

#
# Helpers
Expand All @@ -122,7 +122,7 @@ def __init__(self, errors: Errors, modules: Dict[str, MypyFile]) -> None:
def copy(self) -> 'MessageBuilder':
new = MessageBuilder(self.errors.copy(), self.modules)
new.disable_count = self.disable_count
new.disable_type_names = self.disable_type_names
new.disable_type_names_count = self.disable_type_names_count
return new

def clean_copy(self) -> 'MessageBuilder':
Expand All @@ -145,6 +145,14 @@ def disable_errors(self) -> Iterator[None]:
finally:
self.disable_count -= 1

@contextmanager
def disable_type_names(self) -> Iterator[None]:
self.disable_type_names_count += 1
try:
yield
finally:
self.disable_type_names_count -= 1

def is_errors(self) -> bool:
return self.errors.is_errors()

Expand Down Expand Up @@ -298,7 +306,7 @@ def has_no_attr(self,
extra = ' (not iterable)'
elif member == '__aiter__':
extra = ' (not async iterable)'
if not self.disable_type_names:
if not self.disable_type_names_count:
failed = False
if isinstance(original_type, Instance) and original_type.type.names:
alternatives = set(original_type.type.names.keys())
Expand Down Expand Up @@ -380,7 +388,7 @@ def unsupported_operand_types(self,
else:
right_str = format_type(right_type)

if self.disable_type_names:
if self.disable_type_names_count:
msg = 'Unsupported operand types for {} (likely involving Union)'.format(op)
else:
msg = 'Unsupported operand types for {} ({} and {})'.format(
Expand All @@ -389,7 +397,7 @@ def unsupported_operand_types(self,

def unsupported_left_operand(self, op: str, typ: Type,
context: Context) -> None:
if self.disable_type_names:
if self.disable_type_names_count:
msg = 'Unsupported left operand type for {} (some union)'.format(op)
else:
msg = 'Unsupported left operand type for {} ({})'.format(
Expand Down

0 comments on commit 55bee20

Please sign in to comment.