From 78771148995f3a1d3a12f0177d5ab153dd6de3f0 Mon Sep 17 00:00:00 2001 From: Timur Kushukov Date: Thu, 8 Oct 2020 00:04:50 +0500 Subject: [PATCH 1/3] get imports command --- isort/__init__.py | 9 ++++++- isort/api.py | 56 ++++++++++++++++++++++++++++++++++++++++ isort/core.py | 12 +++++++++ tests/unit/test_isort.py | 41 +++++++++++++++++++++++++++++ 4 files changed, 117 insertions(+), 1 deletion(-) diff --git a/isort/__init__.py b/isort/__init__.py index 236255dd8..6f2c2c833 100644 --- a/isort/__init__.py +++ b/isort/__init__.py @@ -2,7 +2,14 @@ from . import settings from ._version import __version__ from .api import check_code_string as check_code -from .api import check_file, check_stream, place_module, place_module_with_reason +from .api import ( + check_file, + check_stream, + get_imports_stream, + get_imports_string, + place_module, + place_module_with_reason, +) from .api import sort_code_string as code from .api import sort_file as file from .api import sort_stream as stream diff --git a/isort/api.py b/isort/api.py index 5a2df6af3..aed041ce4 100644 --- a/isort/api.py +++ b/isort/api.py @@ -366,6 +366,62 @@ def sort_file( return changed +def get_imports_string( + code: str, + extension: Optional[str] = None, + config: Config = DEFAULT_CONFIG, + file_path: Optional[Path] = None, + **config_kwargs, +) -> str: + """Finds all imports within the provided code string, returning a new string with them. + + - **code**: The string of code with imports that need to be sorted. + - **extension**: The file extension that contains imports. Defaults to filename extension or py. + - **config**: The config object to use when sorting imports. + - **file_path**: The disk location where the code string was pulled from. + - ****config_kwargs**: Any config modifications. + """ + input_stream = StringIO(code) + output_stream = StringIO() + config = _config(path=file_path, config=config, **config_kwargs) + get_imports_stream( + input_stream, + output_stream, + extension=extension, + config=config, + file_path=file_path, + ) + output_stream.seek(0) + return output_stream.read() + + +def get_imports_stream( + input_stream: TextIO, + output_stream: TextIO, + extension: Optional[str] = None, + config: Config = DEFAULT_CONFIG, + file_path: Optional[Path] = None, + **config_kwargs, +) -> None: + """Finds all imports within the provided code stream, outputs to the provided output stream. + + - **input_stream**: The stream of code with imports that need to be sorted. + - **output_stream**: The stream where sorted imports should be written to. + - **extension**: The file extension that contains imports. Defaults to filename extension or py. + - **config**: The config object to use when sorting imports. + - **file_path**: The disk location where the code string was pulled from. + - ****config_kwargs**: Any config modifications. + """ + config = _config(path=file_path, config=config, **config_kwargs) + core.process( + input_stream, + output_stream, + extension=extension or (file_path and file_path.suffix.lstrip(".")) or "py", + config=config, + imports_only=True, + ) + + def _config( path: Optional[Path] = None, config: Config = DEFAULT_CONFIG, **config_kwargs ) -> Config: diff --git a/isort/core.py b/isort/core.py index 292bdc1c2..3668ae9e8 100644 --- a/isort/core.py +++ b/isort/core.py @@ -30,6 +30,7 @@ def process( output_stream: TextIO, extension: str = "py", config: Config = DEFAULT_CONFIG, + imports_only: bool = False, ) -> bool: """Parses stream identifying sections of contiguous imports and sorting them @@ -68,6 +69,7 @@ def process( stripped_line: str = "" end_of_file: bool = False verbose_output: List[str] = [] + all_imports: List[str] = [] if config.float_to_top: new_input = "" @@ -331,6 +333,11 @@ def process( parsed_content = parse.file_contents(import_section, config=config) verbose_output += parsed_content.verbose_output + all_imports.extend( + li + for li in parsed_content.in_lines + if li and li not in set(parsed_content.lines_without_imports) + ) sorted_import_section = output.sorted_imports( parsed_content, @@ -395,6 +402,11 @@ def process( for output_str in verbose_output: print(output_str) + if imports_only: + output_stream.seek(0) + output_stream.truncate(0) + output_stream.write(line_separator.join(all_imports) + line_separator) + return made_changes diff --git a/tests/unit/test_isort.py b/tests/unit/test_isort.py index 01fba9a77..f28c5de63 100644 --- a/tests/unit/test_isort.py +++ b/tests/unit/test_isort.py @@ -4913,3 +4913,44 @@ def test_combine_straight_imports() -> None: assert isort.code(test_input, combine_straight_imports=True, only_sections=True) == ( "import sys, os, math\n" "\n" "import a, b\n" ) + + +def test_get_imports_string() -> None: + test_input = ( + "import first_straight\n" + "\n" + "import second_straight\n" + "from first_from import first_from_function_1, first_from_function_2\n" + "import bad_name as good_name\n" + "from parent.some_bad_defs import bad_name_1 as ok_name_1, bad_name_2 as ok_name_2\n" + "\n" + "# isort: list\n" + "__all__ = ['b', 'c', 'a']\n" + "\n" + "def bla():\n" + " import needed_in_bla_2\n" + "\n" + "\n" + " import needed_in_bla\n" + " pass" + "\n" + "def bla_bla():\n" + " import needed_in_bla_bla\n" + "\n" + " #import not_really_an_import\n" + " pass" + "\n" + "import needed_in_end\n" + ) + result = api.get_imports_string(test_input) + assert result == ( + "import first_straight\n" + "import second_straight\n" + "from first_from import first_from_function_1, first_from_function_2\n" + "import bad_name as good_name\n" + "from parent.some_bad_defs import bad_name_1 as ok_name_1, bad_name_2 as ok_name_2\n" + "import needed_in_bla_2\n" + "import needed_in_bla\n" + "import needed_in_bla_bla\n" + "import needed_in_end\n" + ) From 68bf16a0de7391ff5036ff1fac530f4ea359e489 Mon Sep 17 00:00:00 2001 From: Timur Kushukov Date: Wed, 14 Oct 2020 11:45:33 +0500 Subject: [PATCH 2/3] get imports stdout fix --- isort/core.py | 9 ++++++--- tests/unit/test_isort.py | 22 ++++++++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/isort/core.py b/isort/core.py index 3668ae9e8..c62687e5a 100644 --- a/isort/core.py +++ b/isort/core.py @@ -1,3 +1,4 @@ +import os import textwrap from io import StringIO from itertools import chain @@ -70,6 +71,9 @@ def process( end_of_file: bool = False verbose_output: List[str] = [] all_imports: List[str] = [] + if imports_only: + _output_stream = output_stream + output_stream = open(os.devnull, "wt") if config.float_to_top: new_input = "" @@ -403,9 +407,8 @@ def process( print(output_str) if imports_only: - output_stream.seek(0) - output_stream.truncate(0) - output_stream.write(line_separator.join(all_imports) + line_separator) + result = line_separator.join(all_imports) + line_separator + _output_stream.write(result) return made_changes diff --git a/tests/unit/test_isort.py b/tests/unit/test_isort.py index f28c5de63..cdec1c4e2 100644 --- a/tests/unit/test_isort.py +++ b/tests/unit/test_isort.py @@ -7,6 +7,7 @@ from pathlib import Path import subprocess import sys +from io import StringIO from tempfile import NamedTemporaryFile from typing import Any, Dict, Iterator, List, Set, Tuple @@ -4954,3 +4955,24 @@ def test_get_imports_string() -> None: "import needed_in_bla_bla\n" "import needed_in_end\n" ) + + +def test_get_imports_stdout() -> None: + """Ensure that get_imports_stream can work with nonseekable streams like STDOUT""" + + global_output = [] + + class NonSeekableTestStream(StringIO): + def seek(self, position): + raise OSError("Stream is not seekable") + + def seekable(self): + return False + + def write(self, s): + global_output.append(s) + + test_input = StringIO("import m2\n" "import m1\n" "not_import = 7") + test_output = NonSeekableTestStream() + api.get_imports_stream(test_input, test_output) + assert "".join(global_output) == "import m2\nimport m1\n" From e24c3b0b144f73319f12193155c66ef24a071d8d Mon Sep 17 00:00:00 2001 From: Timur Kushukov Date: Thu, 15 Oct 2020 12:35:06 +0500 Subject: [PATCH 3/3] fix devnull linter warning --- isort/core.py | 8 ++++++-- tests/unit/test_isort.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/isort/core.py b/isort/core.py index c62687e5a..badf2a04c 100644 --- a/isort/core.py +++ b/isort/core.py @@ -1,4 +1,3 @@ -import os import textwrap from io import StringIO from itertools import chain @@ -73,7 +72,12 @@ def process( all_imports: List[str] = [] if imports_only: _output_stream = output_stream - output_stream = open(os.devnull, "wt") + + class DevNull(StringIO): + def write(self, *a, **kw): + pass + + output_stream = DevNull() if config.float_to_top: new_input = "" diff --git a/tests/unit/test_isort.py b/tests/unit/test_isort.py index cdec1c4e2..aea2c74b6 100644 --- a/tests/unit/test_isort.py +++ b/tests/unit/test_isort.py @@ -4969,7 +4969,7 @@ def seek(self, position): def seekable(self): return False - def write(self, s): + def write(self, s, *a, **kw): global_output.append(s) test_input = StringIO("import m2\n" "import m1\n" "not_import = 7")