Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix replacement of entire indented import #680

Merged
merged 1 commit into from Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 17 additions & 3 deletions pyupgrade/_plugins/imports.py
Expand Up @@ -17,6 +17,7 @@
from pyupgrade._data import TokenFunc
from pyupgrade._token_helpers import find_end
from pyupgrade._token_helpers import find_token
from pyupgrade._token_helpers import has_space_before
from pyupgrade._token_helpers import indented_amount

# GENERATED VIA generate-imports
Expand Down Expand Up @@ -259,13 +260,19 @@ def _remove_import(i: int, tokens: list[Token]) -> None:


class FromImport(NamedTuple):
start: int
mod_start: int
mod_end: int
names: tuple[int, ...]
end: int

@classmethod
def parse(cls, i: int, tokens: list[Token]) -> FromImport:
if has_space_before(i, tokens):
start = i - 1
else:
start = i

j = i + 1
# XXX: does not handle explicit relative imports
while tokens[j].name != 'NAME':
Expand All @@ -290,7 +297,10 @@ def parse(cls, i: int, tokens: list[Token]) -> FromImport:
if tokens[names[i]].src == 'as':
del names[i:i + 2]

return cls(mod_start, mod_end + 1, tuple(names), end)
return cls(start, mod_start, mod_end + 1, tuple(names), end)

def remove_self(self, tokens: list[Token]) -> None:
del tokens[self.start:self.end]

def replace_modname(self, tokens: list[Token], modname: str) -> None:
tokens[self.mod_start:self.mod_end] = [Token('CODE', modname)]
Expand Down Expand Up @@ -365,7 +375,7 @@ def _replace_from_mixed(

# all names rewritten -- delete import
if len(parsed.names) == len(removal_idxs):
del tokens[i:parsed.end]
parsed.remove_self(tokens)
else:
parsed.remove_parts(tokens, removal_idxs)

Expand Down Expand Up @@ -443,6 +453,10 @@ def _replace_import(
except ValueError:
return

if has_space_before(i, tokens):
start = i - 1
else:
start = i
end = find_end(tokens, i)

parts = []
Expand Down Expand Up @@ -478,7 +492,7 @@ def _replace_import(
tokens[end:end] = [Token('CODE', ''.join(new_imports))]

if len(to_from) == len(parts):
del tokens[i:end]
del tokens[start:end]
else:
for idx, _, _, _ in reversed(to_from):
if idx == 0: # look forward until next name and del
Expand Down
8 changes: 6 additions & 2 deletions pyupgrade/_token_helpers.py
Expand Up @@ -517,11 +517,15 @@ def replace_list_comp_brackets(i: int, tokens: list[Token]) -> None:
tokens[start] = Token('OP', '(')


def has_space_before(i: int, tokens: list[Token]) -> bool:
return i >= 1 and tokens[i - 1].name in {UNIMPORTANT_WS, 'INDENT'}


def indented_amount(i: int, tokens: list[Token]) -> str:
if i == 0:
return ''
elif i >= 2 and tokens[i - 1].name in {UNIMPORTANT_WS, 'INDENT'}:
if tokens[i - 2].name in {'NL', 'NEWLINE', 'DEDENT'}:
elif has_space_before(i, tokens):
if i >= 2 and tokens[i - 2].name in {'NL', 'NEWLINE', 'DEDENT'}:
return tokens[i - 1].src
else: # inline import
raise ValueError('not at beginning of line')
Expand Down
16 changes: 16 additions & 0 deletions tests/features/import_replaces_test.py
Expand Up @@ -134,6 +134,14 @@ def test_mock_noop_keep_mock():
'from collections.abc import Mapping as mapping\n',
id='new import with aliased name',
),
pytest.param(
'if True:\n'
' from xml.etree import cElementTree as ET\n',
(3,),
'if True:\n'
' from xml.etree import ElementTree as ET\n',
id='indented and full import replaced',
),
pytest.param(
'if True:\n'
' from collections import Mapping, Counter\n',
Expand All @@ -152,6 +160,14 @@ def test_mock_noop_keep_mock():
' import queue\n',
id='indented import-import being added',
),
pytest.param(
'if True:\n'
' import mock\n',
(3,),
'if True:\n'
' from unittest import mock\n',
id='indented import-import rewritten',
),
pytest.param(
'if True:\n'
' if True:\n'
Expand Down