diff --git a/pyupgrade/_ast_helpers.py b/pyupgrade/_ast_helpers.py index eb4ba066..2f5c05ae 100644 --- a/pyupgrade/_ast_helpers.py +++ b/pyupgrade/_ast_helpers.py @@ -1,8 +1,12 @@ import ast import warnings +from typing import Any from typing import Container from typing import Dict +from typing import Iterable from typing import Set +from typing import Tuple +from typing import Type from typing import Union from tokenize_rt import Offset @@ -57,3 +61,43 @@ def is_async_listcomp(node: ast.ListComp) -> bool: any(gen.is_async for gen in node.generators) or contains_await(node) ) + + +def _all_isinstance( + vals: Iterable[Any], + tp: Union[Type[Any], Tuple[Type[Any], ...]], +) -> bool: + return all(isinstance(v, tp) for v in vals) + + +def _fields_same(n1: ast.AST, n2: ast.AST) -> bool: + for (a1, v1), (a2, v2) in zip(ast.iter_fields(n1), ast.iter_fields(n2)): + # ignore ast attributes, they'll be covered by walk + if a1 != a2: + return False + elif _all_isinstance((v1, v2), ast.AST): + continue + elif _all_isinstance((v1, v2), (list, tuple)): + if len(v1) != len(v2): + return False + # ignore sequences which are all-ast, they'll be covered by walk + elif _all_isinstance(v1, ast.AST) and _all_isinstance(v2, ast.AST): + continue + elif v1 != v2: + return False + elif v1 != v2: + return False + return True + + +def targets_same(node1: ast.AST, node2: ast.AST) -> bool: + for t1, t2 in zip(ast.walk(node1), ast.walk(node2)): + # ignore `ast.Load` / `ast.Store` + if _all_isinstance((t1, t2), ast.expr_context): + continue + elif type(t1) != type(t2): + return False + elif not _fields_same(t1, t2): + return False + else: + return True diff --git a/pyupgrade/_plugins/legacy.py b/pyupgrade/_plugins/legacy.py index 321ffba0..f7ba3763 100644 --- a/pyupgrade/_plugins/legacy.py +++ b/pyupgrade/_plugins/legacy.py @@ -2,21 +2,19 @@ import collections import contextlib import functools -from typing import Any from typing import Dict from typing import Generator from typing import Iterable from typing import List from typing import Set from typing import Tuple -from typing import Type -from typing import Union from tokenize_rt import Offset from tokenize_rt import Token from tokenize_rt import tokens_to_src from pyupgrade._ast_helpers import ast_to_offset +from pyupgrade._ast_helpers import targets_same from pyupgrade._data import register from pyupgrade._data import State from pyupgrade._data import TokenFunc @@ -26,6 +24,7 @@ from pyupgrade._token_helpers import find_token FUNC_TYPES = (ast.Lambda, ast.FunctionDef, ast.AsyncFunctionDef) +NON_LAMBDA_FUNC_TYPES = (ast.FunctionDef, ast.AsyncFunctionDef) def _fix_yield(i: int, tokens: List[Token]) -> None: @@ -36,44 +35,13 @@ def _fix_yield(i: int, tokens: List[Token]) -> None: tokens[i:block.end] = [Token('CODE', f'yield from {container}\n')] -def _all_isinstance( - vals: Iterable[Any], - tp: Union[Type[Any], Tuple[Type[Any], ...]], -) -> bool: - return all(isinstance(v, tp) for v in vals) - - -def _fields_same(n1: ast.AST, n2: ast.AST) -> bool: - for (a1, v1), (a2, v2) in zip(ast.iter_fields(n1), ast.iter_fields(n2)): - # ignore ast attributes, they'll be covered by walk - if a1 != a2: - return False - elif _all_isinstance((v1, v2), ast.AST): - continue - elif _all_isinstance((v1, v2), (list, tuple)): - if len(v1) != len(v2): - return False - # ignore sequences which are all-ast, they'll be covered by walk - elif _all_isinstance(v1, ast.AST) and _all_isinstance(v2, ast.AST): - continue - elif v1 != v2: - return False - elif v1 != v2: - return False - return True - - -def _targets_same(target: ast.AST, yield_value: ast.AST) -> bool: - for t1, t2 in zip(ast.walk(target), ast.walk(yield_value)): - # ignore `ast.Load` / `ast.Store` - if _all_isinstance((t1, t2), ast.expr_context): - continue - elif type(t1) != type(t2): - return False - elif not _fields_same(t1, t2): - return False - else: - return True +def _is_simple_base(base: ast.AST) -> bool: + return ( + isinstance(base, ast.Name) or ( + isinstance(base, ast.Attribute) and + _is_simple_base(base.value) + ) + ) class Scope: @@ -92,6 +60,7 @@ class Visitor(ast.NodeVisitor): def __init__(self) -> None: self._scopes: List[Scope] = [] self.super_offsets: Set[Offset] = set() + self.old_super_offsets: Set[Tuple[Offset, str]] = set() self.yield_offsets: Set[Offset] = set() @contextlib.contextmanager @@ -137,7 +106,6 @@ def visit_Call(self, node: ast.Call) -> None: len(node.args) == 2 and isinstance(node.args[0], ast.Name) and isinstance(node.args[1], ast.Name) and - # there are at least two scopes len(self._scopes) >= 2 and # the second to last scope is the class in arg1 isinstance(self._scopes[-2].node, ast.ClassDef) and @@ -148,6 +116,29 @@ def visit_Call(self, node: ast.Call) -> None: node.args[1].id == self._scopes[-1].node.args.args[0].arg ): self.super_offsets.add(ast_to_offset(node)) + elif ( + # base.funcname(funcarg1, ...) + isinstance(node.func, ast.Attribute) and + len(node.args) >= 1 and + isinstance(node.args[0], ast.Name) and + len(self._scopes) >= 2 and + # last stack is a function whose first argument is the first + # argument of this function + isinstance(self._scopes[-1].node, NON_LAMBDA_FUNC_TYPES) and + node.func.attr == self._scopes[-1].node.name and + node.func.attr != '__new__' and + len(self._scopes[-1].node.args.args) >= 1 and + node.args[0].id == self._scopes[-1].node.args.args[0].arg and + # the function is an attribute of the contained class name + isinstance(self._scopes[-2].node, ast.ClassDef) and + len(self._scopes[-2].node.bases) == 1 and + _is_simple_base(self._scopes[-2].node.bases[0]) and + targets_same( + self._scopes[-2].node.bases[0], + node.func.value, + ) + ): + self.old_super_offsets.add((ast_to_offset(node), node.func.attr)) self.generic_visit(node) @@ -159,7 +150,7 @@ def visit_For(self, node: ast.For) -> None: isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Yield) and node.body[0].value.value is not None and - _targets_same(node.target, node.body[0].value.value) and + targets_same(node.target, node.body[0].value.value) and not node.orelse ): offset = ast_to_offset(node) @@ -198,5 +189,10 @@ def visit_Module( for offset in visitor.super_offsets: yield offset, super_func + for offset, func_name in visitor.old_super_offsets: + template = f'super().{func_name}({{rest}})' + callback = functools.partial(find_and_replace_call, template=template) + yield offset, callback + for offset in visitor.yield_offsets: yield offset, _fix_yield diff --git a/tests/ast_helpers_test.py b/tests/ast_helpers_test.py new file mode 100644 index 00000000..282ea729 --- /dev/null +++ b/tests/ast_helpers_test.py @@ -0,0 +1,19 @@ +import ast + +from pyupgrade._ast_helpers import _fields_same +from pyupgrade._ast_helpers import targets_same + + +def test_targets_same(): + assert targets_same(ast.parse('global a, b'), ast.parse('global a, b')) + assert not targets_same(ast.parse('global a'), ast.parse('global b')) + + +def _get_body(expr): + body = ast.parse(expr).body[0] + assert isinstance(body, ast.Expr) + return body.value + + +def test_fields_same(): + assert not _fields_same(_get_body('x'), _get_body('1')) diff --git a/tests/features/super_test.py b/tests/features/super_test.py index ee558a6d..8e3be14f 100644 --- a/tests/features/super_test.py +++ b/tests/features/super_test.py @@ -122,3 +122,107 @@ def test_fix_super_noop(s): ) def test_fix_super(s, expected): assert _fix_plugins(s, settings=Settings(min_version=(3,))) == expected + + +@pytest.mark.parametrize( + 's', + ( + pytest.param( + 'class C(B):\n' + ' def f(self):\n' + ' B.f(notself)\n', + id='old style super, first argument is not first function arg', + ), + pytest.param( + 'class C(B1, B2):\n' + ' def f(self):\n' + ' B1.f(self)\n', + # TODO: is this safe to rewrite? I don't think so + id='old-style super, multiple inheritance first class', + ), + pytest.param( + 'class C(B1, B2):\n' + ' def f(self):\n' + ' B2.f(self)\n', + # TODO: is this safe to rewrite? I don't think so + id='old-style super, multiple inheritance not-first class', + ), + pytest.param( + 'class C(Base):\n' + ' def f(self):\n' + ' return [Base.f(self) for _ in ()]\n', + id='super in comprehension', + ), + pytest.param( + 'class C(Base):\n' + ' def f(self):\n' + ' def g():\n' + ' Base.f(self)\n' + ' g()\n', + id='super in nested functions', + ), + pytest.param( + 'class C(not_simple()):\n' + ' def f(self):\n' + ' not_simple().f(self)\n', + id='not a simple base', + ), + pytest.param( + 'class C(a().b):\n' + ' def f(self):\n' + ' a().b.f(self)\n', + id='non simple attribute base', + ), + pytest.param( + 'class C:\n' + ' @classmethod\n' + ' def make(cls, instance):\n' + ' ...\n' + 'class D(C):\n' + ' def find(self):\n' + ' return C.make(self)\n', + ), + pytest.param( + 'class C(tuple):\n' + ' def __new__(cls, arg):\n' + ' return tuple.__new__(cls, (arg,))\n', + id='super() does not work properly for __new__', + ), + ), +) +def test_old_style_class_super_noop(s): + assert _fix_plugins(s, settings=Settings(min_version=(3,))) == s + + +@pytest.mark.parametrize( + ('s', 'expected'), + ( + ( + 'class C(B):\n' + ' def f(self):\n' + ' B.f(self)\n' + ' B.f(self, arg, arg)\n', + 'class C(B):\n' + ' def f(self):\n' + ' super().f()\n' + ' super().f(arg, arg)\n', + ), + pytest.param( + 'class C(B):\n' + ' def f(self, a):\n' + ' B.f(\n' + ' self,\n' + ' a,\n' + ' )\n', + + 'class C(B):\n' + ' def f(self, a):\n' + ' super().f(\n' + ' a,\n' + ' )\n', + id='multi-line super call', + ), + ), +) +def test_old_style_class_super(s, expected): + assert _fix_plugins(s, settings=Settings(min_version=(3,))) == expected diff --git a/tests/features/yield_from_test.py b/tests/features/yield_from_test.py index d3d12326..dc040993 100644 --- a/tests/features/yield_from_test.py +++ b/tests/features/yield_from_test.py @@ -1,11 +1,7 @@ -import ast - import pytest from pyupgrade._data import Settings from pyupgrade._main import _fix_plugins -from pyupgrade._plugins.legacy import _fields_same -from pyupgrade._plugins.legacy import _targets_same @pytest.mark.parametrize( @@ -215,18 +211,3 @@ def test_fix_yield_from(s, expected): ) def test_fix_yield_from_noop(s): assert _fix_plugins(s, settings=Settings(min_version=(3,))) == s - - -def test_targets_same(): - assert _targets_same(ast.parse('global a, b'), ast.parse('global a, b')) - assert not _targets_same(ast.parse('global a'), ast.parse('global b')) - - -def _get_body(expr): - body = ast.parse(expr).body[0] - assert isinstance(body, ast.Expr) - return body.value - - -def test_fields_same(): - assert not _fields_same(_get_body('x'), _get_body('1'))