From 820f9ae7e2d813fe6c61918d0bcb113b7cd713d5 Mon Sep 17 00:00:00 2001 From: Anthony Sottile Date: Sat, 30 Jan 2021 11:46:05 -0800 Subject: [PATCH 1/4] Revert "Revert "Merge pull request #317 from asottile/old_super"" This reverts commit 2719335fa7bdb582b35ac90547a0f763d4225036. --- pyupgrade/_ast_helpers.py | 44 +++++++++++++++ pyupgrade/_plugins/legacy.py | 90 ++++++++++++++++--------------- tests/ast_helpers_test.py | 19 +++++++ tests/features/super_test.py | 74 +++++++++++++++++++++++++ tests/features/yield_from_test.py | 19 ------- 5 files changed, 185 insertions(+), 61 deletions(-) create mode 100644 tests/ast_helpers_test.py diff --git a/pyupgrade/_ast_helpers.py b/pyupgrade/_ast_helpers.py index eb4ba066..2f5c05ae 100644 --- a/pyupgrade/_ast_helpers.py +++ b/pyupgrade/_ast_helpers.py @@ -1,8 +1,12 @@ import ast import warnings +from typing import Any from typing import Container from typing import Dict +from typing import Iterable from typing import Set +from typing import Tuple +from typing import Type from typing import Union from tokenize_rt import Offset @@ -57,3 +61,43 @@ def is_async_listcomp(node: ast.ListComp) -> bool: any(gen.is_async for gen in node.generators) or contains_await(node) ) + + +def _all_isinstance( + vals: Iterable[Any], + tp: Union[Type[Any], Tuple[Type[Any], ...]], +) -> bool: + return all(isinstance(v, tp) for v in vals) + + +def _fields_same(n1: ast.AST, n2: ast.AST) -> bool: + for (a1, v1), (a2, v2) in zip(ast.iter_fields(n1), ast.iter_fields(n2)): + # ignore ast attributes, they'll be covered by walk + if a1 != a2: + return False + elif _all_isinstance((v1, v2), ast.AST): + continue + elif _all_isinstance((v1, v2), (list, tuple)): + if len(v1) != len(v2): + return False + # ignore sequences which are all-ast, they'll be covered by walk + elif _all_isinstance(v1, ast.AST) and _all_isinstance(v2, ast.AST): + continue + elif v1 != v2: + return False + elif v1 != v2: + return False + return True + + +def targets_same(node1: ast.AST, node2: ast.AST) -> bool: + for t1, t2 in zip(ast.walk(node1), ast.walk(node2)): + # ignore `ast.Load` / `ast.Store` + if _all_isinstance((t1, t2), ast.expr_context): + continue + elif type(t1) != type(t2): + return False + elif not _fields_same(t1, t2): + return False + else: + return True diff --git a/pyupgrade/_plugins/legacy.py b/pyupgrade/_plugins/legacy.py index 321ffba0..e8443065 100644 --- a/pyupgrade/_plugins/legacy.py +++ b/pyupgrade/_plugins/legacy.py @@ -2,28 +2,28 @@ import collections import contextlib import functools -from typing import Any from typing import Dict from typing import Generator from typing import Iterable from typing import List from typing import Set from typing import Tuple -from typing import Type -from typing import Union from tokenize_rt import Offset from tokenize_rt import Token from tokenize_rt import tokens_to_src from pyupgrade._ast_helpers import ast_to_offset +from pyupgrade._ast_helpers import targets_same from pyupgrade._data import register from pyupgrade._data import State from pyupgrade._data import TokenFunc from pyupgrade._token_helpers import Block from pyupgrade._token_helpers import find_and_replace_call from pyupgrade._token_helpers import find_block_start +from pyupgrade._token_helpers import find_open_paren from pyupgrade._token_helpers import find_token +from pyupgrade._token_helpers import parse_call_args FUNC_TYPES = (ast.Lambda, ast.FunctionDef, ast.AsyncFunctionDef) @@ -36,44 +36,27 @@ def _fix_yield(i: int, tokens: List[Token]) -> None: tokens[i:block.end] = [Token('CODE', f'yield from {container}\n')] -def _all_isinstance( - vals: Iterable[Any], - tp: Union[Type[Any], Tuple[Type[Any], ...]], -) -> bool: - return all(isinstance(v, tp) for v in vals) - - -def _fields_same(n1: ast.AST, n2: ast.AST) -> bool: - for (a1, v1), (a2, v2) in zip(ast.iter_fields(n1), ast.iter_fields(n2)): - # ignore ast attributes, they'll be covered by walk - if a1 != a2: - return False - elif _all_isinstance((v1, v2), ast.AST): - continue - elif _all_isinstance((v1, v2), (list, tuple)): - if len(v1) != len(v2): - return False - # ignore sequences which are all-ast, they'll be covered by walk - elif _all_isinstance(v1, ast.AST) and _all_isinstance(v2, ast.AST): - continue - elif v1 != v2: - return False - elif v1 != v2: - return False - return True - - -def _targets_same(target: ast.AST, yield_value: ast.AST) -> bool: - for t1, t2 in zip(ast.walk(target), ast.walk(yield_value)): - # ignore `ast.Load` / `ast.Store` - if _all_isinstance((t1, t2), ast.expr_context): - continue - elif type(t1) != type(t2): - return False - elif not _fields_same(t1, t2): - return False +def _fix_old_super(i: int, tokens: List[Token]) -> None: + j = find_open_paren(tokens, i) + k = j - 1 + while tokens[k].src != '.': + k -= 1 + func_args, end = parse_call_args(tokens, j) + # remove the first argument + if len(func_args) == 1: + del tokens[func_args[0][0]:func_args[0][0] + 1] else: - return True + del tokens[func_args[0][0]:func_args[1][0] + 1] + tokens[i:k] = [Token('CODE', 'super()')] + + +def _is_simple_base(base: ast.AST) -> bool: + return ( + isinstance(base, ast.Name) or ( + isinstance(base, ast.Attribute) and + _is_simple_base(base.value) + ) + ) class Scope: @@ -92,6 +75,7 @@ class Visitor(ast.NodeVisitor): def __init__(self) -> None: self._scopes: List[Scope] = [] self.super_offsets: Set[Offset] = set() + self.old_super_offsets: Set[Offset] = set() self.yield_offsets: Set[Offset] = set() @contextlib.contextmanager @@ -137,7 +121,6 @@ def visit_Call(self, node: ast.Call) -> None: len(node.args) == 2 and isinstance(node.args[0], ast.Name) and isinstance(node.args[1], ast.Name) and - # there are at least two scopes len(self._scopes) >= 2 and # the second to last scope is the class in arg1 isinstance(self._scopes[-2].node, ast.ClassDef) and @@ -148,6 +131,26 @@ def visit_Call(self, node: ast.Call) -> None: node.args[1].id == self._scopes[-1].node.args.args[0].arg ): self.super_offsets.add(ast_to_offset(node)) + elif ( + len(self._scopes) >= 2 and + # last stack is a function whose first argument is the first + # argument of this function + len(node.args) >= 1 and + isinstance(node.args[0], ast.Name) and + isinstance(self._scopes[-1].node, FUNC_TYPES) and + len(self._scopes[-1].node.args.args) >= 1 and + node.args[0].id == self._scopes[-1].node.args.args[0].arg and + # the function is an attribute of the contained class name + isinstance(node.func, ast.Attribute) and + isinstance(self._scopes[-2].node, ast.ClassDef) and + len(self._scopes[-2].node.bases) == 1 and + _is_simple_base(self._scopes[-2].node.bases[0]) and + targets_same( + self._scopes[-2].node.bases[0], + node.func.value, + ) + ): + self.old_super_offsets.add(ast_to_offset(node)) self.generic_visit(node) @@ -159,7 +162,7 @@ def visit_For(self, node: ast.For) -> None: isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Yield) and node.body[0].value.value is not None and - _targets_same(node.target, node.body[0].value.value) and + targets_same(node.target, node.body[0].value.value) and not node.orelse ): offset = ast_to_offset(node) @@ -198,5 +201,8 @@ def visit_Module( for offset in visitor.super_offsets: yield offset, super_func + for offset in visitor.old_super_offsets: + yield offset, _fix_old_super + for offset in visitor.yield_offsets: yield offset, _fix_yield diff --git a/tests/ast_helpers_test.py b/tests/ast_helpers_test.py new file mode 100644 index 00000000..282ea729 --- /dev/null +++ b/tests/ast_helpers_test.py @@ -0,0 +1,19 @@ +import ast + +from pyupgrade._ast_helpers import _fields_same +from pyupgrade._ast_helpers import targets_same + + +def test_targets_same(): + assert targets_same(ast.parse('global a, b'), ast.parse('global a, b')) + assert not targets_same(ast.parse('global a'), ast.parse('global b')) + + +def _get_body(expr): + body = ast.parse(expr).body[0] + assert isinstance(body, ast.Expr) + return body.value + + +def test_fields_same(): + assert not _fields_same(_get_body('x'), _get_body('1')) diff --git a/tests/features/super_test.py b/tests/features/super_test.py index ee558a6d..6e2103c9 100644 --- a/tests/features/super_test.py +++ b/tests/features/super_test.py @@ -122,3 +122,77 @@ def test_fix_super_noop(s): ) def test_fix_super(s, expected): assert _fix_plugins(s, settings=Settings(min_version=(3,))) == expected + + +@pytest.mark.parametrize( + 's', + ( + pytest.param( + 'class C(B):\n' + ' def f(self):\n' + ' B.f(notself)\n', + id='old style super, first argument is not first function arg', + ), + pytest.param( + 'class C(B1, B2):\n' + ' def f(self):\n' + ' B1.f(self)\n', + # TODO: is this safe to rewrite? I don't think so + id='old-style super, multiple inheritance first class', + ), + pytest.param( + 'class C(B1, B2):\n' + ' def f(self):\n' + ' B2.f(self)\n', + # TODO: is this safe to rewrite? I don't think so + id='old-style super, multiple inheritance not-first class', + ), + pytest.param( + 'class C(Base):\n' + ' def f(self):\n' + ' return [Base.f(self) for _ in ()]\n', + id='super in comprehension', + ), + pytest.param( + 'class C(Base):\n' + ' def f(self):\n' + ' def g():\n' + ' Base.f(self)\n' + ' g()\n', + id='super in nested functions', + ), + pytest.param( + 'class C(not_simple()):\n' + ' def f(self):\n' + ' not_simple().f(self)\n', + id='not a simple base', + ), + pytest.param( + 'class C(a().b):\n' + ' def f(self):\n' + ' a().b.f(self)\n', + id='non simple attribute base', + ), + ), +) +def test_old_style_class_super_noop(s): + assert _fix_plugins(s, settings=Settings(min_version=(3,))) == s + + +@pytest.mark.parametrize( + ('s', 'expected'), + ( + ( + 'class C(B):\n' + ' def f(self):\n' + ' B.f(self)\n' + ' B.f(self, arg, arg)\n', + 'class C(B):\n' + ' def f(self):\n' + ' super().f()\n' + ' super().f(arg, arg)\n', + ), + ), +) +def test_old_style_class_super(s, expected): + assert _fix_plugins(s, settings=Settings(min_version=(3,))) == expected diff --git a/tests/features/yield_from_test.py b/tests/features/yield_from_test.py index d3d12326..dc040993 100644 --- a/tests/features/yield_from_test.py +++ b/tests/features/yield_from_test.py @@ -1,11 +1,7 @@ -import ast - import pytest from pyupgrade._data import Settings from pyupgrade._main import _fix_plugins -from pyupgrade._plugins.legacy import _fields_same -from pyupgrade._plugins.legacy import _targets_same @pytest.mark.parametrize( @@ -215,18 +211,3 @@ def test_fix_yield_from(s, expected): ) def test_fix_yield_from_noop(s): assert _fix_plugins(s, settings=Settings(min_version=(3,))) == s - - -def test_targets_same(): - assert _targets_same(ast.parse('global a, b'), ast.parse('global a, b')) - assert not _targets_same(ast.parse('global a'), ast.parse('global b')) - - -def _get_body(expr): - body = ast.parse(expr).body[0] - assert isinstance(body, ast.Expr) - return body.value - - -def test_fields_same(): - assert not _fields_same(_get_body('x'), _get_body('1')) From eadaf860db8819732de22e8b4fc610577db7be9a Mon Sep 17 00:00:00 2001 From: Anthony Sottile Date: Sat, 30 Jan 2021 11:56:46 -0800 Subject: [PATCH 2/4] Fix bug with calling different superclass method --- pyupgrade/_plugins/legacy.py | 11 +++++++---- tests/features/super_test.py | 9 +++++++++ 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/pyupgrade/_plugins/legacy.py b/pyupgrade/_plugins/legacy.py index e8443065..93f6eff5 100644 --- a/pyupgrade/_plugins/legacy.py +++ b/pyupgrade/_plugins/legacy.py @@ -26,6 +26,7 @@ from pyupgrade._token_helpers import parse_call_args FUNC_TYPES = (ast.Lambda, ast.FunctionDef, ast.AsyncFunctionDef) +NON_LAMBDA_FUNC_TYPES = (ast.FunctionDef, ast.AsyncFunctionDef) def _fix_yield(i: int, tokens: List[Token]) -> None: @@ -132,16 +133,18 @@ def visit_Call(self, node: ast.Call) -> None: ): self.super_offsets.add(ast_to_offset(node)) elif ( + # base.funcname(funcarg1, ...) + isinstance(node.func, ast.Attribute) and + len(node.args) >= 1 and + isinstance(node.args[0], ast.Name) and len(self._scopes) >= 2 and # last stack is a function whose first argument is the first # argument of this function - len(node.args) >= 1 and - isinstance(node.args[0], ast.Name) and - isinstance(self._scopes[-1].node, FUNC_TYPES) and + isinstance(self._scopes[-1].node, NON_LAMBDA_FUNC_TYPES) and + node.func.attr == self._scopes[-1].node.name and len(self._scopes[-1].node.args.args) >= 1 and node.args[0].id == self._scopes[-1].node.args.args[0].arg and # the function is an attribute of the contained class name - isinstance(node.func, ast.Attribute) and isinstance(self._scopes[-2].node, ast.ClassDef) and len(self._scopes[-2].node.bases) == 1 and _is_simple_base(self._scopes[-2].node.bases[0]) and diff --git a/tests/features/super_test.py b/tests/features/super_test.py index 6e2103c9..0149a7fb 100644 --- a/tests/features/super_test.py +++ b/tests/features/super_test.py @@ -173,6 +173,15 @@ def test_fix_super(s, expected): ' a().b.f(self)\n', id='non simple attribute base', ), + pytest.param( + 'class C:\n' + ' @classmethod\n' + ' def make(cls, instance):\n' + ' ...\n' + 'class D(C):\n' + ' def find(self):\n' + ' return C.make(self)\n', + ), ), ) def test_old_style_class_super_noop(s): From dd1fe7fae1ca4decc4c03445beecbe32fd86510a Mon Sep 17 00:00:00 2001 From: Anthony Sottile Date: Sat, 25 Jul 2020 14:27:33 -0700 Subject: [PATCH 3/4] fix super replacement of multiple lines --- pyupgrade/_plugins/legacy.py | 26 ++++++-------------------- tests/features/super_test.py | 15 +++++++++++++++ 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/pyupgrade/_plugins/legacy.py b/pyupgrade/_plugins/legacy.py index 93f6eff5..8de3fa75 100644 --- a/pyupgrade/_plugins/legacy.py +++ b/pyupgrade/_plugins/legacy.py @@ -21,9 +21,7 @@ from pyupgrade._token_helpers import Block from pyupgrade._token_helpers import find_and_replace_call from pyupgrade._token_helpers import find_block_start -from pyupgrade._token_helpers import find_open_paren from pyupgrade._token_helpers import find_token -from pyupgrade._token_helpers import parse_call_args FUNC_TYPES = (ast.Lambda, ast.FunctionDef, ast.AsyncFunctionDef) NON_LAMBDA_FUNC_TYPES = (ast.FunctionDef, ast.AsyncFunctionDef) @@ -37,20 +35,6 @@ def _fix_yield(i: int, tokens: List[Token]) -> None: tokens[i:block.end] = [Token('CODE', f'yield from {container}\n')] -def _fix_old_super(i: int, tokens: List[Token]) -> None: - j = find_open_paren(tokens, i) - k = j - 1 - while tokens[k].src != '.': - k -= 1 - func_args, end = parse_call_args(tokens, j) - # remove the first argument - if len(func_args) == 1: - del tokens[func_args[0][0]:func_args[0][0] + 1] - else: - del tokens[func_args[0][0]:func_args[1][0] + 1] - tokens[i:k] = [Token('CODE', 'super()')] - - def _is_simple_base(base: ast.AST) -> bool: return ( isinstance(base, ast.Name) or ( @@ -76,7 +60,7 @@ class Visitor(ast.NodeVisitor): def __init__(self) -> None: self._scopes: List[Scope] = [] self.super_offsets: Set[Offset] = set() - self.old_super_offsets: Set[Offset] = set() + self.old_super_offsets: Set[Tuple[Offset, str]] = set() self.yield_offsets: Set[Offset] = set() @contextlib.contextmanager @@ -153,7 +137,7 @@ def visit_Call(self, node: ast.Call) -> None: node.func.value, ) ): - self.old_super_offsets.add(ast_to_offset(node)) + self.old_super_offsets.add((ast_to_offset(node), node.func.attr)) self.generic_visit(node) @@ -204,8 +188,10 @@ def visit_Module( for offset in visitor.super_offsets: yield offset, super_func - for offset in visitor.old_super_offsets: - yield offset, _fix_old_super + for offset, func_name in visitor.old_super_offsets: + template = f'super().{func_name}({{rest}})' + callback = functools.partial(find_and_replace_call, template=template) + yield offset, callback for offset in visitor.yield_offsets: yield offset, _fix_yield diff --git a/tests/features/super_test.py b/tests/features/super_test.py index 0149a7fb..124f87df 100644 --- a/tests/features/super_test.py +++ b/tests/features/super_test.py @@ -201,6 +201,21 @@ def test_old_style_class_super_noop(s): ' super().f()\n' ' super().f(arg, arg)\n', ), + pytest.param( + 'class C(B):\n' + ' def f(self, a):\n' + ' B.f(\n' + ' self,\n' + ' a,\n' + ' )\n', + + 'class C(B):\n' + ' def f(self, a):\n' + ' super().f(\n' + ' a,\n' + ' )\n', + id='multi-line super call', + ), ), ) def test_old_style_class_super(s, expected): From 1a76829caa376722ccf8d02c6d2dd33bbe30a94f Mon Sep 17 00:00:00 2001 From: Anthony Sottile Date: Sat, 25 Sep 2021 16:03:36 -0400 Subject: [PATCH 4/4] don't rewrite old super calls for __new__ --- pyupgrade/_plugins/legacy.py | 1 + tests/features/super_test.py | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/pyupgrade/_plugins/legacy.py b/pyupgrade/_plugins/legacy.py index 8de3fa75..f7ba3763 100644 --- a/pyupgrade/_plugins/legacy.py +++ b/pyupgrade/_plugins/legacy.py @@ -126,6 +126,7 @@ def visit_Call(self, node: ast.Call) -> None: # argument of this function isinstance(self._scopes[-1].node, NON_LAMBDA_FUNC_TYPES) and node.func.attr == self._scopes[-1].node.name and + node.func.attr != '__new__' and len(self._scopes[-1].node.args.args) >= 1 and node.args[0].id == self._scopes[-1].node.args.args[0].arg and # the function is an attribute of the contained class name diff --git a/tests/features/super_test.py b/tests/features/super_test.py index 124f87df..8e3be14f 100644 --- a/tests/features/super_test.py +++ b/tests/features/super_test.py @@ -182,6 +182,12 @@ def test_fix_super(s, expected): ' def find(self):\n' ' return C.make(self)\n', ), + pytest.param( + 'class C(tuple):\n' + ' def __new__(cls, arg):\n' + ' return tuple.__new__(cls, (arg,))\n', + id='super() does not work properly for __new__', + ), ), ) def test_old_style_class_super_noop(s):