Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure import lines with ';' are properly split and auto corrected #177

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 17 additions & 2 deletions src/autoimport/model.py
Expand Up @@ -4,7 +4,7 @@
import inspect
import os
import re
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

import autoflake
from pyflakes.messages import UndefinedExport, UndefinedName, UnusedImport
Expand Down Expand Up @@ -213,7 +213,7 @@ def _move_imports_to_top(self) -> None:
multiline_string = False
code_lines_to_remove = []

for line in self.code:
for line_num, line in enumerate(self.code):
# Process multiline strings, taking care not to catch single line strings
# defined with three quotes.
if re.match(r"^.*?(\"|\'){3}.*?(?!\1{3})$", line) and not re.match(
Expand All @@ -232,6 +232,13 @@ def _move_imports_to_top(self) -> None:
if re.match(r".*?# ?noqa:.*?autoimport.*", line):
continue

# process lines using separation markers
if ";" in line:
import_line, next_line = self._split_separation_line(line)
self.imports.append(import_line.strip())
self.code[line_num] = next_line
continue

# Process multiline import statements
if "(" in line:
multiline_import = True
Expand All @@ -247,6 +254,14 @@ def _move_imports_to_top(self) -> None:
for line in code_lines_to_remove:
self.code.remove(line)

def _split_separation_line(self, line: str) -> Tuple[str, str]:
"""Split separation lines into two and return both lines back."""
first_line, next_line = line.split(";")
# add correct number of leading spaces
num_lspaces = len(first_line) - len(first_line.lstrip())
next_line = f"{' ' * num_lspaces}{next_line.lstrip()}"
return first_line, next_line

def _fix_flake_import_errors(self) -> None:
"""Fix python source code to correct missed or unused import statements."""
error_messages = autoflake.check(self._join_code())
Expand Down
50 changes: 50 additions & 0 deletions tests/unit/test_services.py
Expand Up @@ -1041,3 +1041,53 @@ def test_file_with_non_used_multiline_import() -> None:
result = fix_code(source)

assert result == "\n"


def test_file_with_import_and_seperator() -> None:
"""Ensure import lines with seperators are fixed correctly."""
source = dedent(
"""
a = 1
import pdb;pdb.set_trace()
b = 2
"""
)
expected = dedent(
"""
import pdb

a = 1
pdb.set_trace()
b = 2
"""
).replace("\n", "", 1)

result = fix_code(source)

assert result == expected


def test_file_with_import_and_seperator_indentation() -> None:
"""Ensure import lines with seperators are fixed correctly when indented."""
source = dedent(
"""
Class Person:
import pdb; pdb.set_trace()
def say_hi(self):
print('hi')
"""
)
expected = dedent(
"""
import pdb

Class Person:
pdb.set_trace()
def say_hi(self):
print('hi')
"""
).replace("\n", "", 1)

result = fix_code(source)

assert result == expected