Skip to content

Commit

Permalink
Merge pull request #1545 from timqsh/issue/1536
Browse files Browse the repository at this point in the history
Provide API and CLI to list and stream imports
  • Loading branch information
timothycrosley committed Oct 15, 2020
2 parents 6421a5b + e24c3b0 commit 1e38354
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 1 deletion.
9 changes: 8 additions & 1 deletion isort/__init__.py
Expand Up @@ -2,7 +2,14 @@
from . import settings
from ._version import __version__
from .api import check_code_string as check_code
from .api import check_file, check_stream, place_module, place_module_with_reason
from .api import (
check_file,
check_stream,
get_imports_stream,
get_imports_string,
place_module,
place_module_with_reason,
)
from .api import sort_code_string as code
from .api import sort_file as file
from .api import sort_stream as stream
Expand Down
56 changes: 56 additions & 0 deletions isort/api.py
Expand Up @@ -366,6 +366,62 @@ def sort_file(
return changed


def get_imports_string(
code: str,
extension: Optional[str] = None,
config: Config = DEFAULT_CONFIG,
file_path: Optional[Path] = None,
**config_kwargs,
) -> str:
"""Finds all imports within the provided code string, returning a new string with them.
- **code**: The string of code with imports that need to be sorted.
- **extension**: The file extension that contains imports. Defaults to filename extension or py.
- **config**: The config object to use when sorting imports.
- **file_path**: The disk location where the code string was pulled from.
- ****config_kwargs**: Any config modifications.
"""
input_stream = StringIO(code)
output_stream = StringIO()
config = _config(path=file_path, config=config, **config_kwargs)
get_imports_stream(
input_stream,
output_stream,
extension=extension,
config=config,
file_path=file_path,
)
output_stream.seek(0)
return output_stream.read()


def get_imports_stream(
input_stream: TextIO,
output_stream: TextIO,
extension: Optional[str] = None,
config: Config = DEFAULT_CONFIG,
file_path: Optional[Path] = None,
**config_kwargs,
) -> None:
"""Finds all imports within the provided code stream, outputs to the provided output stream.
- **input_stream**: The stream of code with imports that need to be sorted.
- **output_stream**: The stream where sorted imports should be written to.
- **extension**: The file extension that contains imports. Defaults to filename extension or py.
- **config**: The config object to use when sorting imports.
- **file_path**: The disk location where the code string was pulled from.
- ****config_kwargs**: Any config modifications.
"""
config = _config(path=file_path, config=config, **config_kwargs)
core.process(
input_stream,
output_stream,
extension=extension or (file_path and file_path.suffix.lstrip(".")) or "py",
config=config,
imports_only=True,
)


def _config(
path: Optional[Path] = None, config: Config = DEFAULT_CONFIG, **config_kwargs
) -> Config:
Expand Down
19 changes: 19 additions & 0 deletions isort/core.py
Expand Up @@ -30,6 +30,7 @@ def process(
output_stream: TextIO,
extension: str = "py",
config: Config = DEFAULT_CONFIG,
imports_only: bool = False,
) -> bool:
"""Parses stream identifying sections of contiguous imports and sorting them
Expand Down Expand Up @@ -68,6 +69,15 @@ def process(
stripped_line: str = ""
end_of_file: bool = False
verbose_output: List[str] = []
all_imports: List[str] = []
if imports_only:
_output_stream = output_stream

class DevNull(StringIO):
def write(self, *a, **kw):
pass

output_stream = DevNull()

if config.float_to_top:
new_input = ""
Expand Down Expand Up @@ -331,6 +341,11 @@ def process(

parsed_content = parse.file_contents(import_section, config=config)
verbose_output += parsed_content.verbose_output
all_imports.extend(
li
for li in parsed_content.in_lines
if li and li not in set(parsed_content.lines_without_imports)
)

sorted_import_section = output.sorted_imports(
parsed_content,
Expand Down Expand Up @@ -395,6 +410,10 @@ def process(
for output_str in verbose_output:
print(output_str)

if imports_only:
result = line_separator.join(all_imports) + line_separator
_output_stream.write(result)

return made_changes


Expand Down
63 changes: 63 additions & 0 deletions tests/unit/test_isort.py
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path
import subprocess
import sys
from io import StringIO
from tempfile import NamedTemporaryFile
from typing import Any, Dict, Iterator, List, Set, Tuple

Expand Down Expand Up @@ -4913,3 +4914,65 @@ def test_combine_straight_imports() -> None:
assert isort.code(test_input, combine_straight_imports=True, only_sections=True) == (
"import sys, os, math\n" "\n" "import a, b\n"
)


def test_get_imports_string() -> None:
test_input = (
"import first_straight\n"
"\n"
"import second_straight\n"
"from first_from import first_from_function_1, first_from_function_2\n"
"import bad_name as good_name\n"
"from parent.some_bad_defs import bad_name_1 as ok_name_1, bad_name_2 as ok_name_2\n"
"\n"
"# isort: list\n"
"__all__ = ['b', 'c', 'a']\n"
"\n"
"def bla():\n"
" import needed_in_bla_2\n"
"\n"
"\n"
" import needed_in_bla\n"
" pass"
"\n"
"def bla_bla():\n"
" import needed_in_bla_bla\n"
"\n"
" #import not_really_an_import\n"
" pass"
"\n"
"import needed_in_end\n"
)
result = api.get_imports_string(test_input)
assert result == (
"import first_straight\n"
"import second_straight\n"
"from first_from import first_from_function_1, first_from_function_2\n"
"import bad_name as good_name\n"
"from parent.some_bad_defs import bad_name_1 as ok_name_1, bad_name_2 as ok_name_2\n"
"import needed_in_bla_2\n"
"import needed_in_bla\n"
"import needed_in_bla_bla\n"
"import needed_in_end\n"
)


def test_get_imports_stdout() -> None:
"""Ensure that get_imports_stream can work with nonseekable streams like STDOUT"""

global_output = []

class NonSeekableTestStream(StringIO):
def seek(self, position):
raise OSError("Stream is not seekable")

def seekable(self):
return False

def write(self, s, *a, **kw):
global_output.append(s)

test_input = StringIO("import m2\n" "import m1\n" "not_import = 7")
test_output = NonSeekableTestStream()
api.get_imports_stream(test_input, test_output)
assert "".join(global_output) == "import m2\nimport m1\n"

0 comments on commit 1e38354

Please sign in to comment.