From ba4d3e4193ed3faa89ccea89b661f047cae92c58 Mon Sep 17 00:00:00 2001 From: David Ross Date: Tue, 18 Jan 2022 20:07:26 +0000 Subject: [PATCH] Ensure import lines with ';' are properly split and auto corrected --- src/autoimport/model.py | 19 ++++++++++++-- tests/unit/test_services.py | 50 +++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/src/autoimport/model.py b/src/autoimport/model.py index 56b0485..56afd59 100644 --- a/src/autoimport/model.py +++ b/src/autoimport/model.py @@ -4,7 +4,7 @@ import inspect import os import re -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import autoflake from pyflakes.messages import UndefinedExport, UndefinedName, UnusedImport @@ -213,7 +213,7 @@ def _move_imports_to_top(self) -> None: multiline_string = False code_lines_to_remove = [] - for line in self.code: + for line_num, line in enumerate(self.code): # Process multiline strings, taking care not to catch single line strings # defined with three quotes. if re.match(r"^.*?(\"|\'){3}.*?(?!\1{3})$", line) and not re.match( @@ -232,6 +232,13 @@ def _move_imports_to_top(self) -> None: if re.match(r".*?# ?noqa:.*?autoimport.*", line): continue + # process lines using separation markers + if ";" in line: + import_line, next_line = self._split_separation_line(line) + self.imports.append(import_line.strip()) + self.code[line_num] = next_line + continue + # Process multiline import statements if "(" in line: multiline_import = True @@ -247,6 +254,14 @@ def _move_imports_to_top(self) -> None: for line in code_lines_to_remove: self.code.remove(line) + def _split_separation_line(self, line: str) -> Tuple[str, str]: + """Split separation lines into two and return both lines back.""" + first_line, next_line = line.split(";") + # add correct number of leading spaces + num_lspaces = len(first_line) - len(first_line.lstrip()) + next_line = f"{' ' * num_lspaces}{next_line.lstrip()}" + return first_line, next_line + def _fix_flake_import_errors(self) -> None: """Fix python source code to correct missed or unused import statements.""" error_messages = autoflake.check(self._join_code()) diff --git a/tests/unit/test_services.py b/tests/unit/test_services.py index 93de538..109921b 100644 --- a/tests/unit/test_services.py +++ b/tests/unit/test_services.py @@ -1021,3 +1021,53 @@ def test_file_with_import_as() -> None: result = fix_code(source) assert result == "\n" + + +def test_file_with_import_and_seperator() -> None: + """Ensure import lines with seperators are fixed correctly.""" + source = dedent( + """ + a = 1 + import pdb;pdb.set_trace() + b = 2 + """ + ) + expected = dedent( + """ + import pdb + + a = 1 + pdb.set_trace() + b = 2 + """ + ).replace("\n", "", 1) + + result = fix_code(source) + + assert result == expected + + +def test_file_with_import_and_seperator_indentation() -> None: + """Ensure import lines with seperators are fixed correctly when indented.""" + source = dedent( + """ + Class Person: + import pdb; pdb.set_trace() + def say_hi(self): + print('hi') + """ + ) + expected = dedent( + """ + import pdb + + Class Person: + pdb.set_trace() + def say_hi(self): + print('hi') + """ + ).replace("\n", "", 1) + + result = fix_code(source) + + assert result == expected