From 80a63c0026ea0964ccaf1752914d5877d7a2341a Mon Sep 17 00:00:00 2001 From: Anthony Sottile Date: Mon, 11 Jul 2022 05:21:04 -0700 Subject: [PATCH] fix replacement of entire indented import --- pyupgrade/_plugins/imports.py | 20 +++++++++++++++++--- pyupgrade/_token_helpers.py | 8 ++++++-- tests/features/import_replaces_test.py | 16 ++++++++++++++++ 3 files changed, 39 insertions(+), 5 deletions(-) diff --git a/pyupgrade/_plugins/imports.py b/pyupgrade/_plugins/imports.py index 6d24864a..989553aa 100644 --- a/pyupgrade/_plugins/imports.py +++ b/pyupgrade/_plugins/imports.py @@ -17,6 +17,7 @@ from pyupgrade._data import TokenFunc from pyupgrade._token_helpers import find_end from pyupgrade._token_helpers import find_token +from pyupgrade._token_helpers import has_space_before from pyupgrade._token_helpers import indented_amount # GENERATED VIA generate-imports @@ -259,6 +260,7 @@ def _remove_import(i: int, tokens: list[Token]) -> None: class FromImport(NamedTuple): + start: int mod_start: int mod_end: int names: tuple[int, ...] @@ -266,6 +268,11 @@ class FromImport(NamedTuple): @classmethod def parse(cls, i: int, tokens: list[Token]) -> FromImport: + if has_space_before(i, tokens): + start = i - 1 + else: + start = i + j = i + 1 # XXX: does not handle explicit relative imports while tokens[j].name != 'NAME': @@ -290,7 +297,10 @@ def parse(cls, i: int, tokens: list[Token]) -> FromImport: if tokens[names[i]].src == 'as': del names[i:i + 2] - return cls(mod_start, mod_end + 1, tuple(names), end) + return cls(start, mod_start, mod_end + 1, tuple(names), end) + + def remove_self(self, tokens: list[Token]) -> None: + del tokens[self.start:self.end] def replace_modname(self, tokens: list[Token], modname: str) -> None: tokens[self.mod_start:self.mod_end] = [Token('CODE', modname)] @@ -365,7 +375,7 @@ def _replace_from_mixed( # all names rewritten -- delete import if len(parsed.names) == len(removal_idxs): - del tokens[i:parsed.end] + parsed.remove_self(tokens) else: parsed.remove_parts(tokens, removal_idxs) @@ -443,6 +453,10 @@ def _replace_import( except ValueError: return + if has_space_before(i, tokens): + start = i - 1 + else: + start = i end = find_end(tokens, i) parts = [] @@ -478,7 +492,7 @@ def _replace_import( tokens[end:end] = [Token('CODE', ''.join(new_imports))] if len(to_from) == len(parts): - del tokens[i:end] + del tokens[start:end] else: for idx, _, _, _ in reversed(to_from): if idx == 0: # look forward until next name and del diff --git a/pyupgrade/_token_helpers.py b/pyupgrade/_token_helpers.py index 8a8e916f..397b77e1 100644 --- a/pyupgrade/_token_helpers.py +++ b/pyupgrade/_token_helpers.py @@ -517,11 +517,15 @@ def replace_list_comp_brackets(i: int, tokens: list[Token]) -> None: tokens[start] = Token('OP', '(') +def has_space_before(i: int, tokens: list[Token]) -> bool: + return i >= 1 and tokens[i - 1].name in {UNIMPORTANT_WS, 'INDENT'} + + def indented_amount(i: int, tokens: list[Token]) -> str: if i == 0: return '' - elif i >= 2 and tokens[i - 1].name in {UNIMPORTANT_WS, 'INDENT'}: - if tokens[i - 2].name in {'NL', 'NEWLINE', 'DEDENT'}: + elif has_space_before(i, tokens): + if i >= 2 and tokens[i - 2].name in {'NL', 'NEWLINE', 'DEDENT'}: return tokens[i - 1].src else: # inline import raise ValueError('not at beginning of line') diff --git a/tests/features/import_replaces_test.py b/tests/features/import_replaces_test.py index d53acbe0..f31b5529 100644 --- a/tests/features/import_replaces_test.py +++ b/tests/features/import_replaces_test.py @@ -134,6 +134,14 @@ def test_mock_noop_keep_mock(): 'from collections.abc import Mapping as mapping\n', id='new import with aliased name', ), + pytest.param( + 'if True:\n' + ' from xml.etree import cElementTree as ET\n', + (3,), + 'if True:\n' + ' from xml.etree import ElementTree as ET\n', + id='indented and full import replaced', + ), pytest.param( 'if True:\n' ' from collections import Mapping, Counter\n', @@ -152,6 +160,14 @@ def test_mock_noop_keep_mock(): ' import queue\n', id='indented import-import being added', ), + pytest.param( + 'if True:\n' + ' import mock\n', + (3,), + 'if True:\n' + ' from unittest import mock\n', + id='indented import-import rewritten', + ), pytest.param( 'if True:\n' ' if True:\n'