Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend common_statements #160

Merged
merged 7 commits into from
Nov 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 17 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,23 @@ from os import getcwd # noqa: autoimport
getcwd()
```


# Configuration

`autoimport` uses the `maison` library to discover and read your project-local
`pyproject.toml` file (if it exists). This file can be used to configure
`autoimport`'s behavior: the `tool.autoimport.common_statements` table in that
file can be used to define a custom set of "common statements", overriding the
default set of common statements mentioned above. For example:

```toml
# pyproject.toml

[tool.autoimport.common_statements]
"np" = "import numpy as np"
"FooBar" = "from baz_qux import FooBar"
```

# References

As most open sourced programs, `autoimport` is standing on the shoulders of
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,5 @@
[console_scripts]
autoimport=autoimport.entrypoints.cli:cli
""",
install_requires=["autoflake", "Click", "pyprojroot", "sh"],
install_requires=["autoflake", "Click", "pyprojroot", "sh", "maison"],
)
56 changes: 0 additions & 56 deletions src/autoimport/config.py

This file was deleted.

4 changes: 3 additions & 1 deletion src/autoimport/entrypoints/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Tuple

import click
from maison.config import ProjectConfig

from autoimport import services, version

Expand All @@ -15,8 +16,9 @@
@click.argument("files", type=click.File("r+"), nargs=-1)
def cli(files: Tuple[str]) -> None:
"""Corrects the source code of the specified files."""
config = ProjectConfig(project_name="autoimport").to_dict()
try:
fixed_code = services.fix_files(files)
fixed_code = services.fix_files(files, config)
except FileNotFoundError as error:
log.error(error)

Expand Down
18 changes: 12 additions & 6 deletions src/autoimport/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import inspect
import os
import re
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

import autoflake
from pyflakes.messages import UndefinedExport, UndefinedName, UnusedImport
Expand Down Expand Up @@ -37,12 +37,15 @@
class SourceCode: # noqa: R090
"""Python source code entity."""

def __init__(self, source_code: str) -> None:
def __init__(
self, source_code: str, config: Optional[Dict[str, Any]] = None
) -> None:
"""Initialize the object."""
self.header: List[str] = []
self.imports: List[str] = []
self.typing: List[str] = []
self.code: List[str] = []
self.config: Dict[str, Any] = config if config else {}
lyz-code marked this conversation as resolved.
Show resolved Hide resolved
self._trailing_newline = False
self._split_code(source_code)

Expand Down Expand Up @@ -356,8 +359,7 @@ def _find_package_in_typing(name: str) -> Optional[str]:
except KeyError:
return None

@staticmethod
def _find_package_in_common_statements(name: str) -> Optional[str]:
def _find_package_in_common_statements(self, name: str) -> Optional[str]:
"""Search in the common statements the object name.

Args:
Expand All @@ -366,8 +368,12 @@ def _find_package_in_common_statements(name: str) -> Optional[str]:
Returns:
import_string
"""
if name in common_statements:
return common_statements[name]
local_common_statements = common_statements.copy()
if "common_statements" in self.config:
local_common_statements.update(self.config["common_statements"])

if name in local_common_statements:
return local_common_statements[name]

return None

Expand Down
12 changes: 7 additions & 5 deletions src/autoimport/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
and handlers to achieve the program's purpose.
"""

from typing import Optional, Tuple
from typing import Any, Dict, Optional, Tuple

from _io import TextIOWrapper

from autoimport.model import SourceCode


def fix_files(files: Tuple[TextIOWrapper]) -> Optional[str]:
def fix_files(
files: Tuple[TextIOWrapper], config: Optional[Dict[str, Any]] = None
) -> Optional[str]:
"""Fix the python source code of a list of files.

If the input is taken from stdin, it will output the value to stdout.
Expand All @@ -24,7 +26,7 @@ def fix_files(files: Tuple[TextIOWrapper]) -> Optional[str]:
"""
for file_wrapper in files:
source = file_wrapper.read()
fixed_source = fix_code(source)
fixed_source = fix_code(source, config)

try:
# Click testing runner doesn't simulate correctly the reading from stdin
Expand All @@ -49,7 +51,7 @@ def fix_files(files: Tuple[TextIOWrapper]) -> Optional[str]:
return None


def fix_code(original_source_code: str) -> str:
def fix_code(original_source_code: str, config: Optional[Dict[str, Any]] = None) -> str:
lyz-code marked this conversation as resolved.
Show resolved Hide resolved
"""Fix python source code to correct import statements.

It corrects these errors:
Expand All @@ -64,4 +66,4 @@ def fix_code(original_source_code: str) -> str:
Returns:
Corrected source code.
"""
return SourceCode(original_source_code).fix()
return SourceCode(original_source_code, config=config).fix()
36 changes: 0 additions & 36 deletions src/autoimport/utils.py

This file was deleted.

19 changes: 0 additions & 19 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,20 +1 @@
"""Store the classes and fixtures used throughout the tests."""

from pathlib import Path
from typing import Callable, Optional

import pytest


@pytest.fixture()
def create_tmp_file(tmp_path: Path) -> Callable:
"""Fixture for creating a temporary file."""

def _create_tmp_file(
content: Optional[str] = "", filename: Optional[str] = "file.txt"
) -> Path:
tmp_file = tmp_path / filename
tmp_file.write_text(content)
return tmp_file

return _create_tmp_file
28 changes: 28 additions & 0 deletions tests/e2e/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,31 @@ def test_corrects_code_from_stdin(runner: CliRunner) -> None:

assert result.exit_code == 0
assert result.stdout == fixed_source


def test_pyproject_common_statements(runner: CliRunner, tmpdir: LocalPath) -> None:
"""Allow common_statements to be defined in pyproject.toml"""
pyproject_toml = tmpdir.join("pyproject.toml") # type: ignore
pyproject_toml.write(
dedent(
"""\
[tool.autoimport]
common_statements = { "FooBar" = "from baz_qux import FooBar" }
"""
)
)
test_file = tmpdir.join("source.py") # type: ignore
test_file.write("FooBar\n")
fixed_source = dedent(
"""\
from baz_qux import FooBar

FooBar
"""
)
with tmpdir.as_cwd():

result = runner.invoke(cli, [str(test_file)])

assert result.exit_code == 0
assert test_file.read() == fixed_source