Skip to content

Commit

Permalink
Ensure import lines with ';' are properly split and auto corrected
Browse files Browse the repository at this point in the history
  • Loading branch information
superDross committed Jan 23, 2022
1 parent 4941c06 commit 07e8af0
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 2 deletions.
19 changes: 17 additions & 2 deletions src/autoimport/model.py
Expand Up @@ -4,7 +4,7 @@
import inspect
import os
import re
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

import autoflake
from pyflakes.messages import UndefinedExport, UndefinedName, UnusedImport
Expand Down Expand Up @@ -213,7 +213,7 @@ def _move_imports_to_top(self) -> None:
multiline_string = False
code_lines_to_remove = []

for line in self.code:
for line_num, line in enumerate(self.code):
# Process multiline strings, taking care not to catch single line strings
# defined with three quotes.
if re.match(r"^.*?(\"|\'){3}.*?(?!\1{3})$", line) and not re.match(
Expand All @@ -232,6 +232,13 @@ def _move_imports_to_top(self) -> None:
if re.match(r".*?# ?noqa:.*?autoimport.*", line):
continue

# process lines using separation markers
if ";" in line:
import_line, next_line = self._split_separation_line(line)
self.imports.append(import_line.strip())
self.code[line_num] = next_line
continue

# Process multiline import statements
if "(" in line:
multiline_import = True
Expand All @@ -247,6 +254,14 @@ def _move_imports_to_top(self) -> None:
for line in code_lines_to_remove:
self.code.remove(line)

def _split_separation_line(self, line: str) -> Tuple[str, str]:
"""Split separation lines into two and return both lines back."""
first_line, next_line = line.split(";")
# add correct number of leading spaces
num_lspaces = len(first_line) - len(first_line.lstrip())
next_line = f"{' ' * num_lspaces}{next_line.lstrip()}"
return first_line, next_line

def _fix_flake_import_errors(self) -> None:
"""Fix python source code to correct missed or unused import statements."""
error_messages = autoflake.check(self._join_code())
Expand Down
50 changes: 50 additions & 0 deletions tests/unit/test_services.py
Expand Up @@ -1041,3 +1041,53 @@ def test_file_with_non_used_multiline_import() -> None:
result = fix_code(source)

assert result == "\n"


def test_file_with_import_and_seperator() -> None:
"""Ensure import lines with seperators are fixed correctly."""
source = dedent(
"""
a = 1
import pdb;pdb.set_trace()
b = 2
"""
)
expected = dedent(
"""
import pdb
a = 1
pdb.set_trace()
b = 2
"""
).replace("\n", "", 1)

result = fix_code(source)

assert result == expected


def test_file_with_import_and_seperator_indentation() -> None:
"""Ensure import lines with seperators are fixed correctly when indented."""
source = dedent(
"""
Class Person:
import pdb; pdb.set_trace()
def say_hi(self):
print('hi')
"""
)
expected = dedent(
"""
import pdb
Class Person:
pdb.set_trace()
def say_hi(self):
print('hi')
"""
).replace("\n", "", 1)

result = fix_code(source)

assert result == expected

0 comments on commit 07e8af0

Please sign in to comment.