Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

combine mock and imports plugins #673

Merged
merged 1 commit into from Jul 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
98 changes: 72 additions & 26 deletions pyupgrade/_plugins/imports.py
Expand Up @@ -223,6 +223,8 @@
@functools.lru_cache(maxsize=None)
def _for_version(
version: tuple[int, ...],
*,
keep_mock: bool,
) -> tuple[
Mapping[str, set[str]],
Mapping[tuple[str, str], str],
Expand All @@ -238,10 +240,15 @@ def _for_version(
if ver <= version:
exact.update(ver_exact)

mods = {}
if version >= (3,):
mods = REPLACE_MODS
else:
mods = {}
mods.update(REPLACE_MODS)
if not keep_mock:
exact['mock', 'mock'] = 'unittest'
mods.update({
'mock': 'unittest.mock',
'mock.mock': 'unittest.mock',
})

return removals, exact, mods

Expand Down Expand Up @@ -347,26 +354,26 @@ def _replace_from_mixed(
added_from_imports[mod].append(_alias_to_s(alias))
bisect.insort(removal_idxs, idx)

added_imports = []
new_imports = []
for idx, new_mod, alias in module_moves:
new_mod, _, new_sym = new_mod.rpartition('.')
new_alias = ast.alias(name=new_sym, asname=alias.asname)
if new_mod:
added_from_imports[new_mod].append(_alias_to_s(new_alias))
else:
added_imports.append(f'import {_alias_to_s(new_alias)}\n')
new_imports.append(f'import {_alias_to_s(new_alias)}\n')
bisect.insort(removal_idxs, idx)

added_imports.extend(
new_imports.extend(
f'{indent}from {mod} import {", ".join(names)}\n'
for mod, names in added_from_imports.items()
)
added_imports.sort()
new_imports.sort()

if added_imports and tokens[parsed.end - 1].src != '\n':
added_imports.insert(0, '\n')
if new_imports and tokens[parsed.end - 1].src != '\n':
new_imports.insert(0, '\n')

tokens[parsed.end:parsed.end] = [Token('CODE', ''.join(added_imports))]
tokens[parsed.end:parsed.end] = [Token('CODE', ''.join(new_imports))]

# all names rewritten -- delete import
if len(parsed.names) == len(removal_idxs):
Expand All @@ -381,7 +388,10 @@ def visit_ImportFrom(
node: ast.ImportFrom,
parent: ast.AST,
) -> Iterable[tuple[Offset, TokenFunc]]:
removals, exact, mods = _for_version(state.settings.min_version)
removals, exact, mods = _for_version(
state.settings.min_version,
keep_mock=state.settings.keep_mock,
)

# we don't have any relative rewrites
if node.level != 0 or node.module is None:
Expand Down Expand Up @@ -413,9 +423,6 @@ def visit_ImportFrom(

if len(removal_idxs) == len(node.names):
yield ast_to_offset(node), _remove_import
elif mod in mods:
func = functools.partial(_replace_from_modname, modname=mods[mod])
yield ast_to_offset(node), func
elif (
len(exact_moves) == len(node.names) and
len({mod for _, mod, _ in exact_moves}) == 1
Expand All @@ -431,13 +438,17 @@ def visit_ImportFrom(
module_moves=module_moves,
)
yield ast_to_offset(node), func
elif mod in mods:
func = functools.partial(_replace_from_modname, modname=mods[mod])
yield ast_to_offset(node), func


def _replace_import(
i: int,
tokens: list[Token],
*,
exact_moves: list[tuple[int, str, str]],
exact_moves: list[tuple[int, str, ast.alias]],
to_from: list[tuple[int, str, str, ast.alias]],
) -> None:
end = find_end(tokens, i)

Expand All @@ -458,9 +469,32 @@ def _replace_import(
assert start_idx is not None and end_idx is not None
parts.append((start_idx, end_idx))

for i, new_mod, asname in reversed(exact_moves):
new_alias = ast.alias(name=new_mod, asname=asname)
tokens[slice(*parts[i])] = [Token('CODE', _alias_to_s(new_alias))]
for idx, new_mod, alias in reversed(exact_moves):
new_alias = ast.alias(name=new_mod, asname=alias.asname)
tokens[slice(*parts[idx])] = [Token('CODE', _alias_to_s(new_alias))]

new_imports = sorted(
f'from {new_mod} import '
f'{_alias_to_s(ast.alias(name=new_sym, asname=alias.asname))}\n'
for _, new_mod, new_sym, alias in to_from
)

if new_imports and tokens[end - 1].src != '\n':
new_imports.insert(0, '\n')

tokens[end:end] = [Token('CODE', ''.join(new_imports))]

if len(to_from) == len(parts):
del tokens[i:end]
else:
for idx, _, _, _ in reversed(to_from):
if idx == 0: # look forward until next name and del
del tokens[parts[idx][0]:parts[idx + 1][0]]
else: # look backward for comma and del
j = part_end = parts[idx][0]
while tokens[j].src != ',':
j -= 1
del tokens[j:part_end + 1]


@register(ast.Import)
Expand All @@ -469,15 +503,27 @@ def visit_Import(
node: ast.Import,
parent: ast.AST,
) -> Iterable[tuple[Offset, TokenFunc]]:
_, _, mods = _for_version(state.settings.min_version)
_, _, mods = _for_version(
state.settings.min_version,
keep_mock=state.settings.keep_mock,
)

to_from = []
exact_moves = []
for i, alias in enumerate(node.names):
if alias.asname is not None:
new_mod = mods.get(alias.name)
if new_mod is not None:
exact_moves.append((i, new_mod, alias.asname))

if exact_moves:
func = functools.partial(_replace_import, exact_moves=exact_moves)
new_mod = mods.get(alias.name)
if new_mod is not None:
alias_base, _, _ = alias.name.partition('.')
new_mod_base, _, new_sym = new_mod.rpartition('.')
if new_mod_base and new_sym == alias_base:
to_from.append((i, new_mod_base, new_sym, alias))
elif alias.asname is not None:
exact_moves.append((i, new_mod, alias))

if to_from or exact_moves:
func = functools.partial(
_replace_import,
exact_moves=exact_moves,
to_from=to_from,
)
yield ast_to_offset(node), func
58 changes: 0 additions & 58 deletions pyupgrade/_plugins/mock.py
Expand Up @@ -12,70 +12,12 @@
from pyupgrade._data import TokenFunc
from pyupgrade._token_helpers import find_token

MOCK_MODULES = frozenset(('mock', 'mock.mock'))


def _fix_import_from_mock(i: int, tokens: list[Token]) -> None:
j = find_token(tokens, i, 'mock')
if (
j + 2 < len(tokens) and
tokens[j + 1].src == '.' and
tokens[j + 2].src == 'mock'
):
k = j + 2
else:
k = j
src = 'unittest.mock'
tokens[j:k + 1] = [tokens[j]._replace(name='NAME', src=src)]


def _fix_import_mock(i: int, tokens: list[Token]) -> None:
j = find_token(tokens, i, 'mock')
if (
j + 2 < len(tokens) and
tokens[j + 1].src == '.' and
tokens[j + 2].src == 'mock'
):
j += 2
src = 'from unittest import mock'
tokens[i:j + 1] = [tokens[j]._replace(name='NAME', src=src)]


def _fix_mock_mock(i: int, tokens: list[Token]) -> None:
j = find_token(tokens, i + 1, 'mock')
del tokens[i + 1:j + 1]


@register(ast.ImportFrom)
def visit_ImportFrom(
state: State,
node: ast.ImportFrom,
parent: ast.AST,
) -> Iterable[tuple[Offset, TokenFunc]]:
if (
state.settings.min_version >= (3,) and
not state.settings.keep_mock and
not node.level and
node.module in MOCK_MODULES
):
yield ast_to_offset(node), _fix_import_from_mock


@register(ast.Import)
def visit_Import(
state: State,
node: ast.Import,
parent: ast.AST,
) -> Iterable[tuple[Offset, TokenFunc]]:
if (
state.settings.min_version >= (3,) and
not state.settings.keep_mock and
len(node.names) == 1 and
node.names[0].name in MOCK_MODULES
):
yield ast_to_offset(node), _fix_import_mock


@register(ast.Attribute)
def visit_Attribute(
state: State,
Expand Down
8 changes: 1 addition & 7 deletions testing/generate-imports
Expand Up @@ -67,6 +67,7 @@ def _replacements() -> tuple[
replaces = reorder_python_imports.Replacements.make([
reorder_python_imports._validate_replace_import(s)
for s in vals
if 'mock' not in s
])
if replaces.exact:
exact[ver].update(replaces.exact)
Expand All @@ -81,13 +82,6 @@ def main() -> int:

exact, mods = _replacements()

# for now, let the mock plugin continue to handle mock stuff
exact[(3,)] = {
(mod, attr): new
for (mod, attr), new in exact[(3,)].items()
if mod not in {'mock', 'mock.mock'}
}

print(f'# GENERATED VIA {os.path.basename(sys.argv[0])}')
print(f'# Using reorder-python-imports=={version}')
print('REMOVALS = {')
Expand Down
50 changes: 50 additions & 0 deletions tests/features/import_replaces_test.py
Expand Up @@ -46,6 +46,17 @@ def test_import_replaces_noop(s, min_version):
assert _fix_plugins(s, settings=Settings(min_version=min_version)) == s


def test_mock_noop_keep_mock():
"""This would've been rewritten if keep_mock were False"""
s = (
'from mock import patch\n'
'\n'
'patch("func")'
)
settings = Settings(min_version=(3,), keep_mock=True)
assert _fix_plugins(s, settings=settings) == s


@pytest.mark.parametrize(
('s', 'min_version', 'expected'),
(
Expand Down Expand Up @@ -219,6 +230,45 @@ def test_import_replaces_noop(s, min_version):
'import contextlib, xml.etree.ElementTree as ET\n',
id='can rewrite multiple import imports',
),
pytest.param(
'import mock\n',
(3,),
'from unittest import mock\n',
id='rewrites mock import',
),
pytest.param(
'import mock.mock\n',
(3,),
'from unittest import mock\n',
id='rewrites mock.mock import',
),
pytest.param(
'import contextlib, mock, sys\n',
(3,),
'import contextlib, sys\n'
'from unittest import mock\n',
id='mock rewriting multiple imports in middle',
),
pytest.param(
'import mock, sys\n',
(3,),
'import sys\n'
'from unittest import mock\n',
id='mock rewriting multiple imports at beginning',
),
pytest.param(
'import mock, sys',
(3,),
'import sys\n'
'from unittest import mock\n',
id='adds import-import no eol',
),
pytest.param(
'from mock import mock\n',
(3,),
'from unittest import mock\n',
id='mock import mock import',
),
),
)
def test_import_replaces(s, min_version, expected):
Expand Down