Skip to content

Commit

Permalink
Self suggested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Nov 12, 2022
1 parent 236699b commit c98ba01
Showing 1 changed file with 5 additions and 12 deletions.
17 changes: 5 additions & 12 deletions .actions/assistant.py
@@ -1,4 +1,3 @@
import logging
import os
import re
import shutil
Expand All @@ -9,7 +8,6 @@

import pkg_resources

_PATH_ROOT = dirname(dirname(__file__))
REQUIREMENT_FILES = {
"pytorch": (
"requirements/pytorch/base.txt",
Expand Down Expand Up @@ -68,12 +66,8 @@ def _replace_imports(lines: List[str], mapping: List[Tuple[str, str]]) -> List[s
def copy_replace_imports(
source_dir: str, source_imports: List[str], target_imports: List[str], target_dir: Optional[str] = None
) -> None:
"""Copy package content with import adjustments.
>>> _ = copy_replace_imports(os.path.join(
... _PATH_ROOT, "src"), ["lightning_app"], ["lightning.app"], os.path.join(_PATH_ROOT, "src", "lightning"))
"""
logging.info(f"Replacing imports: {locals()}")
"""Copy package content with import adjustments."""
print(f"Replacing imports: {locals()}")
assert len(source_imports) == len(target_imports), (
"source and target imports must have the same length, "
f"source: {len(source_imports)}, target: {len(target_imports)}"
Expand All @@ -98,7 +92,7 @@ def copy_replace_imports(
lines = fo.readlines()
except UnicodeDecodeError:
# a binary file, skip
logging.warning(f"Skipped replacing imports for {fp}")
print(f"Skipped replacing imports for {fp}")
continue
lines = _replace_imports(lines, list(zip(source_imports, target_imports)))
os.makedirs(os.path.dirname(fp_new), exist_ok=True)
Expand Down Expand Up @@ -145,7 +139,7 @@ def _prune_packages(req_file: str, packages: Sequence[str]) -> None:
req = list(pkg_resources.parse_requirements(ln_))[0]
if req.name not in packages:
final.append(line)
logging.info(final)
print(final)
path.write_text("\n".join(final))

@staticmethod
Expand All @@ -163,7 +157,7 @@ def replace_oldest_ver(requirement_fnames: Sequence[str] = REQUIREMENT_FILES_ALL
def copy_replace_imports(
source_dir: str, source_import: str, target_import: str, target_dir: Optional[str] = None
) -> None:
"""Recursively replace imports in given folder."""
"""Copy package content with import adjustments."""
source_imports = source_import.strip().split(",")
target_imports = target_import.strip().split(",")
copy_replace_imports(source_dir, source_imports, target_imports, target_dir=target_dir)
Expand All @@ -172,5 +166,4 @@ def copy_replace_imports(
if __name__ == "__main__":
import jsonargparse

logging.basicConfig(level=logging.INFO)
jsonargparse.CLI(AssistantCLI, as_positional=False)

0 comments on commit c98ba01

Please sign in to comment.