Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite old-style-class super calls #317

Merged
merged 1 commit into from Jun 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
49 changes: 44 additions & 5 deletions pyupgrade.py
Expand Up @@ -1157,8 +1157,8 @@ def fields_same(n1: ast.AST, n2: ast.AST) -> bool:
return True


def targets_same(target: ast.AST, yield_value: ast.AST) -> bool:
for t1, t2 in zip(ast.walk(target), ast.walk(yield_value)):
def targets_same(node1: ast.AST, node2: ast.AST) -> bool:
for t1, t2 in zip(ast.walk(node1), ast.walk(node2)):
# ignore `ast.Load` / `ast.Store`
if _all_isinstance((t1, t2), ast.expr_context):
continue
Expand All @@ -1177,6 +1177,15 @@ def _is_codec(encoding: str, name: str) -> bool:
return False


def _is_simple_base(base: ast.AST) -> bool:
return (
isinstance(base, ast.Name) or (
isinstance(base, ast.Attribute) and
_is_simple_base(base.value)
)
)


class FindPy3Plus(ast.NodeVisitor):
OS_ERROR_ALIASES = frozenset((
'EnvironmentError',
Expand All @@ -1195,8 +1204,9 @@ class FindPy3Plus(ast.NodeVisitor):
MOCK_MODULES = frozenset(('mock', 'mock.mock'))

class ClassInfo:
def __init__(self, name: str) -> None:
self.name = name
def __init__(self, node: ast.ClassDef) -> None:
self.bases = node.bases
self.name = node.name
self.def_depth = 0
self.first_arg_name = ''

Expand Down Expand Up @@ -1249,6 +1259,7 @@ def __init__(self, keep_mock: bool) -> None:
self._class_info_stack: List[FindPy3Plus.ClassInfo] = []
self._in_comp = 0
self.super_calls: Dict[Offset, ast.Call] = {}
self.old_style_super_calls: Set[Offset] = set()
self._in_async_def = False
self._scope_stack: List[FindPy3Plus.Scope] = []
self.yield_from_fors: Set[Offset] = set()
Expand Down Expand Up @@ -1376,7 +1387,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
):
self.six_with_metaclass.add(_ast_to_offset(node.bases[0]))

self._class_info_stack.append(FindPy3Plus.ClassInfo(node.name))
self._class_info_stack.append(FindPy3Plus.ClassInfo(node))
self.generic_visit(node)
self._class_info_stack.pop()

Expand Down Expand Up @@ -1542,6 +1553,21 @@ def visit_Call(self, node: ast.Call) -> None:
node.args[1].id == self._class_info_stack[-1].first_arg_name
):
self.super_calls[_ast_to_offset(node)] = node
elif (
not self._in_comp and
self._class_info_stack and
self._class_info_stack[-1].def_depth == 1 and
len(self._class_info_stack[-1].bases) == 1 and
_is_simple_base(self._class_info_stack[-1].bases[0]) and
isinstance(node.func, ast.Attribute) and
targets_same(
self._class_info_stack[-1].bases[0], node.func.value,
) and
len(node.args) >= 1 and
isinstance(node.args[0], ast.Name) and
node.args[0].id == self._class_info_stack[-1].first_arg_name
):
self.old_style_super_calls.add(_ast_to_offset(node))
elif (
(
self._is_six(node.func, SIX_NATIVE_STR) or
Expand Down Expand Up @@ -2038,6 +2064,7 @@ def _fix_py3_plus(
visitor.six_type_ctx,
visitor.six_with_metaclass,
visitor.super_calls,
visitor.old_style_super_calls,
visitor.yield_from_fors,
)):
return contents_text
Expand Down Expand Up @@ -2197,6 +2224,18 @@ def _replace(i: int, mapping: Dict[str, str], node: NameOrAttr) -> None:
call = visitor.super_calls[token.offset]
victims = _victims(tokens, i, call, gen=False)
del tokens[victims.starts[0] + 1:victims.ends[-1]]
elif token.offset in visitor.old_style_super_calls:
j = _find_open_paren(tokens, i)
k = j - 1
while tokens[k].src != '.':
k -= 1
func_args, end = _parse_call_args(tokens, j)
# remove the first argument
if len(func_args) == 1:
del tokens[func_args[0][0]:func_args[0][0] + 1]
else:
del tokens[func_args[0][0]:func_args[1][0] + 1]
tokens[i:k] = [Token('CODE', 'super()')]
elif token.offset in visitor.encode_calls:
i = _find_open_paren(tokens, i)
call = visitor.encode_calls[token.offset]
Expand Down
74 changes: 74 additions & 0 deletions tests/super_test.py
Expand Up @@ -121,3 +121,77 @@ def test_fix_super_noop(s):
)
def test_fix_super(s, expected):
assert _fix_py3_plus(s, (3,)) == expected


@pytest.mark.parametrize(
's',
(
pytest.param(
'class C(B):\n'
' def f(self):\n'
' B.f(notself)\n',
id='old style super, first argument is not first function arg',
),
pytest.param(
'class C(B1, B2):\n'
' def f(self):\n'
' B1.f(self)\n',
# TODO: is this safe to rewrite? I don't think so
id='old-style super, multiple inheritance first class',
),
pytest.param(
'class C(B1, B2):\n'
' def f(self):\n'
' B2.f(self)\n',
# TODO: is this safe to rewrite? I don't think so
id='old-style super, multiple inheritance not-first class',
),
pytest.param(
'class C(Base):\n'
' def f(self):\n'
' return [Base.f(self) for _ in ()]\n',
id='super in comprehension',
),
pytest.param(
'class C(Base):\n'
' def f(self):\n'
' def g():\n'
' Base.f(self)\n'
' g()\n',
id='super in nested functions',
),
pytest.param(
'class C(not_simple()):\n'
' def f(self):\n'
' not_simple().f(self)\n',
id='not a simple base',
),
pytest.param(
'class C(a().b):\n'
' def f(self):\n'
' a().b.f(self)\n',
id='non simple attribute base',
),
),
)
def test_old_style_class_super_noop(s):
assert _fix_py3_plus(s, (3,)) == s


@pytest.mark.parametrize(
('s', 'expected'),
(
(
'class C(B):\n'
' def f(self):\n'
' B.f(self)\n'
' B.f(self, arg, arg)\n',
'class C(B):\n'
' def f(self):\n'
' super().f()\n'
' super().f(arg, arg)\n',
),
),
)
def test_old_style_class_super(s, expected):
assert _fix_py3_plus(s, (3,)) == expected