Skip to content

Commit

Permalink
Revert "Revert "Merge pull request #317 from asottile/old_super""
Browse files Browse the repository at this point in the history
This reverts commit 2719335.
  • Loading branch information
asottile committed Sep 25, 2021
1 parent 9c40758 commit 17cb81f
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 61 deletions.
44 changes: 44 additions & 0 deletions pyupgrade/_ast_helpers.py
@@ -1,8 +1,12 @@
import ast
import warnings
from typing import Any
from typing import Container
from typing import Dict
from typing import Iterable
from typing import Set
from typing import Tuple
from typing import Type
from typing import Union

from tokenize_rt import Offset
Expand Down Expand Up @@ -57,3 +61,43 @@ def is_async_listcomp(node: ast.ListComp) -> bool:
any(gen.is_async for gen in node.generators) or
contains_await(node)
)


def _all_isinstance(
vals: Iterable[Any],
tp: Union[Type[Any], Tuple[Type[Any], ...]],
) -> bool:
return all(isinstance(v, tp) for v in vals)


def _fields_same(n1: ast.AST, n2: ast.AST) -> bool:
for (a1, v1), (a2, v2) in zip(ast.iter_fields(n1), ast.iter_fields(n2)):
# ignore ast attributes, they'll be covered by walk
if a1 != a2:
return False
elif _all_isinstance((v1, v2), ast.AST):
continue
elif _all_isinstance((v1, v2), (list, tuple)):
if len(v1) != len(v2):
return False
# ignore sequences which are all-ast, they'll be covered by walk
elif _all_isinstance(v1, ast.AST) and _all_isinstance(v2, ast.AST):
continue
elif v1 != v2:
return False
elif v1 != v2:
return False
return True


def targets_same(node1: ast.AST, node2: ast.AST) -> bool:
for t1, t2 in zip(ast.walk(node1), ast.walk(node2)):
# ignore `ast.Load` / `ast.Store`
if _all_isinstance((t1, t2), ast.expr_context):
continue
elif type(t1) != type(t2):
return False
elif not _fields_same(t1, t2):
return False
else:
return True
90 changes: 48 additions & 42 deletions pyupgrade/_plugins/legacy.py
Expand Up @@ -2,28 +2,28 @@
import collections
import contextlib
import functools
from typing import Any
from typing import Dict
from typing import Generator
from typing import Iterable
from typing import List
from typing import Set
from typing import Tuple
from typing import Type
from typing import Union

from tokenize_rt import Offset
from tokenize_rt import Token
from tokenize_rt import tokens_to_src

from pyupgrade._ast_helpers import ast_to_offset
from pyupgrade._ast_helpers import targets_same
from pyupgrade._data import register
from pyupgrade._data import State
from pyupgrade._data import TokenFunc
from pyupgrade._token_helpers import Block
from pyupgrade._token_helpers import find_and_replace_call
from pyupgrade._token_helpers import find_block_start
from pyupgrade._token_helpers import find_open_paren
from pyupgrade._token_helpers import find_token
from pyupgrade._token_helpers import parse_call_args

FUNC_TYPES = (ast.Lambda, ast.FunctionDef, ast.AsyncFunctionDef)

Expand All @@ -36,44 +36,27 @@ def _fix_yield(i: int, tokens: List[Token]) -> None:
tokens[i:block.end] = [Token('CODE', f'yield from {container}\n')]


def _all_isinstance(
vals: Iterable[Any],
tp: Union[Type[Any], Tuple[Type[Any], ...]],
) -> bool:
return all(isinstance(v, tp) for v in vals)


def _fields_same(n1: ast.AST, n2: ast.AST) -> bool:
for (a1, v1), (a2, v2) in zip(ast.iter_fields(n1), ast.iter_fields(n2)):
# ignore ast attributes, they'll be covered by walk
if a1 != a2:
return False
elif _all_isinstance((v1, v2), ast.AST):
continue
elif _all_isinstance((v1, v2), (list, tuple)):
if len(v1) != len(v2):
return False
# ignore sequences which are all-ast, they'll be covered by walk
elif _all_isinstance(v1, ast.AST) and _all_isinstance(v2, ast.AST):
continue
elif v1 != v2:
return False
elif v1 != v2:
return False
return True


def _targets_same(target: ast.AST, yield_value: ast.AST) -> bool:
for t1, t2 in zip(ast.walk(target), ast.walk(yield_value)):
# ignore `ast.Load` / `ast.Store`
if _all_isinstance((t1, t2), ast.expr_context):
continue
elif type(t1) != type(t2):
return False
elif not _fields_same(t1, t2):
return False
def _fix_old_super(i: int, tokens: List[Token]) -> None:
j = find_open_paren(tokens, i)
k = j - 1
while tokens[k].src != '.':
k -= 1
func_args, end = parse_call_args(tokens, j)
# remove the first argument
if len(func_args) == 1:
del tokens[func_args[0][0]:func_args[0][0] + 1]
else:
return True
del tokens[func_args[0][0]:func_args[1][0] + 1]
tokens[i:k] = [Token('CODE', 'super()')]


def _is_simple_base(base: ast.AST) -> bool:
return (
isinstance(base, ast.Name) or (
isinstance(base, ast.Attribute) and
_is_simple_base(base.value)
)
)


class Scope:
Expand All @@ -92,6 +75,7 @@ class Visitor(ast.NodeVisitor):
def __init__(self) -> None:
self._scopes: List[Scope] = []
self.super_offsets: Set[Offset] = set()
self.old_super_offsets: Set[Offset] = set()
self.yield_offsets: Set[Offset] = set()

@contextlib.contextmanager
Expand Down Expand Up @@ -137,7 +121,6 @@ def visit_Call(self, node: ast.Call) -> None:
len(node.args) == 2 and
isinstance(node.args[0], ast.Name) and
isinstance(node.args[1], ast.Name) and
# there are at least two scopes
len(self._scopes) >= 2 and
# the second to last scope is the class in arg1
isinstance(self._scopes[-2].node, ast.ClassDef) and
Expand All @@ -148,6 +131,26 @@ def visit_Call(self, node: ast.Call) -> None:
node.args[1].id == self._scopes[-1].node.args.args[0].arg
):
self.super_offsets.add(ast_to_offset(node))
elif (
len(self._scopes) >= 2 and
# last stack is a function whose first argument is the first
# argument of this function
len(node.args) >= 1 and
isinstance(node.args[0], ast.Name) and
isinstance(self._scopes[-1].node, FUNC_TYPES) and
len(self._scopes[-1].node.args.args) >= 1 and
node.args[0].id == self._scopes[-1].node.args.args[0].arg and
# the function is an attribute of the contained class name
isinstance(node.func, ast.Attribute) and
isinstance(self._scopes[-2].node, ast.ClassDef) and
len(self._scopes[-2].node.bases) == 1 and
_is_simple_base(self._scopes[-2].node.bases[0]) and
targets_same(
self._scopes[-2].node.bases[0],
node.func.value,
)
):
self.old_super_offsets.add(ast_to_offset(node))

self.generic_visit(node)

Expand All @@ -159,7 +162,7 @@ def visit_For(self, node: ast.For) -> None:
isinstance(node.body[0], ast.Expr) and
isinstance(node.body[0].value, ast.Yield) and
node.body[0].value.value is not None and
_targets_same(node.target, node.body[0].value.value) and
targets_same(node.target, node.body[0].value.value) and
not node.orelse
):
offset = ast_to_offset(node)
Expand Down Expand Up @@ -198,5 +201,8 @@ def visit_Module(
for offset in visitor.super_offsets:
yield offset, super_func

for offset in visitor.old_super_offsets:
yield offset, _fix_old_super

for offset in visitor.yield_offsets:
yield offset, _fix_yield
19 changes: 19 additions & 0 deletions tests/ast_helpers_test.py
@@ -0,0 +1,19 @@
import ast

from pyupgrade._ast_helpers import _fields_same
from pyupgrade._ast_helpers import targets_same


def test_targets_same():
assert targets_same(ast.parse('global a, b'), ast.parse('global a, b'))
assert not targets_same(ast.parse('global a'), ast.parse('global b'))


def _get_body(expr):
body = ast.parse(expr).body[0]
assert isinstance(body, ast.Expr)
return body.value


def test_fields_same():
assert not _fields_same(_get_body('x'), _get_body('1'))
74 changes: 74 additions & 0 deletions tests/features/super_test.py
Expand Up @@ -122,3 +122,77 @@ def test_fix_super_noop(s):
)
def test_fix_super(s, expected):
assert _fix_plugins(s, settings=Settings(min_version=(3,))) == expected


@pytest.mark.parametrize(
's',
(
pytest.param(
'class C(B):\n'
' def f(self):\n'
' B.f(notself)\n',
id='old style super, first argument is not first function arg',
),
pytest.param(
'class C(B1, B2):\n'
' def f(self):\n'
' B1.f(self)\n',
# TODO: is this safe to rewrite? I don't think so
id='old-style super, multiple inheritance first class',
),
pytest.param(
'class C(B1, B2):\n'
' def f(self):\n'
' B2.f(self)\n',
# TODO: is this safe to rewrite? I don't think so
id='old-style super, multiple inheritance not-first class',
),
pytest.param(
'class C(Base):\n'
' def f(self):\n'
' return [Base.f(self) for _ in ()]\n',
id='super in comprehension',
),
pytest.param(
'class C(Base):\n'
' def f(self):\n'
' def g():\n'
' Base.f(self)\n'
' g()\n',
id='super in nested functions',
),
pytest.param(
'class C(not_simple()):\n'
' def f(self):\n'
' not_simple().f(self)\n',
id='not a simple base',
),
pytest.param(
'class C(a().b):\n'
' def f(self):\n'
' a().b.f(self)\n',
id='non simple attribute base',
),
),
)
def test_old_style_class_super_noop(s):
assert _fix_plugins(s, settings=Settings(min_version=(3,))) == s


@pytest.mark.parametrize(
('s', 'expected'),
(
(
'class C(B):\n'
' def f(self):\n'
' B.f(self)\n'
' B.f(self, arg, arg)\n',
'class C(B):\n'
' def f(self):\n'
' super().f()\n'
' super().f(arg, arg)\n',
),
),
)
def test_old_style_class_super(s, expected):
assert _fix_plugins(s, settings=Settings(min_version=(3,))) == expected
19 changes: 0 additions & 19 deletions tests/features/yield_from_test.py
@@ -1,11 +1,7 @@
import ast

import pytest

from pyupgrade._data import Settings
from pyupgrade._main import _fix_plugins
from pyupgrade._plugins.legacy import _fields_same
from pyupgrade._plugins.legacy import _targets_same


@pytest.mark.parametrize(
Expand Down Expand Up @@ -215,18 +211,3 @@ def test_fix_yield_from(s, expected):
)
def test_fix_yield_from_noop(s):
assert _fix_plugins(s, settings=Settings(min_version=(3,))) == s


def test_targets_same():
assert _targets_same(ast.parse('global a, b'), ast.parse('global a, b'))
assert not _targets_same(ast.parse('global a'), ast.parse('global b'))


def _get_body(expr):
body = ast.parse(expr).body[0]
assert isinstance(body, ast.Expr)
return body.value


def test_fields_same():
assert not _fields_same(_get_body('x'), _get_body('1'))

0 comments on commit 17cb81f

Please sign in to comment.