From 4aa944edf2765f782fa84a31845769adcabd5497 Mon Sep 17 00:00:00 2001 From: David Ross Date: Tue, 18 Jan 2022 20:07:26 +0000 Subject: [PATCH] Ensure import lines with ';' are properly split and auto corrected --- src/autoimport/model.py | 16 +++++++++++- tests/unit/test_services.py | 50 +++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/src/autoimport/model.py b/src/autoimport/model.py index 56b0485..5db93a4 100644 --- a/src/autoimport/model.py +++ b/src/autoimport/model.py @@ -69,6 +69,20 @@ def fix(self) -> str: return self._join_code() + def _split_into_list(self, source_code: str) -> List[str]: + """Split source code into a list.""" + newly_split = [] + for line in source_code.splitlines(): + if "import " in line and ";" in line: + import_line, next_line = line.split(";") + leading_spaces = re.search(r"\S", import_line) + num_lspaces = leading_spaces.start() if leading_spaces else 0 + next_line = f"{' ' * num_lspaces}{next_line.lstrip()}" + newly_split.extend([import_line, next_line]) + else: + newly_split.append(line) + return newly_split + def _split_code(self, source_code: str) -> None: """Split the source code in the different sections. @@ -80,7 +94,7 @@ def _split_code(self, source_code: str) -> None: Args: source_code: Source code to be corrected. """ - source_code_lines = source_code.splitlines() + source_code_lines = self._split_into_list(source_code) self._extract_header(source_code_lines) self._extract_import_statements(source_code_lines) diff --git a/tests/unit/test_services.py b/tests/unit/test_services.py index 93de538..109921b 100644 --- a/tests/unit/test_services.py +++ b/tests/unit/test_services.py @@ -1021,3 +1021,53 @@ def test_file_with_import_as() -> None: result = fix_code(source) assert result == "\n" + + +def test_file_with_import_and_seperator() -> None: + """Ensure import lines with seperators are fixed correctly.""" + source = dedent( + """ + a = 1 + import pdb;pdb.set_trace() + b = 2 + """ + ) + expected = dedent( + """ + import pdb + + a = 1 + pdb.set_trace() + b = 2 + """ + ).replace("\n", "", 1) + + result = fix_code(source) + + assert result == expected + + +def test_file_with_import_and_seperator_indentation() -> None: + """Ensure import lines with seperators are fixed correctly when indented.""" + source = dedent( + """ + Class Person: + import pdb; pdb.set_trace() + def say_hi(self): + print('hi') + """ + ) + expected = dedent( + """ + import pdb + + Class Person: + pdb.set_trace() + def say_hi(self): + print('hi') + """ + ).replace("\n", "", 1) + + result = fix_code(source) + + assert result == expected