From 7500673b58106fb35a9545f9b1ea9d4bded60a4b Mon Sep 17 00:00:00 2001 From: Anthony Sottile Date: Fri, 19 Jun 2020 14:27:36 -0700 Subject: [PATCH] Revert "Revert "Merge pull request #317 from asottile/old_super"" This reverts commit 2719335fa7bdb582b35ac90547a0f763d4225036. --- pyupgrade.py | 49 +++++++++++++++++++++++++++--- tests/super_test.py | 74 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+), 5 deletions(-) diff --git a/pyupgrade.py b/pyupgrade.py index 54bbe0d7..2233ad35 100644 --- a/pyupgrade.py +++ b/pyupgrade.py @@ -1163,8 +1163,8 @@ def fields_same(n1: ast.AST, n2: ast.AST) -> bool: return True -def targets_same(target: ast.AST, yield_value: ast.AST) -> bool: - for t1, t2 in zip(ast.walk(target), ast.walk(yield_value)): +def targets_same(node1: ast.AST, node2: ast.AST) -> bool: + for t1, t2 in zip(ast.walk(node1), ast.walk(node2)): # ignore `ast.Load` / `ast.Store` if _all_isinstance((t1, t2), ast.expr_context): continue @@ -1183,6 +1183,15 @@ def _is_codec(encoding: str, name: str) -> bool: return False +def _is_simple_base(base: ast.AST) -> bool: + return ( + isinstance(base, ast.Name) or ( + isinstance(base, ast.Attribute) and + _is_simple_base(base.value) + ) + ) + + class FindPy3Plus(ast.NodeVisitor): OS_ERROR_ALIASES = frozenset(( 'EnvironmentError', @@ -1201,8 +1210,9 @@ class FindPy3Plus(ast.NodeVisitor): MOCK_MODULES = frozenset(('mock', 'mock.mock')) class ClassInfo: - def __init__(self, name: str) -> None: - self.name = name + def __init__(self, node: ast.ClassDef) -> None: + self.bases = node.bases + self.name = node.name self.def_depth = 0 self.first_arg_name = '' @@ -1256,6 +1266,7 @@ def __init__(self, keep_mock: bool) -> None: self._class_info_stack: List[FindPy3Plus.ClassInfo] = [] self._in_comp = 0 self.super_calls: Dict[Offset, ast.Call] = {} + self.old_style_super_calls: Set[Offset] = set() self._in_async_def = False self._scope_stack: List[FindPy3Plus.Scope] = [] self.yield_from_fors: Set[Offset] = set() @@ -1408,7 +1419,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None: ): self.six_with_metaclass.add(_ast_to_offset(node.bases[0])) - self._class_info_stack.append(FindPy3Plus.ClassInfo(node.name)) + self._class_info_stack.append(FindPy3Plus.ClassInfo(node)) self.generic_visit(node) self._class_info_stack.pop() @@ -1574,6 +1585,21 @@ def visit_Call(self, node: ast.Call) -> None: node.args[1].id == self._class_info_stack[-1].first_arg_name ): self.super_calls[_ast_to_offset(node)] = node + elif ( + not self._in_comp and + self._class_info_stack and + self._class_info_stack[-1].def_depth == 1 and + len(self._class_info_stack[-1].bases) == 1 and + _is_simple_base(self._class_info_stack[-1].bases[0]) and + isinstance(node.func, ast.Attribute) and + targets_same( + self._class_info_stack[-1].bases[0], node.func.value, + ) and + len(node.args) >= 1 and + isinstance(node.args[0], ast.Name) and + node.args[0].id == self._class_info_stack[-1].first_arg_name + ): + self.old_style_super_calls.add(_ast_to_offset(node)) elif ( ( self._is_six(node.func, SIX_NATIVE_STR) or @@ -2070,6 +2096,7 @@ def _fix_py3_plus( visitor.six_type_ctx, visitor.six_with_metaclass, visitor.super_calls, + visitor.old_style_super_calls, visitor.yield_from_fors, )): return contents_text @@ -2232,6 +2259,18 @@ def _replace(i: int, mapping: Dict[str, str], node: NameOrAttr) -> None: call = visitor.super_calls[token.offset] victims = _victims(tokens, i, call, gen=False) del tokens[victims.starts[0] + 1:victims.ends[-1]] + elif token.offset in visitor.old_style_super_calls: + j = _find_open_paren(tokens, i) + k = j - 1 + while tokens[k].src != '.': + k -= 1 + func_args, end = _parse_call_args(tokens, j) + # remove the first argument + if len(func_args) == 1: + del tokens[func_args[0][0]:func_args[0][0] + 1] + else: + del tokens[func_args[0][0]:func_args[1][0] + 1] + tokens[i:k] = [Token('CODE', 'super()')] elif token.offset in visitor.encode_calls: i = _find_open_paren(tokens, i) call = visitor.encode_calls[token.offset] diff --git a/tests/super_test.py b/tests/super_test.py index 7ec00956..95eba81e 100644 --- a/tests/super_test.py +++ b/tests/super_test.py @@ -121,3 +121,77 @@ def test_fix_super_noop(s): ) def test_fix_super(s, expected): assert _fix_py3_plus(s, (3,)) == expected + + +@pytest.mark.parametrize( + 's', + ( + pytest.param( + 'class C(B):\n' + ' def f(self):\n' + ' B.f(notself)\n', + id='old style super, first argument is not first function arg', + ), + pytest.param( + 'class C(B1, B2):\n' + ' def f(self):\n' + ' B1.f(self)\n', + # TODO: is this safe to rewrite? I don't think so + id='old-style super, multiple inheritance first class', + ), + pytest.param( + 'class C(B1, B2):\n' + ' def f(self):\n' + ' B2.f(self)\n', + # TODO: is this safe to rewrite? I don't think so + id='old-style super, multiple inheritance not-first class', + ), + pytest.param( + 'class C(Base):\n' + ' def f(self):\n' + ' return [Base.f(self) for _ in ()]\n', + id='super in comprehension', + ), + pytest.param( + 'class C(Base):\n' + ' def f(self):\n' + ' def g():\n' + ' Base.f(self)\n' + ' g()\n', + id='super in nested functions', + ), + pytest.param( + 'class C(not_simple()):\n' + ' def f(self):\n' + ' not_simple().f(self)\n', + id='not a simple base', + ), + pytest.param( + 'class C(a().b):\n' + ' def f(self):\n' + ' a().b.f(self)\n', + id='non simple attribute base', + ), + ), +) +def test_old_style_class_super_noop(s): + assert _fix_py3_plus(s, (3,)) == s + + +@pytest.mark.parametrize( + ('s', 'expected'), + ( + ( + 'class C(B):\n' + ' def f(self):\n' + ' B.f(self)\n' + ' B.f(self, arg, arg)\n', + 'class C(B):\n' + ' def f(self):\n' + ' super().f()\n' + ' super().f(arg, arg)\n', + ), + ), +) +def test_old_style_class_super(s, expected): + assert _fix_py3_plus(s, (3,)) == expected