Skip to content

Commit

Permalink
Merge pull request #673 from asottile/mock-imports
Browse files Browse the repository at this point in the history
combine mock and imports plugins
  • Loading branch information
asottile committed Jul 10, 2022
2 parents e323e34 + 1a904a7 commit 1f86c30
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 206 deletions.
98 changes: 72 additions & 26 deletions pyupgrade/_plugins/imports.py
Expand Up @@ -223,6 +223,8 @@
@functools.lru_cache(maxsize=None)
def _for_version(
version: tuple[int, ...],
*,
keep_mock: bool,
) -> tuple[
Mapping[str, set[str]],
Mapping[tuple[str, str], str],
Expand All @@ -238,10 +240,15 @@ def _for_version(
if ver <= version:
exact.update(ver_exact)

mods = {}
if version >= (3,):
mods = REPLACE_MODS
else:
mods = {}
mods.update(REPLACE_MODS)
if not keep_mock:
exact['mock', 'mock'] = 'unittest'
mods.update({
'mock': 'unittest.mock',
'mock.mock': 'unittest.mock',
})

return removals, exact, mods

Expand Down Expand Up @@ -347,26 +354,26 @@ def _replace_from_mixed(
added_from_imports[mod].append(_alias_to_s(alias))
bisect.insort(removal_idxs, idx)

added_imports = []
new_imports = []
for idx, new_mod, alias in module_moves:
new_mod, _, new_sym = new_mod.rpartition('.')
new_alias = ast.alias(name=new_sym, asname=alias.asname)
if new_mod:
added_from_imports[new_mod].append(_alias_to_s(new_alias))
else:
added_imports.append(f'import {_alias_to_s(new_alias)}\n')
new_imports.append(f'import {_alias_to_s(new_alias)}\n')
bisect.insort(removal_idxs, idx)

added_imports.extend(
new_imports.extend(
f'{indent}from {mod} import {", ".join(names)}\n'
for mod, names in added_from_imports.items()
)
added_imports.sort()
new_imports.sort()

if added_imports and tokens[parsed.end - 1].src != '\n':
added_imports.insert(0, '\n')
if new_imports and tokens[parsed.end - 1].src != '\n':
new_imports.insert(0, '\n')

tokens[parsed.end:parsed.end] = [Token('CODE', ''.join(added_imports))]
tokens[parsed.end:parsed.end] = [Token('CODE', ''.join(new_imports))]

# all names rewritten -- delete import
if len(parsed.names) == len(removal_idxs):
Expand All @@ -381,7 +388,10 @@ def visit_ImportFrom(
node: ast.ImportFrom,
parent: ast.AST,
) -> Iterable[tuple[Offset, TokenFunc]]:
removals, exact, mods = _for_version(state.settings.min_version)
removals, exact, mods = _for_version(
state.settings.min_version,
keep_mock=state.settings.keep_mock,
)

# we don't have any relative rewrites
if node.level != 0 or node.module is None:
Expand Down Expand Up @@ -413,9 +423,6 @@ def visit_ImportFrom(

if len(removal_idxs) == len(node.names):
yield ast_to_offset(node), _remove_import
elif mod in mods:
func = functools.partial(_replace_from_modname, modname=mods[mod])
yield ast_to_offset(node), func
elif (
len(exact_moves) == len(node.names) and
len({mod for _, mod, _ in exact_moves}) == 1
Expand All @@ -431,13 +438,17 @@ def visit_ImportFrom(
module_moves=module_moves,
)
yield ast_to_offset(node), func
elif mod in mods:
func = functools.partial(_replace_from_modname, modname=mods[mod])
yield ast_to_offset(node), func


def _replace_import(
i: int,
tokens: list[Token],
*,
exact_moves: list[tuple[int, str, str]],
exact_moves: list[tuple[int, str, ast.alias]],
to_from: list[tuple[int, str, str, ast.alias]],
) -> None:
end = find_end(tokens, i)

Expand All @@ -458,9 +469,32 @@ def _replace_import(
assert start_idx is not None and end_idx is not None
parts.append((start_idx, end_idx))

for i, new_mod, asname in reversed(exact_moves):
new_alias = ast.alias(name=new_mod, asname=asname)
tokens[slice(*parts[i])] = [Token('CODE', _alias_to_s(new_alias))]
for idx, new_mod, alias in reversed(exact_moves):
new_alias = ast.alias(name=new_mod, asname=alias.asname)
tokens[slice(*parts[idx])] = [Token('CODE', _alias_to_s(new_alias))]

new_imports = sorted(
f'from {new_mod} import '
f'{_alias_to_s(ast.alias(name=new_sym, asname=alias.asname))}\n'
for _, new_mod, new_sym, alias in to_from
)

if new_imports and tokens[end - 1].src != '\n':
new_imports.insert(0, '\n')

tokens[end:end] = [Token('CODE', ''.join(new_imports))]

if len(to_from) == len(parts):
del tokens[i:end]
else:
for idx, _, _, _ in reversed(to_from):
if idx == 0: # look forward until next name and del
del tokens[parts[idx][0]:parts[idx + 1][0]]
else: # look backward for comma and del
j = part_end = parts[idx][0]
while tokens[j].src != ',':
j -= 1
del tokens[j:part_end + 1]


@register(ast.Import)
Expand All @@ -469,15 +503,27 @@ def visit_Import(
node: ast.Import,
parent: ast.AST,
) -> Iterable[tuple[Offset, TokenFunc]]:
_, _, mods = _for_version(state.settings.min_version)
_, _, mods = _for_version(
state.settings.min_version,
keep_mock=state.settings.keep_mock,
)

to_from = []
exact_moves = []
for i, alias in enumerate(node.names):
if alias.asname is not None:
new_mod = mods.get(alias.name)
if new_mod is not None:
exact_moves.append((i, new_mod, alias.asname))

if exact_moves:
func = functools.partial(_replace_import, exact_moves=exact_moves)
new_mod = mods.get(alias.name)
if new_mod is not None:
alias_base, _, _ = alias.name.partition('.')
new_mod_base, _, new_sym = new_mod.rpartition('.')
if new_mod_base and new_sym == alias_base:
to_from.append((i, new_mod_base, new_sym, alias))
elif alias.asname is not None:
exact_moves.append((i, new_mod, alias))

if to_from or exact_moves:
func = functools.partial(
_replace_import,
exact_moves=exact_moves,
to_from=to_from,
)
yield ast_to_offset(node), func
58 changes: 0 additions & 58 deletions pyupgrade/_plugins/mock.py
Expand Up @@ -12,70 +12,12 @@
from pyupgrade._data import TokenFunc
from pyupgrade._token_helpers import find_token

MOCK_MODULES = frozenset(('mock', 'mock.mock'))


def _fix_import_from_mock(i: int, tokens: list[Token]) -> None:
j = find_token(tokens, i, 'mock')
if (
j + 2 < len(tokens) and
tokens[j + 1].src == '.' and
tokens[j + 2].src == 'mock'
):
k = j + 2
else:
k = j
src = 'unittest.mock'
tokens[j:k + 1] = [tokens[j]._replace(name='NAME', src=src)]


def _fix_import_mock(i: int, tokens: list[Token]) -> None:
j = find_token(tokens, i, 'mock')
if (
j + 2 < len(tokens) and
tokens[j + 1].src == '.' and
tokens[j + 2].src == 'mock'
):
j += 2
src = 'from unittest import mock'
tokens[i:j + 1] = [tokens[j]._replace(name='NAME', src=src)]


def _fix_mock_mock(i: int, tokens: list[Token]) -> None:
j = find_token(tokens, i + 1, 'mock')
del tokens[i + 1:j + 1]


@register(ast.ImportFrom)
def visit_ImportFrom(
state: State,
node: ast.ImportFrom,
parent: ast.AST,
) -> Iterable[tuple[Offset, TokenFunc]]:
if (
state.settings.min_version >= (3,) and
not state.settings.keep_mock and
not node.level and
node.module in MOCK_MODULES
):
yield ast_to_offset(node), _fix_import_from_mock


@register(ast.Import)
def visit_Import(
state: State,
node: ast.Import,
parent: ast.AST,
) -> Iterable[tuple[Offset, TokenFunc]]:
if (
state.settings.min_version >= (3,) and
not state.settings.keep_mock and
len(node.names) == 1 and
node.names[0].name in MOCK_MODULES
):
yield ast_to_offset(node), _fix_import_mock


@register(ast.Attribute)
def visit_Attribute(
state: State,
Expand Down
8 changes: 1 addition & 7 deletions testing/generate-imports
Expand Up @@ -67,6 +67,7 @@ def _replacements() -> tuple[
replaces = reorder_python_imports.Replacements.make([
reorder_python_imports._validate_replace_import(s)
for s in vals
if 'mock' not in s
])
if replaces.exact:
exact[ver].update(replaces.exact)
Expand All @@ -81,13 +82,6 @@ def main() -> int:

exact, mods = _replacements()

# for now, let the mock plugin continue to handle mock stuff
exact[(3,)] = {
(mod, attr): new
for (mod, attr), new in exact[(3,)].items()
if mod not in {'mock', 'mock.mock'}
}

print(f'# GENERATED VIA {os.path.basename(sys.argv[0])}')
print(f'# Using reorder-python-imports=={version}')
print('REMOVALS = {')
Expand Down
50 changes: 50 additions & 0 deletions tests/features/import_replaces_test.py
Expand Up @@ -46,6 +46,17 @@ def test_import_replaces_noop(s, min_version):
assert _fix_plugins(s, settings=Settings(min_version=min_version)) == s


def test_mock_noop_keep_mock():
"""This would've been rewritten if keep_mock were False"""
s = (
'from mock import patch\n'
'\n'
'patch("func")'
)
settings = Settings(min_version=(3,), keep_mock=True)
assert _fix_plugins(s, settings=settings) == s


@pytest.mark.parametrize(
('s', 'min_version', 'expected'),
(
Expand Down Expand Up @@ -219,6 +230,45 @@ def test_import_replaces_noop(s, min_version):
'import contextlib, xml.etree.ElementTree as ET\n',
id='can rewrite multiple import imports',
),
pytest.param(
'import mock\n',
(3,),
'from unittest import mock\n',
id='rewrites mock import',
),
pytest.param(
'import mock.mock\n',
(3,),
'from unittest import mock\n',
id='rewrites mock.mock import',
),
pytest.param(
'import contextlib, mock, sys\n',
(3,),
'import contextlib, sys\n'
'from unittest import mock\n',
id='mock rewriting multiple imports in middle',
),
pytest.param(
'import mock, sys\n',
(3,),
'import sys\n'
'from unittest import mock\n',
id='mock rewriting multiple imports at beginning',
),
pytest.param(
'import mock, sys',
(3,),
'import sys\n'
'from unittest import mock\n',
id='adds import-import no eol',
),
pytest.param(
'from mock import mock\n',
(3,),
'from unittest import mock\n',
id='mock import mock import',
),
),
)
def test_import_replaces(s, min_version, expected):
Expand Down

0 comments on commit 1f86c30

Please sign in to comment.