Skip to content

Commit

Permalink
Merge pull request #360 from asottile/six_calls_expr_context
Browse files Browse the repository at this point in the history
parenthesize expressions when replacing six calls when needed
  • Loading branch information
asottile committed Nov 10, 2020
2 parents 20d5c9e + 0ff29b7 commit b1773d9
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 3 deletions.
36 changes: 33 additions & 3 deletions pyupgrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@

_KEYWORDS = frozenset(keyword.kwlist)

_EXPR_NEEDS_PARENS: Tuple[Type[ast.expr], ...] = (
ast.Await, ast.BinOp, ast.BoolOp, ast.Compare, ast.GeneratorExp, ast.IfExp,
ast.Lambda, ast.UnaryOp,
)
if sys.version_info >= (3, 8): # pragma: no cover (py38+)
_EXPR_NEEDS_PARENS += (ast.NamedExpr,)


def parse_format(s: str) -> Tuple[DotFormatPart, ...]:
"""Makes the empty string not a special case. In the stdlib, there's
Expand Down Expand Up @@ -1103,7 +1110,6 @@ def _fix_percent_format(contents_text: str) -> str:
'u': '{args[0]}',
'byte2int': '{args[0]}[0]',
'indexbytes': '{args[0]}[{rest}]',
'int2byte': 'bytes(({args[0]},))',
'iteritems': '{args[0]}.items()',
'iterkeys': '{args[0]}.keys()',
'itervalues': '{args[0]}.values()',
Expand All @@ -1122,6 +1128,7 @@ def _fix_percent_format(contents_text: str) -> str:
'assertRaisesRegex': '{args[0]}.assertRaisesRegex({rest})',
'assertRegex': '{args[0]}.assertRegex({rest})',
}
SIX_INT2BYTE_TMPL = 'bytes(({args[0]},))'
SIX_B_TMPL = 'b{args[0]}'
WITH_METACLASS_NO_BASES_TMPL = 'metaclass={args[0]}'
WITH_METACLASS_BASES_TMPL = '{rest}, metaclass={args[0]}'
Expand Down Expand Up @@ -1244,6 +1251,7 @@ def __init__(self, keep_mock: bool) -> None:
self.six_add_metaclass: Set[Offset] = set()
self.six_b: Set[Offset] = set()
self.six_calls: Dict[Offset, ast.Call] = {}
self.six_calls_int2byte: Set[Offset] = set()
self.six_iter: Dict[Offset, ast.Call] = {}
self._previous_node: Optional[ast.AST] = None
self.six_raise_from: Set[Offset] = set()
Expand Down Expand Up @@ -1534,8 +1542,18 @@ def visit_Call(self, node: ast.Call) -> None:
self.six_type_ctx[_ast_to_offset(node.args[1])] = arg
elif self._is_six(node.func, ('b', 'ensure_binary')):
self.six_b.add(_ast_to_offset(node))
elif self._is_six(node.func, SIX_CALLS) and not _starargs(node):
elif (
self._is_six(node.func, SIX_CALLS) and
node.args and
not _starargs(node)
):
self.six_calls[_ast_to_offset(node)] = node
elif (
self._is_six(node.func, ('int2byte',)) and
node.args and
not _starargs(node)
):
self.six_calls_int2byte.add(_ast_to_offset(node))
elif (
isinstance(node.func, ast.Name) and
node.func.id == 'next' and
Expand Down Expand Up @@ -2006,8 +2024,12 @@ def _replace_call(
end: int,
args: List[Tuple[int, int]],
tmpl: str,
*,
parens: Sequence[int] = (),
) -> None:
arg_strs = [_arg_str(tokens, *arg) for arg in args]
for paren in parens:
arg_strs[paren] = f'({arg_strs[paren]})'

start_rest = args[0][1] + 1
while (
Expand Down Expand Up @@ -2062,6 +2084,7 @@ def _fix_py3_plus(
visitor.six_add_metaclass,
visitor.six_b,
visitor.six_calls,
visitor.six_calls_int2byte,
visitor.six_iter,
visitor.six_raise_from,
visitor.six_reraise,
Expand Down Expand Up @@ -2167,7 +2190,14 @@ def _replace(i: int, mapping: Dict[str, str], node: NameOrAttr) -> None:
call = visitor.six_calls[token.offset]
assert isinstance(call.func, (ast.Name, ast.Attribute))
template = _get_tmpl(SIX_CALLS, call.func)
_replace_call(tokens, i, end, func_args, template)
if isinstance(call.args[0], _EXPR_NEEDS_PARENS):
_replace_call(tokens, i, end, func_args, template, parens=(0,))
else:
_replace_call(tokens, i, end, func_args, template)
elif token.offset in visitor.six_calls_int2byte:
j = _find_open_paren(tokens, i)
func_args, end = _parse_call_args(tokens, j)
_replace_call(tokens, i, end, func_args, SIX_INT2BYTE_TMPL)
elif token.offset in visitor.six_raise_from:
j = _find_open_paren(tokens, i)
func_args, end = _parse_call_args(tokens, j)
Expand Down
67 changes: 67 additions & 0 deletions tests/six_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys

import pytest

from pyupgrade import _fix_py3_plus
Expand Down Expand Up @@ -42,6 +44,7 @@
id='relative import might not be six',
),
('traceback.format_exc(*sys.exc_info())'),
pytest.param('six.iteritems()', id='wrong argument count'),
),
)
def test_fix_six_noop(s):
Expand Down Expand Up @@ -382,12 +385,76 @@ def test_fix_six_noop(s):
'print(next(iter({1:2}.values())))\n',
id='six.itervalues inside next(...)',
),
pytest.param(
'for _ in six.itervalues({} or y): pass',
'for _ in ({} or y).values(): pass',
id='needs parenthesizing for BoolOp',
),
pytest.param(
'for _ in six.itervalues({} | y): pass',
'for _ in ({} | y).values(): pass',
id='needs parenthesizing for BinOp',
),
pytest.param(
'six.int2byte(x | y)',
'bytes((x | y,))',
id='no parenthesize for int2byte BinOP',
),
pytest.param(
'six.iteritems(+weird_dct)',
'(+weird_dct).items()',
id='needs parenthesizing for UnaryOp',
),
pytest.param(
'x = six.get_method_function(lambda: x)',
'x = (lambda: x).__func__',
id='needs parenthesizing for Lambda',
),
pytest.param(
'for _ in six.itervalues(x if 1 else y): pass',
'for _ in (x if 1 else y).values(): pass',
id='needs parenthesizing for IfExp',
),
# this one is bogus / impossible, but parenthesize it anyway
pytest.param(
'six.itervalues(x for x in y)',
'(x for x in y).values()',
id='needs parentehsizing for GeneratorExp',
),
pytest.param(
'async def f():\n'
' return six.iteritems(await y)\n',
'async def f():\n'
' return (await y).items()\n',
id='needs parenthesizing for Await',
),
# this one is bogus / impossible, but parenthesize it anyway
pytest.param(
'six.itervalues(x < y)',
'(x < y).values()',
id='needs parentehsizing for Compare',
),
),
)
def test_fix_six(s, expected):
assert _fix_py3_plus(s, (3,)) == expected


@pytest.mark.xfail(sys.version_info < (3, 8), reason='walrus')
@pytest.mark.parametrize(
('s', 'expected'),
(
pytest.param(
'for _ in six.itervalues(x := y): pass',
'for _ in (x := y).values(): pass',
id='needs parenthesizing for NamedExpr',
),
),
)
def test_fix_six_py38_plus(s, expected):
assert _fix_py3_plus(s, (3,)) == expected


@pytest.mark.parametrize(
('s', 'expected'),
(
Expand Down

0 comments on commit b1773d9

Please sign in to comment.