diff --git a/pyupgrade/_ast_helpers.py b/pyupgrade/_ast_helpers.py index 2f5c05ae..eb4ba066 100644 --- a/pyupgrade/_ast_helpers.py +++ b/pyupgrade/_ast_helpers.py @@ -1,12 +1,8 @@ import ast import warnings -from typing import Any from typing import Container from typing import Dict -from typing import Iterable from typing import Set -from typing import Tuple -from typing import Type from typing import Union from tokenize_rt import Offset @@ -61,43 +57,3 @@ def is_async_listcomp(node: ast.ListComp) -> bool: any(gen.is_async for gen in node.generators) or contains_await(node) ) - - -def _all_isinstance( - vals: Iterable[Any], - tp: Union[Type[Any], Tuple[Type[Any], ...]], -) -> bool: - return all(isinstance(v, tp) for v in vals) - - -def _fields_same(n1: ast.AST, n2: ast.AST) -> bool: - for (a1, v1), (a2, v2) in zip(ast.iter_fields(n1), ast.iter_fields(n2)): - # ignore ast attributes, they'll be covered by walk - if a1 != a2: - return False - elif _all_isinstance((v1, v2), ast.AST): - continue - elif _all_isinstance((v1, v2), (list, tuple)): - if len(v1) != len(v2): - return False - # ignore sequences which are all-ast, they'll be covered by walk - elif _all_isinstance(v1, ast.AST) and _all_isinstance(v2, ast.AST): - continue - elif v1 != v2: - return False - elif v1 != v2: - return False - return True - - -def targets_same(node1: ast.AST, node2: ast.AST) -> bool: - for t1, t2 in zip(ast.walk(node1), ast.walk(node2)): - # ignore `ast.Load` / `ast.Store` - if _all_isinstance((t1, t2), ast.expr_context): - continue - elif type(t1) != type(t2): - return False - elif not _fields_same(t1, t2): - return False - else: - return True diff --git a/pyupgrade/_plugins/legacy.py b/pyupgrade/_plugins/legacy.py index f7ba3763..321ffba0 100644 --- a/pyupgrade/_plugins/legacy.py +++ b/pyupgrade/_plugins/legacy.py @@ -2,19 +2,21 @@ import collections import contextlib import functools +from typing import Any from typing import Dict from typing import Generator from typing import Iterable from typing import List from typing import Set from typing import Tuple +from typing import Type +from typing import Union from tokenize_rt import Offset from tokenize_rt import Token from tokenize_rt import tokens_to_src from pyupgrade._ast_helpers import ast_to_offset -from pyupgrade._ast_helpers import targets_same from pyupgrade._data import register from pyupgrade._data import State from pyupgrade._data import TokenFunc @@ -24,7 +26,6 @@ from pyupgrade._token_helpers import find_token FUNC_TYPES = (ast.Lambda, ast.FunctionDef, ast.AsyncFunctionDef) -NON_LAMBDA_FUNC_TYPES = (ast.FunctionDef, ast.AsyncFunctionDef) def _fix_yield(i: int, tokens: List[Token]) -> None: @@ -35,13 +36,44 @@ def _fix_yield(i: int, tokens: List[Token]) -> None: tokens[i:block.end] = [Token('CODE', f'yield from {container}\n')] -def _is_simple_base(base: ast.AST) -> bool: - return ( - isinstance(base, ast.Name) or ( - isinstance(base, ast.Attribute) and - _is_simple_base(base.value) - ) - ) +def _all_isinstance( + vals: Iterable[Any], + tp: Union[Type[Any], Tuple[Type[Any], ...]], +) -> bool: + return all(isinstance(v, tp) for v in vals) + + +def _fields_same(n1: ast.AST, n2: ast.AST) -> bool: + for (a1, v1), (a2, v2) in zip(ast.iter_fields(n1), ast.iter_fields(n2)): + # ignore ast attributes, they'll be covered by walk + if a1 != a2: + return False + elif _all_isinstance((v1, v2), ast.AST): + continue + elif _all_isinstance((v1, v2), (list, tuple)): + if len(v1) != len(v2): + return False + # ignore sequences which are all-ast, they'll be covered by walk + elif _all_isinstance(v1, ast.AST) and _all_isinstance(v2, ast.AST): + continue + elif v1 != v2: + return False + elif v1 != v2: + return False + return True + + +def _targets_same(target: ast.AST, yield_value: ast.AST) -> bool: + for t1, t2 in zip(ast.walk(target), ast.walk(yield_value)): + # ignore `ast.Load` / `ast.Store` + if _all_isinstance((t1, t2), ast.expr_context): + continue + elif type(t1) != type(t2): + return False + elif not _fields_same(t1, t2): + return False + else: + return True class Scope: @@ -60,7 +92,6 @@ class Visitor(ast.NodeVisitor): def __init__(self) -> None: self._scopes: List[Scope] = [] self.super_offsets: Set[Offset] = set() - self.old_super_offsets: Set[Tuple[Offset, str]] = set() self.yield_offsets: Set[Offset] = set() @contextlib.contextmanager @@ -106,6 +137,7 @@ def visit_Call(self, node: ast.Call) -> None: len(node.args) == 2 and isinstance(node.args[0], ast.Name) and isinstance(node.args[1], ast.Name) and + # there are at least two scopes len(self._scopes) >= 2 and # the second to last scope is the class in arg1 isinstance(self._scopes[-2].node, ast.ClassDef) and @@ -116,29 +148,6 @@ def visit_Call(self, node: ast.Call) -> None: node.args[1].id == self._scopes[-1].node.args.args[0].arg ): self.super_offsets.add(ast_to_offset(node)) - elif ( - # base.funcname(funcarg1, ...) - isinstance(node.func, ast.Attribute) and - len(node.args) >= 1 and - isinstance(node.args[0], ast.Name) and - len(self._scopes) >= 2 and - # last stack is a function whose first argument is the first - # argument of this function - isinstance(self._scopes[-1].node, NON_LAMBDA_FUNC_TYPES) and - node.func.attr == self._scopes[-1].node.name and - node.func.attr != '__new__' and - len(self._scopes[-1].node.args.args) >= 1 and - node.args[0].id == self._scopes[-1].node.args.args[0].arg and - # the function is an attribute of the contained class name - isinstance(self._scopes[-2].node, ast.ClassDef) and - len(self._scopes[-2].node.bases) == 1 and - _is_simple_base(self._scopes[-2].node.bases[0]) and - targets_same( - self._scopes[-2].node.bases[0], - node.func.value, - ) - ): - self.old_super_offsets.add((ast_to_offset(node), node.func.attr)) self.generic_visit(node) @@ -150,7 +159,7 @@ def visit_For(self, node: ast.For) -> None: isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Yield) and node.body[0].value.value is not None and - targets_same(node.target, node.body[0].value.value) and + _targets_same(node.target, node.body[0].value.value) and not node.orelse ): offset = ast_to_offset(node) @@ -189,10 +198,5 @@ def visit_Module( for offset in visitor.super_offsets: yield offset, super_func - for offset, func_name in visitor.old_super_offsets: - template = f'super().{func_name}({{rest}})' - callback = functools.partial(find_and_replace_call, template=template) - yield offset, callback - for offset in visitor.yield_offsets: yield offset, _fix_yield diff --git a/tests/ast_helpers_test.py b/tests/ast_helpers_test.py deleted file mode 100644 index 282ea729..00000000 --- a/tests/ast_helpers_test.py +++ /dev/null @@ -1,19 +0,0 @@ -import ast - -from pyupgrade._ast_helpers import _fields_same -from pyupgrade._ast_helpers import targets_same - - -def test_targets_same(): - assert targets_same(ast.parse('global a, b'), ast.parse('global a, b')) - assert not targets_same(ast.parse('global a'), ast.parse('global b')) - - -def _get_body(expr): - body = ast.parse(expr).body[0] - assert isinstance(body, ast.Expr) - return body.value - - -def test_fields_same(): - assert not _fields_same(_get_body('x'), _get_body('1')) diff --git a/tests/features/super_test.py b/tests/features/super_test.py index 8e3be14f..ee558a6d 100644 --- a/tests/features/super_test.py +++ b/tests/features/super_test.py @@ -122,107 +122,3 @@ def test_fix_super_noop(s): ) def test_fix_super(s, expected): assert _fix_plugins(s, settings=Settings(min_version=(3,))) == expected - - -@pytest.mark.parametrize( - 's', - ( - pytest.param( - 'class C(B):\n' - ' def f(self):\n' - ' B.f(notself)\n', - id='old style super, first argument is not first function arg', - ), - pytest.param( - 'class C(B1, B2):\n' - ' def f(self):\n' - ' B1.f(self)\n', - # TODO: is this safe to rewrite? I don't think so - id='old-style super, multiple inheritance first class', - ), - pytest.param( - 'class C(B1, B2):\n' - ' def f(self):\n' - ' B2.f(self)\n', - # TODO: is this safe to rewrite? I don't think so - id='old-style super, multiple inheritance not-first class', - ), - pytest.param( - 'class C(Base):\n' - ' def f(self):\n' - ' return [Base.f(self) for _ in ()]\n', - id='super in comprehension', - ), - pytest.param( - 'class C(Base):\n' - ' def f(self):\n' - ' def g():\n' - ' Base.f(self)\n' - ' g()\n', - id='super in nested functions', - ), - pytest.param( - 'class C(not_simple()):\n' - ' def f(self):\n' - ' not_simple().f(self)\n', - id='not a simple base', - ), - pytest.param( - 'class C(a().b):\n' - ' def f(self):\n' - ' a().b.f(self)\n', - id='non simple attribute base', - ), - pytest.param( - 'class C:\n' - ' @classmethod\n' - ' def make(cls, instance):\n' - ' ...\n' - 'class D(C):\n' - ' def find(self):\n' - ' return C.make(self)\n', - ), - pytest.param( - 'class C(tuple):\n' - ' def __new__(cls, arg):\n' - ' return tuple.__new__(cls, (arg,))\n', - id='super() does not work properly for __new__', - ), - ), -) -def test_old_style_class_super_noop(s): - assert _fix_plugins(s, settings=Settings(min_version=(3,))) == s - - -@pytest.mark.parametrize( - ('s', 'expected'), - ( - ( - 'class C(B):\n' - ' def f(self):\n' - ' B.f(self)\n' - ' B.f(self, arg, arg)\n', - 'class C(B):\n' - ' def f(self):\n' - ' super().f()\n' - ' super().f(arg, arg)\n', - ), - pytest.param( - 'class C(B):\n' - ' def f(self, a):\n' - ' B.f(\n' - ' self,\n' - ' a,\n' - ' )\n', - - 'class C(B):\n' - ' def f(self, a):\n' - ' super().f(\n' - ' a,\n' - ' )\n', - id='multi-line super call', - ), - ), -) -def test_old_style_class_super(s, expected): - assert _fix_plugins(s, settings=Settings(min_version=(3,))) == expected diff --git a/tests/features/yield_from_test.py b/tests/features/yield_from_test.py index dc040993..d3d12326 100644 --- a/tests/features/yield_from_test.py +++ b/tests/features/yield_from_test.py @@ -1,7 +1,11 @@ +import ast + import pytest from pyupgrade._data import Settings from pyupgrade._main import _fix_plugins +from pyupgrade._plugins.legacy import _fields_same +from pyupgrade._plugins.legacy import _targets_same @pytest.mark.parametrize( @@ -211,3 +215,18 @@ def test_fix_yield_from(s, expected): ) def test_fix_yield_from_noop(s): assert _fix_plugins(s, settings=Settings(min_version=(3,))) == s + + +def test_targets_same(): + assert _targets_same(ast.parse('global a, b'), ast.parse('global a, b')) + assert not _targets_same(ast.parse('global a'), ast.parse('global b')) + + +def _get_body(expr): + body = ast.parse(expr).body[0] + assert isinstance(body, ast.Expr) + return body.value + + +def test_fields_same(): + assert not _fields_same(_get_body('x'), _get_body('1'))