diff --git a/isort/api.py b/isort/api.py index 611187437..ae17afb89 100644 --- a/isort/api.py +++ b/isort/api.py @@ -21,6 +21,20 @@ from .settings import DEFAULT_CONFIG, Config +def detect_newline(code: str) -> str: + stream = StringIO(code, newline="") + line = stream.readline() + if not line: + newline = "\n" + elif "\r\n" == line[-2:]: + newline = "\r\n" + elif "\r" == line[-1:]: + newline = "\r" + else: + newline = "\n" + return newline + + def sort_code_string( code: str, extension: Optional[str] = None, @@ -41,8 +55,9 @@ def sort_code_string( TextIO stream is provided results will be written to it, otherwise no diff will be computed. - ****config_kwargs**: Any config modifications. """ - input_stream = StringIO(code) - output_stream = StringIO() + newline = detect_newline(code) + input_stream = StringIO(code, newline=None) + output_stream = StringIO(newline=newline) config = _config(path=file_path, config=config, **config_kwargs) sort_stream( input_stream, @@ -92,6 +107,7 @@ def check_code_string( def sort_stream( input_stream: TextIO, output_stream: TextIO, + newline: str, extension: Optional[str] = None, config: Config = DEFAULT_CONFIG, file_path: Optional[Path] = None, @@ -113,7 +129,7 @@ def sort_stream( - ****config_kwargs**: Any config modifications. """ if show_diff: - _output_stream = StringIO() + _output_stream = StringIO(newline=newline) _input_stream = StringIO(input_stream.read()) changed = sort_stream( input_stream=_input_stream, @@ -152,7 +168,7 @@ def sort_stream( raise ExistingSyntaxErrors(content_source) if not output_stream.readable(): - _internal_output = StringIO() + _internal_output = StringIO(newline=newline) try: changed = core.process( @@ -309,6 +325,7 @@ def sort_file( changed = sort_stream( input_stream=source_file.stream, output_stream=sys.stdout, + newline=source_file.newline, config=config, file_path=actual_file_path, disregard_skip=disregard_skip, @@ -318,7 +335,7 @@ def sort_file( tmp_file = source_file.path.with_suffix(source_file.path.suffix + ".isorted") try: with tmp_file.open( - "w", encoding=source_file.encoding, newline="" + "w", encoding=source_file.encoding, newline=source_file.newline ) as output_stream: shutil.copymode(filename, tmp_file) changed = sort_stream( @@ -333,7 +350,7 @@ def sort_file( if show_diff or ask_to_apply: source_file.stream.seek(0) with tmp_file.open( - encoding=source_file.encoding, newline="" + encoding=source_file.encoding, newline=source_file.newline ) as tmp_out: show_unified_diff( file_input=source_file.stream.read(), diff --git a/isort/io.py b/isort/io.py index 7ff2807d2..4c88823ce 100644 --- a/isort/io.py +++ b/isort/io.py @@ -1,10 +1,13 @@ """Defines any IO utilities used by isort""" +import io import re import tokenize from contextlib import contextmanager from io import BytesIO, StringIO, TextIOWrapper from pathlib import Path -from typing import Callable, Iterator, NamedTuple, TextIO, Union +from typing import BinaryIO +from typing import Iterator, NamedTuple, TextIO, Union +from typing import Tuple from isort.exceptions import UnsupportedEncoding @@ -15,50 +18,49 @@ class File(NamedTuple): stream: TextIO path: Path encoding: str + newline: str @staticmethod - def detect_encoding(filename: str, readline: Callable[[], bytes]): + def decode_bytes(filename: str, buffer: BinaryIO) -> Tuple[TextIO, str, str]: try: - return tokenize.detect_encoding(readline)[0] + encoding, lines = tokenize.detect_encoding(buffer.readline) except Exception: raise UnsupportedEncoding(filename) + if not lines: + newline = "\n" + elif b"\r\n" == lines[0][-2:]: + newline = "\r\n" + elif b"\r" == lines[0][-1:]: + newline = "\r" + else: + newline = "\n" + + buffer.seek(0) + text = io.TextIOWrapper(buffer, encoding, line_buffering=True) + return text, encoding, newline + @staticmethod def from_contents(contents: str, filename: str) -> "File": - encoding = File.detect_encoding(filename, BytesIO(contents.encode("utf-8")).readline) - return File(StringIO(contents), path=Path(filename).resolve(), encoding=encoding) + text, encoding, newline = File.decode_bytes(filename, BytesIO(contents.encode("utf-8"))) + return File(StringIO(contents), path=Path(filename).resolve(), encoding=encoding, newline=newline) @property def extension(self): return self.path.suffix.lstrip(".") - @staticmethod - def _open(filename): - """Open a file in read only mode using the encoding detected by - detect_encoding(). - """ - buffer = open(filename, "rb") - try: - encoding = File.detect_encoding(filename, buffer.readline) - buffer.seek(0) - text = TextIOWrapper(buffer, encoding, line_buffering=True, newline="") - text.mode = "r" # type: ignore - return text - except Exception: - buffer.close() - raise - @staticmethod @contextmanager def read(filename: Union[str, Path]) -> Iterator["File"]: file_path = Path(filename).resolve() - stream = None + buffer = None try: - stream = File._open(file_path) - yield File(stream=stream, path=file_path, encoding=stream.encoding) + buffer = open(filename, "rb") + stream, encoding, newline = File.decode_bytes(filename, buffer) + yield File(stream=stream, path=file_path, encoding=encoding, newline=newline) finally: - if stream is not None: - stream.close() + if buffer is not None: + buffer.close() class _EmptyIO(StringIO): diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py index 4ee19bc43..6659c58af 100644 --- a/tests/unit/test_api.py +++ b/tests/unit/test_api.py @@ -2,11 +2,13 @@ import os import sys from io import StringIO +from pathlib import Path from unittest.mock import MagicMock, patch import pytest from isort import api +from isort.api import detect_newline from isort.settings import Config imperfect_content = "import b\nimport a\n" @@ -15,12 +17,25 @@ @pytest.fixture -def imperfect(tmpdir) -> None: +def imperfect(tmpdir) -> Path: imperfect_file = tmpdir.join("test_needs_changes.py") - imperfect_file.write_text(imperfect_content, "utf8") + with open(imperfect_file, mode="w", encoding="utf-8", newline=os.linesep) as f: + f.write(imperfect_content) return imperfect_file +def test_detect_newline(): + lf: str = "a\nb" + crlf: str = "a\r\nb" + cr: str = "a\rb" + empty: str = "" + + assert "\n" == detect_newline(lf) + assert "\r\n" == detect_newline(crlf) + assert "\r" == detect_newline(cr) + assert "\n" == detect_newline(empty) + + def test_sort_file_with_bad_syntax(tmpdir) -> None: tmp_file = tmpdir.join("test_bad_syntax.py") tmp_file.write_text("""print('mismatching quotes")""", "utf8")