diff --git a/pyupgrade.py b/pyupgrade.py index 6b84e939..9b16c591 100644 --- a/pyupgrade.py +++ b/pyupgrade.py @@ -1207,7 +1207,7 @@ class ClassInfo: def __init__(self, node: ast.ClassDef) -> None: self.bases = node.bases self.name = node.name - self.def_depth = 0 + self.def_names: List[str] = [] self.first_arg_name = '' class Scope: @@ -1397,13 +1397,14 @@ def _track_def_depth( node: AnyFunctionDef, ) -> Generator[None, None, None]: class_info = self._class_info_stack[-1] - class_info.def_depth += 1 - if class_info.def_depth == 1 and node.args.args: + def_name = '' if isinstance(node, ast.Lambda) else node.name + class_info.def_names.append(def_name) + if len(class_info.def_names) == 1 and node.args.args: class_info.first_arg_name = node.args.args[0].arg try: yield finally: - class_info.def_depth -= 1 + class_info.def_names.pop() @contextlib.contextmanager def _scope(self) -> Generator[None, None, None]: @@ -1543,7 +1544,7 @@ def visit_Call(self, node: ast.Call) -> None: elif ( not self._in_comp and self._class_info_stack and - self._class_info_stack[-1].def_depth == 1 and + len(self._class_info_stack[-1].def_names) == 1 and isinstance(node.func, ast.Name) and node.func.id == 'super' and len(node.args) == 2 and @@ -1556,13 +1557,14 @@ def visit_Call(self, node: ast.Call) -> None: elif ( not self._in_comp and self._class_info_stack and - self._class_info_stack[-1].def_depth == 1 and + len(self._class_info_stack[-1].def_names) == 1 and len(self._class_info_stack[-1].bases) == 1 and _is_simple_base(self._class_info_stack[-1].bases[0]) and isinstance(node.func, ast.Attribute) and targets_same( self._class_info_stack[-1].bases[0], node.func.value, ) and + node.func.attr == self._class_info_stack[-1].def_names[0] and len(node.args) >= 1 and isinstance(node.args[0], ast.Name) and node.args[0].id == self._class_info_stack[-1].first_arg_name diff --git a/tests/super_test.py b/tests/super_test.py index 95eba81e..377dfa43 100644 --- a/tests/super_test.py +++ b/tests/super_test.py @@ -172,6 +172,15 @@ def test_fix_super(s, expected): ' a().b.f(self)\n', id='non simple attribute base', ), + pytest.param( + 'class C:\n' + ' @classmethod\n' + ' def make(cls, instance):\n' + ' ...\n' + 'class D(C):\n' + ' def find(self):\n' + ' return C.make(self)\n', + ), ), ) def test_old_style_class_super_noop(s):