Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite old-style-class super calls #320

Merged
merged 4 commits into from Sep 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
44 changes: 44 additions & 0 deletions pyupgrade/_ast_helpers.py
@@ -1,8 +1,12 @@
import ast
import warnings
from typing import Any
from typing import Container
from typing import Dict
from typing import Iterable
from typing import Set
from typing import Tuple
from typing import Type
from typing import Union

from tokenize_rt import Offset
Expand Down Expand Up @@ -57,3 +61,43 @@ def is_async_listcomp(node: ast.ListComp) -> bool:
any(gen.is_async for gen in node.generators) or
contains_await(node)
)


def _all_isinstance(
vals: Iterable[Any],
tp: Union[Type[Any], Tuple[Type[Any], ...]],
) -> bool:
return all(isinstance(v, tp) for v in vals)


def _fields_same(n1: ast.AST, n2: ast.AST) -> bool:
for (a1, v1), (a2, v2) in zip(ast.iter_fields(n1), ast.iter_fields(n2)):
# ignore ast attributes, they'll be covered by walk
if a1 != a2:
return False
elif _all_isinstance((v1, v2), ast.AST):
continue
elif _all_isinstance((v1, v2), (list, tuple)):
if len(v1) != len(v2):
return False
# ignore sequences which are all-ast, they'll be covered by walk
elif _all_isinstance(v1, ast.AST) and _all_isinstance(v2, ast.AST):
continue
elif v1 != v2:
return False
elif v1 != v2:
return False
return True


def targets_same(node1: ast.AST, node2: ast.AST) -> bool:
for t1, t2 in zip(ast.walk(node1), ast.walk(node2)):
# ignore `ast.Load` / `ast.Store`
if _all_isinstance((t1, t2), ast.expr_context):
continue
elif type(t1) != type(t2):
return False
elif not _fields_same(t1, t2):
return False
else:
return True
82 changes: 39 additions & 43 deletions pyupgrade/_plugins/legacy.py
Expand Up @@ -2,21 +2,19 @@
import collections
import contextlib
import functools
from typing import Any
from typing import Dict
from typing import Generator
from typing import Iterable
from typing import List
from typing import Set
from typing import Tuple
from typing import Type
from typing import Union

from tokenize_rt import Offset
from tokenize_rt import Token
from tokenize_rt import tokens_to_src

from pyupgrade._ast_helpers import ast_to_offset
from pyupgrade._ast_helpers import targets_same
from pyupgrade._data import register
from pyupgrade._data import State
from pyupgrade._data import TokenFunc
Expand All @@ -26,6 +24,7 @@
from pyupgrade._token_helpers import find_token

FUNC_TYPES = (ast.Lambda, ast.FunctionDef, ast.AsyncFunctionDef)
NON_LAMBDA_FUNC_TYPES = (ast.FunctionDef, ast.AsyncFunctionDef)


def _fix_yield(i: int, tokens: List[Token]) -> None:
Expand All @@ -36,44 +35,13 @@ def _fix_yield(i: int, tokens: List[Token]) -> None:
tokens[i:block.end] = [Token('CODE', f'yield from {container}\n')]


def _all_isinstance(
vals: Iterable[Any],
tp: Union[Type[Any], Tuple[Type[Any], ...]],
) -> bool:
return all(isinstance(v, tp) for v in vals)


def _fields_same(n1: ast.AST, n2: ast.AST) -> bool:
for (a1, v1), (a2, v2) in zip(ast.iter_fields(n1), ast.iter_fields(n2)):
# ignore ast attributes, they'll be covered by walk
if a1 != a2:
return False
elif _all_isinstance((v1, v2), ast.AST):
continue
elif _all_isinstance((v1, v2), (list, tuple)):
if len(v1) != len(v2):
return False
# ignore sequences which are all-ast, they'll be covered by walk
elif _all_isinstance(v1, ast.AST) and _all_isinstance(v2, ast.AST):
continue
elif v1 != v2:
return False
elif v1 != v2:
return False
return True


def _targets_same(target: ast.AST, yield_value: ast.AST) -> bool:
for t1, t2 in zip(ast.walk(target), ast.walk(yield_value)):
# ignore `ast.Load` / `ast.Store`
if _all_isinstance((t1, t2), ast.expr_context):
continue
elif type(t1) != type(t2):
return False
elif not _fields_same(t1, t2):
return False
else:
return True
def _is_simple_base(base: ast.AST) -> bool:
return (
isinstance(base, ast.Name) or (
isinstance(base, ast.Attribute) and
_is_simple_base(base.value)
)
)


class Scope:
Expand All @@ -92,6 +60,7 @@ class Visitor(ast.NodeVisitor):
def __init__(self) -> None:
self._scopes: List[Scope] = []
self.super_offsets: Set[Offset] = set()
self.old_super_offsets: Set[Tuple[Offset, str]] = set()
self.yield_offsets: Set[Offset] = set()

@contextlib.contextmanager
Expand Down Expand Up @@ -137,7 +106,6 @@ def visit_Call(self, node: ast.Call) -> None:
len(node.args) == 2 and
isinstance(node.args[0], ast.Name) and
isinstance(node.args[1], ast.Name) and
# there are at least two scopes
len(self._scopes) >= 2 and
# the second to last scope is the class in arg1
isinstance(self._scopes[-2].node, ast.ClassDef) and
Expand All @@ -148,6 +116,29 @@ def visit_Call(self, node: ast.Call) -> None:
node.args[1].id == self._scopes[-1].node.args.args[0].arg
):
self.super_offsets.add(ast_to_offset(node))
elif (
# base.funcname(funcarg1, ...)
isinstance(node.func, ast.Attribute) and
len(node.args) >= 1 and
isinstance(node.args[0], ast.Name) and
len(self._scopes) >= 2 and
# last stack is a function whose first argument is the first
# argument of this function
isinstance(self._scopes[-1].node, NON_LAMBDA_FUNC_TYPES) and
node.func.attr == self._scopes[-1].node.name and
node.func.attr != '__new__' and
len(self._scopes[-1].node.args.args) >= 1 and
node.args[0].id == self._scopes[-1].node.args.args[0].arg and
# the function is an attribute of the contained class name
isinstance(self._scopes[-2].node, ast.ClassDef) and
len(self._scopes[-2].node.bases) == 1 and
_is_simple_base(self._scopes[-2].node.bases[0]) and
targets_same(
self._scopes[-2].node.bases[0],
node.func.value,
)
):
self.old_super_offsets.add((ast_to_offset(node), node.func.attr))

self.generic_visit(node)

Expand All @@ -159,7 +150,7 @@ def visit_For(self, node: ast.For) -> None:
isinstance(node.body[0], ast.Expr) and
isinstance(node.body[0].value, ast.Yield) and
node.body[0].value.value is not None and
_targets_same(node.target, node.body[0].value.value) and
targets_same(node.target, node.body[0].value.value) and
not node.orelse
):
offset = ast_to_offset(node)
Expand Down Expand Up @@ -198,5 +189,10 @@ def visit_Module(
for offset in visitor.super_offsets:
yield offset, super_func

for offset, func_name in visitor.old_super_offsets:
template = f'super().{func_name}({{rest}})'
callback = functools.partial(find_and_replace_call, template=template)
yield offset, callback

for offset in visitor.yield_offsets:
yield offset, _fix_yield
19 changes: 19 additions & 0 deletions tests/ast_helpers_test.py
@@ -0,0 +1,19 @@
import ast

from pyupgrade._ast_helpers import _fields_same
from pyupgrade._ast_helpers import targets_same


def test_targets_same():
assert targets_same(ast.parse('global a, b'), ast.parse('global a, b'))
assert not targets_same(ast.parse('global a'), ast.parse('global b'))


def _get_body(expr):
body = ast.parse(expr).body[0]
assert isinstance(body, ast.Expr)
return body.value


def test_fields_same():
assert not _fields_same(_get_body('x'), _get_body('1'))
104 changes: 104 additions & 0 deletions tests/features/super_test.py
Expand Up @@ -122,3 +122,107 @@ def test_fix_super_noop(s):
)
def test_fix_super(s, expected):
assert _fix_plugins(s, settings=Settings(min_version=(3,))) == expected


@pytest.mark.parametrize(
's',
(
pytest.param(
'class C(B):\n'
' def f(self):\n'
' B.f(notself)\n',
id='old style super, first argument is not first function arg',
),
pytest.param(
'class C(B1, B2):\n'
' def f(self):\n'
' B1.f(self)\n',
# TODO: is this safe to rewrite? I don't think so
id='old-style super, multiple inheritance first class',
),
pytest.param(
'class C(B1, B2):\n'
' def f(self):\n'
' B2.f(self)\n',
# TODO: is this safe to rewrite? I don't think so
id='old-style super, multiple inheritance not-first class',
),
pytest.param(
'class C(Base):\n'
' def f(self):\n'
' return [Base.f(self) for _ in ()]\n',
id='super in comprehension',
),
pytest.param(
'class C(Base):\n'
' def f(self):\n'
' def g():\n'
' Base.f(self)\n'
' g()\n',
id='super in nested functions',
),
pytest.param(
'class C(not_simple()):\n'
' def f(self):\n'
' not_simple().f(self)\n',
id='not a simple base',
),
pytest.param(
'class C(a().b):\n'
' def f(self):\n'
' a().b.f(self)\n',
id='non simple attribute base',
),
pytest.param(
'class C:\n'
' @classmethod\n'
' def make(cls, instance):\n'
' ...\n'
'class D(C):\n'
' def find(self):\n'
' return C.make(self)\n',
),
pytest.param(
'class C(tuple):\n'
' def __new__(cls, arg):\n'
' return tuple.__new__(cls, (arg,))\n',
id='super() does not work properly for __new__',
),
),
)
def test_old_style_class_super_noop(s):
assert _fix_plugins(s, settings=Settings(min_version=(3,))) == s


@pytest.mark.parametrize(
('s', 'expected'),
(
(
'class C(B):\n'
' def f(self):\n'
' B.f(self)\n'
' B.f(self, arg, arg)\n',
'class C(B):\n'
' def f(self):\n'
' super().f()\n'
' super().f(arg, arg)\n',
),
pytest.param(
'class C(B):\n'
' def f(self, a):\n'
' B.f(\n'
' self,\n'
' a,\n'
' )\n',

'class C(B):\n'
' def f(self, a):\n'
' super().f(\n'
' a,\n'
' )\n',
id='multi-line super call',
),
),
)
def test_old_style_class_super(s, expected):
assert _fix_plugins(s, settings=Settings(min_version=(3,))) == expected
19 changes: 0 additions & 19 deletions tests/features/yield_from_test.py
@@ -1,11 +1,7 @@
import ast

import pytest

from pyupgrade._data import Settings
from pyupgrade._main import _fix_plugins
from pyupgrade._plugins.legacy import _fields_same
from pyupgrade._plugins.legacy import _targets_same


@pytest.mark.parametrize(
Expand Down Expand Up @@ -215,18 +211,3 @@ def test_fix_yield_from(s, expected):
)
def test_fix_yield_from_noop(s):
assert _fix_plugins(s, settings=Settings(min_version=(3,))) == s


def test_targets_same():
assert _targets_same(ast.parse('global a, b'), ast.parse('global a, b'))
assert not _targets_same(ast.parse('global a'), ast.parse('global b'))


def _get_body(expr):
body = ast.parse(expr).body[0]
assert isinstance(body, ast.Expr)
return body.value


def test_fields_same():
assert not _fields_same(_get_body('x'), _get_body('1'))