diff --git a/pyupgrade/_plugins/imports.py b/pyupgrade/_plugins/imports.py index 6078b355..2883f958 100644 --- a/pyupgrade/_plugins/imports.py +++ b/pyupgrade/_plugins/imports.py @@ -223,6 +223,8 @@ @functools.lru_cache(maxsize=None) def _for_version( version: tuple[int, ...], + *, + keep_mock: bool, ) -> tuple[ Mapping[str, set[str]], Mapping[tuple[str, str], str], @@ -238,10 +240,15 @@ def _for_version( if ver <= version: exact.update(ver_exact) + mods = {} if version >= (3,): - mods = REPLACE_MODS - else: - mods = {} + mods.update(REPLACE_MODS) + if not keep_mock: + exact['mock', 'mock'] = 'unittest' + mods.update({ + 'mock': 'unittest.mock', + 'mock.mock': 'unittest.mock', + }) return removals, exact, mods @@ -347,26 +354,26 @@ def _replace_from_mixed( added_from_imports[mod].append(_alias_to_s(alias)) bisect.insort(removal_idxs, idx) - added_imports = [] + new_imports = [] for idx, new_mod, alias in module_moves: new_mod, _, new_sym = new_mod.rpartition('.') new_alias = ast.alias(name=new_sym, asname=alias.asname) if new_mod: added_from_imports[new_mod].append(_alias_to_s(new_alias)) else: - added_imports.append(f'import {_alias_to_s(new_alias)}\n') + new_imports.append(f'import {_alias_to_s(new_alias)}\n') bisect.insort(removal_idxs, idx) - added_imports.extend( + new_imports.extend( f'{indent}from {mod} import {", ".join(names)}\n' for mod, names in added_from_imports.items() ) - added_imports.sort() + new_imports.sort() - if added_imports and tokens[parsed.end - 1].src != '\n': - added_imports.insert(0, '\n') + if new_imports and tokens[parsed.end - 1].src != '\n': + new_imports.insert(0, '\n') - tokens[parsed.end:parsed.end] = [Token('CODE', ''.join(added_imports))] + tokens[parsed.end:parsed.end] = [Token('CODE', ''.join(new_imports))] # all names rewritten -- delete import if len(parsed.names) == len(removal_idxs): @@ -381,7 +388,10 @@ def visit_ImportFrom( node: ast.ImportFrom, parent: ast.AST, ) -> Iterable[tuple[Offset, TokenFunc]]: - removals, exact, mods = _for_version(state.settings.min_version) + removals, exact, mods = _for_version( + state.settings.min_version, + keep_mock=state.settings.keep_mock, + ) # we don't have any relative rewrites if node.level != 0 or node.module is None: @@ -413,9 +423,6 @@ def visit_ImportFrom( if len(removal_idxs) == len(node.names): yield ast_to_offset(node), _remove_import - elif mod in mods: - func = functools.partial(_replace_from_modname, modname=mods[mod]) - yield ast_to_offset(node), func elif ( len(exact_moves) == len(node.names) and len({mod for _, mod, _ in exact_moves}) == 1 @@ -431,13 +438,17 @@ def visit_ImportFrom( module_moves=module_moves, ) yield ast_to_offset(node), func + elif mod in mods: + func = functools.partial(_replace_from_modname, modname=mods[mod]) + yield ast_to_offset(node), func def _replace_import( i: int, tokens: list[Token], *, - exact_moves: list[tuple[int, str, str]], + exact_moves: list[tuple[int, str, ast.alias]], + to_from: list[tuple[int, str, str, ast.alias]], ) -> None: end = find_end(tokens, i) @@ -458,9 +469,32 @@ def _replace_import( assert start_idx is not None and end_idx is not None parts.append((start_idx, end_idx)) - for i, new_mod, asname in reversed(exact_moves): - new_alias = ast.alias(name=new_mod, asname=asname) - tokens[slice(*parts[i])] = [Token('CODE', _alias_to_s(new_alias))] + for idx, new_mod, alias in reversed(exact_moves): + new_alias = ast.alias(name=new_mod, asname=alias.asname) + tokens[slice(*parts[idx])] = [Token('CODE', _alias_to_s(new_alias))] + + new_imports = sorted( + f'from {new_mod} import ' + f'{_alias_to_s(ast.alias(name=new_sym, asname=alias.asname))}\n' + for _, new_mod, new_sym, alias in to_from + ) + + if new_imports and tokens[end - 1].src != '\n': + new_imports.insert(0, '\n') + + tokens[end:end] = [Token('CODE', ''.join(new_imports))] + + if len(to_from) == len(parts): + del tokens[i:end] + else: + for idx, _, _, _ in reversed(to_from): + if idx == 0: # look forward until next name and del + del tokens[parts[idx][0]:parts[idx + 1][0]] + else: # look backward for comma and del + j = part_end = parts[idx][0] + while tokens[j].src != ',': + j -= 1 + del tokens[j:part_end + 1] @register(ast.Import) @@ -469,15 +503,27 @@ def visit_Import( node: ast.Import, parent: ast.AST, ) -> Iterable[tuple[Offset, TokenFunc]]: - _, _, mods = _for_version(state.settings.min_version) + _, _, mods = _for_version( + state.settings.min_version, + keep_mock=state.settings.keep_mock, + ) + to_from = [] exact_moves = [] for i, alias in enumerate(node.names): - if alias.asname is not None: - new_mod = mods.get(alias.name) - if new_mod is not None: - exact_moves.append((i, new_mod, alias.asname)) - - if exact_moves: - func = functools.partial(_replace_import, exact_moves=exact_moves) + new_mod = mods.get(alias.name) + if new_mod is not None: + alias_base, _, _ = alias.name.partition('.') + new_mod_base, _, new_sym = new_mod.rpartition('.') + if new_mod_base and new_sym == alias_base: + to_from.append((i, new_mod_base, new_sym, alias)) + elif alias.asname is not None: + exact_moves.append((i, new_mod, alias)) + + if to_from or exact_moves: + func = functools.partial( + _replace_import, + exact_moves=exact_moves, + to_from=to_from, + ) yield ast_to_offset(node), func diff --git a/pyupgrade/_plugins/mock.py b/pyupgrade/_plugins/mock.py index 764acf62..6e567dfa 100644 --- a/pyupgrade/_plugins/mock.py +++ b/pyupgrade/_plugins/mock.py @@ -12,70 +12,12 @@ from pyupgrade._data import TokenFunc from pyupgrade._token_helpers import find_token -MOCK_MODULES = frozenset(('mock', 'mock.mock')) - - -def _fix_import_from_mock(i: int, tokens: list[Token]) -> None: - j = find_token(tokens, i, 'mock') - if ( - j + 2 < len(tokens) and - tokens[j + 1].src == '.' and - tokens[j + 2].src == 'mock' - ): - k = j + 2 - else: - k = j - src = 'unittest.mock' - tokens[j:k + 1] = [tokens[j]._replace(name='NAME', src=src)] - - -def _fix_import_mock(i: int, tokens: list[Token]) -> None: - j = find_token(tokens, i, 'mock') - if ( - j + 2 < len(tokens) and - tokens[j + 1].src == '.' and - tokens[j + 2].src == 'mock' - ): - j += 2 - src = 'from unittest import mock' - tokens[i:j + 1] = [tokens[j]._replace(name='NAME', src=src)] - def _fix_mock_mock(i: int, tokens: list[Token]) -> None: j = find_token(tokens, i + 1, 'mock') del tokens[i + 1:j + 1] -@register(ast.ImportFrom) -def visit_ImportFrom( - state: State, - node: ast.ImportFrom, - parent: ast.AST, -) -> Iterable[tuple[Offset, TokenFunc]]: - if ( - state.settings.min_version >= (3,) and - not state.settings.keep_mock and - not node.level and - node.module in MOCK_MODULES - ): - yield ast_to_offset(node), _fix_import_from_mock - - -@register(ast.Import) -def visit_Import( - state: State, - node: ast.Import, - parent: ast.AST, -) -> Iterable[tuple[Offset, TokenFunc]]: - if ( - state.settings.min_version >= (3,) and - not state.settings.keep_mock and - len(node.names) == 1 and - node.names[0].name in MOCK_MODULES - ): - yield ast_to_offset(node), _fix_import_mock - - @register(ast.Attribute) def visit_Attribute( state: State, diff --git a/testing/generate-imports b/testing/generate-imports index 797df337..6dcb94fc 100755 --- a/testing/generate-imports +++ b/testing/generate-imports @@ -67,6 +67,7 @@ def _replacements() -> tuple[ replaces = reorder_python_imports.Replacements.make([ reorder_python_imports._validate_replace_import(s) for s in vals + if 'mock' not in s ]) if replaces.exact: exact[ver].update(replaces.exact) @@ -81,13 +82,6 @@ def main() -> int: exact, mods = _replacements() - # for now, let the mock plugin continue to handle mock stuff - exact[(3,)] = { - (mod, attr): new - for (mod, attr), new in exact[(3,)].items() - if mod not in {'mock', 'mock.mock'} - } - print(f'# GENERATED VIA {os.path.basename(sys.argv[0])}') print(f'# Using reorder-python-imports=={version}') print('REMOVALS = {') diff --git a/tests/features/import_replaces_test.py b/tests/features/import_replaces_test.py index 36599b94..5f3392c8 100644 --- a/tests/features/import_replaces_test.py +++ b/tests/features/import_replaces_test.py @@ -46,6 +46,17 @@ def test_import_replaces_noop(s, min_version): assert _fix_plugins(s, settings=Settings(min_version=min_version)) == s +def test_mock_noop_keep_mock(): + """This would've been rewritten if keep_mock were False""" + s = ( + 'from mock import patch\n' + '\n' + 'patch("func")' + ) + settings = Settings(min_version=(3,), keep_mock=True) + assert _fix_plugins(s, settings=settings) == s + + @pytest.mark.parametrize( ('s', 'min_version', 'expected'), ( @@ -219,6 +230,45 @@ def test_import_replaces_noop(s, min_version): 'import contextlib, xml.etree.ElementTree as ET\n', id='can rewrite multiple import imports', ), + pytest.param( + 'import mock\n', + (3,), + 'from unittest import mock\n', + id='rewrites mock import', + ), + pytest.param( + 'import mock.mock\n', + (3,), + 'from unittest import mock\n', + id='rewrites mock.mock import', + ), + pytest.param( + 'import contextlib, mock, sys\n', + (3,), + 'import contextlib, sys\n' + 'from unittest import mock\n', + id='mock rewriting multiple imports in middle', + ), + pytest.param( + 'import mock, sys\n', + (3,), + 'import sys\n' + 'from unittest import mock\n', + id='mock rewriting multiple imports at beginning', + ), + pytest.param( + 'import mock, sys', + (3,), + 'import sys\n' + 'from unittest import mock\n', + id='adds import-import no eol', + ), + pytest.param( + 'from mock import mock\n', + (3,), + 'from unittest import mock\n', + id='mock import mock import', + ), ), ) def test_import_replaces(s, min_version, expected): diff --git a/tests/features/mock_test.py b/tests/features/mock_test.py index 30b3121d..66c9dcd7 100644 --- a/tests/features/mock_test.py +++ b/tests/features/mock_test.py @@ -6,64 +6,9 @@ from pyupgrade._main import _fix_plugins -@pytest.mark.parametrize( - 's', - ( - pytest.param( - 'import contextlib, mock, sys\n', - id='does not rewrite multiple imports', - ), - pytest.param( - 'from .mock import patch\n', - id='leave relative imports alone', - ), - ), -) -def test_mock_noop(s): - assert _fix_plugins(s, settings=Settings(min_version=(3,))) == s - - -def test_mock_noop_keep_mock(): - """This would've been rewritten if keep_mock were False""" - s = ( - 'from mock import patch\n' - '\n' - 'patch("func")' - ) - settings = Settings(min_version=(3,), keep_mock=True) - assert _fix_plugins(s, settings=settings) == s - - @pytest.mark.parametrize( ('s', 'expected'), ( - pytest.param( - 'from mock import patch\n' - '\n' - 'patch("func")', - 'from unittest.mock import patch\n' - '\n' - 'patch("func")', - id='relative import func', - ), - pytest.param( - 'import mock\n' - '\n' - 'mock.patch("func")\n', - 'from unittest import mock\n' - '\n' - 'mock.patch("func")\n', - id='absolute import func', - ), - pytest.param( - 'from mock.mock import patch\n' - '\n' - 'patch("func")\n', - 'from unittest.mock import patch\n' - '\n' - 'patch("func")\n', - id='double mock relative import func', - ), pytest.param( 'import mock.mock\n' '\n' @@ -75,34 +20,6 @@ def test_mock_noop_keep_mock(): 'mock.patch("func2")\n', id='double mock absolute import func', ), - - pytest.param( - 'from mock import patch\n' - '\n' - 'patch.object(Foo, "func")\n', - 'from unittest.mock import patch\n' - '\n' - 'patch.object(Foo, "func")\n', - id='relative import func attr', - ), - pytest.param( - 'import mock\n' - '\n' - 'mock.patch.object(Foo, "func")\n', - 'from unittest import mock\n' - '\n' - 'mock.patch.object(Foo, "func")\n', - id='absolute import func attr', - ), - pytest.param( - 'from mock.mock import patch\n' - '\n' - 'patch.object(Foo, "func")\n', - 'from unittest.mock import patch\n' - '\n' - 'patch.object(Foo, "func")\n', - id='double mock relative import func attr', - ), pytest.param( 'import mock.mock\n' '\n' @@ -114,38 +31,6 @@ def test_mock_noop_keep_mock(): 'mock.patch.object(Foo, "func2")\n', id='double mock absolute import func attr', ), - - pytest.param( - 'from mock import patch as patch2\n', - 'from unittest.mock import patch as patch2\n', - id='relative import with as', - ), - pytest.param( - 'import mock as mock2\n', - 'from unittest import mock as mock2\n', - id='absolute import with as', - ), - pytest.param( - 'from mock.mock import patch as patch2\n', - 'from unittest.mock import patch as patch2\n', - id='double mock relative import with as', - ), - pytest.param( - 'import mock.mock as mock2\n', - 'from unittest import mock as mock2\n', - id='double mock absolute import with as', - ), - - pytest.param( - 'from mock import *\n', - 'from unittest.mock import *\n', - id='relative import with star', - ), - pytest.param( - 'from mock.mock import *\n', - 'from unittest.mock import *\n', - id='double mock relative import with star', - ), ), ) def test_fix_mock(s, expected):