From 17cb81f01df60f234fba17f93c7ed79d01a8e678 Mon Sep 17 00:00:00 2001 From: Anthony Sottile Date: Sat, 30 Jan 2021 11:46:05 -0800 Subject: [PATCH] Revert "Revert "Merge pull request #317 from asottile/old_super"" This reverts commit 2719335fa7bdb582b35ac90547a0f763d4225036. --- pyupgrade/_ast_helpers.py | 44 +++++++++++++++ pyupgrade/_plugins/legacy.py | 90 ++++++++++++++++--------------- tests/ast_helpers_test.py | 19 +++++++ tests/features/super_test.py | 74 +++++++++++++++++++++++++ tests/features/yield_from_test.py | 19 ------- 5 files changed, 185 insertions(+), 61 deletions(-) create mode 100644 tests/ast_helpers_test.py diff --git a/pyupgrade/_ast_helpers.py b/pyupgrade/_ast_helpers.py index eb4ba066..2f5c05ae 100644 --- a/pyupgrade/_ast_helpers.py +++ b/pyupgrade/_ast_helpers.py @@ -1,8 +1,12 @@ import ast import warnings +from typing import Any from typing import Container from typing import Dict +from typing import Iterable from typing import Set +from typing import Tuple +from typing import Type from typing import Union from tokenize_rt import Offset @@ -57,3 +61,43 @@ def is_async_listcomp(node: ast.ListComp) -> bool: any(gen.is_async for gen in node.generators) or contains_await(node) ) + + +def _all_isinstance( + vals: Iterable[Any], + tp: Union[Type[Any], Tuple[Type[Any], ...]], +) -> bool: + return all(isinstance(v, tp) for v in vals) + + +def _fields_same(n1: ast.AST, n2: ast.AST) -> bool: + for (a1, v1), (a2, v2) in zip(ast.iter_fields(n1), ast.iter_fields(n2)): + # ignore ast attributes, they'll be covered by walk + if a1 != a2: + return False + elif _all_isinstance((v1, v2), ast.AST): + continue + elif _all_isinstance((v1, v2), (list, tuple)): + if len(v1) != len(v2): + return False + # ignore sequences which are all-ast, they'll be covered by walk + elif _all_isinstance(v1, ast.AST) and _all_isinstance(v2, ast.AST): + continue + elif v1 != v2: + return False + elif v1 != v2: + return False + return True + + +def targets_same(node1: ast.AST, node2: ast.AST) -> bool: + for t1, t2 in zip(ast.walk(node1), ast.walk(node2)): + # ignore `ast.Load` / `ast.Store` + if _all_isinstance((t1, t2), ast.expr_context): + continue + elif type(t1) != type(t2): + return False + elif not _fields_same(t1, t2): + return False + else: + return True diff --git a/pyupgrade/_plugins/legacy.py b/pyupgrade/_plugins/legacy.py index 321ffba0..e8443065 100644 --- a/pyupgrade/_plugins/legacy.py +++ b/pyupgrade/_plugins/legacy.py @@ -2,28 +2,28 @@ import collections import contextlib import functools -from typing import Any from typing import Dict from typing import Generator from typing import Iterable from typing import List from typing import Set from typing import Tuple -from typing import Type -from typing import Union from tokenize_rt import Offset from tokenize_rt import Token from tokenize_rt import tokens_to_src from pyupgrade._ast_helpers import ast_to_offset +from pyupgrade._ast_helpers import targets_same from pyupgrade._data import register from pyupgrade._data import State from pyupgrade._data import TokenFunc from pyupgrade._token_helpers import Block from pyupgrade._token_helpers import find_and_replace_call from pyupgrade._token_helpers import find_block_start +from pyupgrade._token_helpers import find_open_paren from pyupgrade._token_helpers import find_token +from pyupgrade._token_helpers import parse_call_args FUNC_TYPES = (ast.Lambda, ast.FunctionDef, ast.AsyncFunctionDef) @@ -36,44 +36,27 @@ def _fix_yield(i: int, tokens: List[Token]) -> None: tokens[i:block.end] = [Token('CODE', f'yield from {container}\n')] -def _all_isinstance( - vals: Iterable[Any], - tp: Union[Type[Any], Tuple[Type[Any], ...]], -) -> bool: - return all(isinstance(v, tp) for v in vals) - - -def _fields_same(n1: ast.AST, n2: ast.AST) -> bool: - for (a1, v1), (a2, v2) in zip(ast.iter_fields(n1), ast.iter_fields(n2)): - # ignore ast attributes, they'll be covered by walk - if a1 != a2: - return False - elif _all_isinstance((v1, v2), ast.AST): - continue - elif _all_isinstance((v1, v2), (list, tuple)): - if len(v1) != len(v2): - return False - # ignore sequences which are all-ast, they'll be covered by walk - elif _all_isinstance(v1, ast.AST) and _all_isinstance(v2, ast.AST): - continue - elif v1 != v2: - return False - elif v1 != v2: - return False - return True - - -def _targets_same(target: ast.AST, yield_value: ast.AST) -> bool: - for t1, t2 in zip(ast.walk(target), ast.walk(yield_value)): - # ignore `ast.Load` / `ast.Store` - if _all_isinstance((t1, t2), ast.expr_context): - continue - elif type(t1) != type(t2): - return False - elif not _fields_same(t1, t2): - return False +def _fix_old_super(i: int, tokens: List[Token]) -> None: + j = find_open_paren(tokens, i) + k = j - 1 + while tokens[k].src != '.': + k -= 1 + func_args, end = parse_call_args(tokens, j) + # remove the first argument + if len(func_args) == 1: + del tokens[func_args[0][0]:func_args[0][0] + 1] else: - return True + del tokens[func_args[0][0]:func_args[1][0] + 1] + tokens[i:k] = [Token('CODE', 'super()')] + + +def _is_simple_base(base: ast.AST) -> bool: + return ( + isinstance(base, ast.Name) or ( + isinstance(base, ast.Attribute) and + _is_simple_base(base.value) + ) + ) class Scope: @@ -92,6 +75,7 @@ class Visitor(ast.NodeVisitor): def __init__(self) -> None: self._scopes: List[Scope] = [] self.super_offsets: Set[Offset] = set() + self.old_super_offsets: Set[Offset] = set() self.yield_offsets: Set[Offset] = set() @contextlib.contextmanager @@ -137,7 +121,6 @@ def visit_Call(self, node: ast.Call) -> None: len(node.args) == 2 and isinstance(node.args[0], ast.Name) and isinstance(node.args[1], ast.Name) and - # there are at least two scopes len(self._scopes) >= 2 and # the second to last scope is the class in arg1 isinstance(self._scopes[-2].node, ast.ClassDef) and @@ -148,6 +131,26 @@ def visit_Call(self, node: ast.Call) -> None: node.args[1].id == self._scopes[-1].node.args.args[0].arg ): self.super_offsets.add(ast_to_offset(node)) + elif ( + len(self._scopes) >= 2 and + # last stack is a function whose first argument is the first + # argument of this function + len(node.args) >= 1 and + isinstance(node.args[0], ast.Name) and + isinstance(self._scopes[-1].node, FUNC_TYPES) and + len(self._scopes[-1].node.args.args) >= 1 and + node.args[0].id == self._scopes[-1].node.args.args[0].arg and + # the function is an attribute of the contained class name + isinstance(node.func, ast.Attribute) and + isinstance(self._scopes[-2].node, ast.ClassDef) and + len(self._scopes[-2].node.bases) == 1 and + _is_simple_base(self._scopes[-2].node.bases[0]) and + targets_same( + self._scopes[-2].node.bases[0], + node.func.value, + ) + ): + self.old_super_offsets.add(ast_to_offset(node)) self.generic_visit(node) @@ -159,7 +162,7 @@ def visit_For(self, node: ast.For) -> None: isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Yield) and node.body[0].value.value is not None and - _targets_same(node.target, node.body[0].value.value) and + targets_same(node.target, node.body[0].value.value) and not node.orelse ): offset = ast_to_offset(node) @@ -198,5 +201,8 @@ def visit_Module( for offset in visitor.super_offsets: yield offset, super_func + for offset in visitor.old_super_offsets: + yield offset, _fix_old_super + for offset in visitor.yield_offsets: yield offset, _fix_yield diff --git a/tests/ast_helpers_test.py b/tests/ast_helpers_test.py new file mode 100644 index 00000000..282ea729 --- /dev/null +++ b/tests/ast_helpers_test.py @@ -0,0 +1,19 @@ +import ast + +from pyupgrade._ast_helpers import _fields_same +from pyupgrade._ast_helpers import targets_same + + +def test_targets_same(): + assert targets_same(ast.parse('global a, b'), ast.parse('global a, b')) + assert not targets_same(ast.parse('global a'), ast.parse('global b')) + + +def _get_body(expr): + body = ast.parse(expr).body[0] + assert isinstance(body, ast.Expr) + return body.value + + +def test_fields_same(): + assert not _fields_same(_get_body('x'), _get_body('1')) diff --git a/tests/features/super_test.py b/tests/features/super_test.py index ee558a6d..6e2103c9 100644 --- a/tests/features/super_test.py +++ b/tests/features/super_test.py @@ -122,3 +122,77 @@ def test_fix_super_noop(s): ) def test_fix_super(s, expected): assert _fix_plugins(s, settings=Settings(min_version=(3,))) == expected + + +@pytest.mark.parametrize( + 's', + ( + pytest.param( + 'class C(B):\n' + ' def f(self):\n' + ' B.f(notself)\n', + id='old style super, first argument is not first function arg', + ), + pytest.param( + 'class C(B1, B2):\n' + ' def f(self):\n' + ' B1.f(self)\n', + # TODO: is this safe to rewrite? I don't think so + id='old-style super, multiple inheritance first class', + ), + pytest.param( + 'class C(B1, B2):\n' + ' def f(self):\n' + ' B2.f(self)\n', + # TODO: is this safe to rewrite? I don't think so + id='old-style super, multiple inheritance not-first class', + ), + pytest.param( + 'class C(Base):\n' + ' def f(self):\n' + ' return [Base.f(self) for _ in ()]\n', + id='super in comprehension', + ), + pytest.param( + 'class C(Base):\n' + ' def f(self):\n' + ' def g():\n' + ' Base.f(self)\n' + ' g()\n', + id='super in nested functions', + ), + pytest.param( + 'class C(not_simple()):\n' + ' def f(self):\n' + ' not_simple().f(self)\n', + id='not a simple base', + ), + pytest.param( + 'class C(a().b):\n' + ' def f(self):\n' + ' a().b.f(self)\n', + id='non simple attribute base', + ), + ), +) +def test_old_style_class_super_noop(s): + assert _fix_plugins(s, settings=Settings(min_version=(3,))) == s + + +@pytest.mark.parametrize( + ('s', 'expected'), + ( + ( + 'class C(B):\n' + ' def f(self):\n' + ' B.f(self)\n' + ' B.f(self, arg, arg)\n', + 'class C(B):\n' + ' def f(self):\n' + ' super().f()\n' + ' super().f(arg, arg)\n', + ), + ), +) +def test_old_style_class_super(s, expected): + assert _fix_plugins(s, settings=Settings(min_version=(3,))) == expected diff --git a/tests/features/yield_from_test.py b/tests/features/yield_from_test.py index d3d12326..dc040993 100644 --- a/tests/features/yield_from_test.py +++ b/tests/features/yield_from_test.py @@ -1,11 +1,7 @@ -import ast - import pytest from pyupgrade._data import Settings from pyupgrade._main import _fix_plugins -from pyupgrade._plugins.legacy import _fields_same -from pyupgrade._plugins.legacy import _targets_same @pytest.mark.parametrize( @@ -215,18 +211,3 @@ def test_fix_yield_from(s, expected): ) def test_fix_yield_from_noop(s): assert _fix_plugins(s, settings=Settings(min_version=(3,))) == s - - -def test_targets_same(): - assert _targets_same(ast.parse('global a, b'), ast.parse('global a, b')) - assert not _targets_same(ast.parse('global a'), ast.parse('global b')) - - -def _get_body(expr): - body = ast.parse(expr).body[0] - assert isinstance(body, ast.Expr) - return body.value - - -def test_fields_same(): - assert not _fields_same(_get_body('x'), _get_body('1'))