Skip to content

Commit

Permalink
fix(mypy): 💚 add missing typing
Browse files Browse the repository at this point in the history
Includes workarounds for tmbo/questionary#191 and pydantic/pydantic#3175.
  • Loading branch information
yajo committed Dec 19, 2021
1 parent f99f10b commit 534f012
Show file tree
Hide file tree
Showing 11 changed files with 118 additions and 82 deletions.
3 changes: 2 additions & 1 deletion copier/errors.py
Expand Up @@ -9,7 +9,8 @@
from .types import PathSeq

if TYPE_CHECKING: # always false
from .user_data import AnswersMap, Question, Template
from .template import Template
from .user_data import AnswersMap, Question


# Errors
Expand Down
11 changes: 6 additions & 5 deletions copier/main.py
Expand Up @@ -10,7 +10,7 @@
from itertools import chain
from pathlib import Path
from shutil import rmtree
from typing import Callable, List, Mapping, Optional, Sequence
from typing import Callable, Iterable, List, Mapping, Optional, Sequence
from unicodedata import normalize

import pathspec
Expand Down Expand Up @@ -41,7 +41,8 @@
try:
from functools import cached_property
except ImportError:
from backports.cached_property import cached_property
# HACK https://github.com/python/mypy/issues/1153#issuecomment-558556828
from backports.cached_property import cached_property # type: ignore


@dataclass
Expand Down Expand Up @@ -131,7 +132,7 @@ class Worker:
"""

src_path: Optional[str] = None
dst_path: Path = field(default=".")
dst_path: Path = field(default=Path("."))
answers_file: Optional[RelativePath] = None
vcs_ref: OptStr = None
data: AnyByStrDict = field(default_factory=dict)
Expand Down Expand Up @@ -210,7 +211,7 @@ def _render_context(self) -> Mapping:
_folder_name=self.subproject.local_abspath.name,
)

def _path_matcher(self, patterns: StrSeq) -> Callable[[Path], bool]:
def _path_matcher(self, patterns: Iterable[str]) -> Callable[[Path], bool]:
"""Produce a function that matches against specified patterns."""
# TODO Is normalization really needed?
normalized_patterns = (normalize("NFD", pattern) for pattern in patterns)
Expand Down Expand Up @@ -347,7 +348,7 @@ def answers(self) -> AnswersMap:
question.get_default()
if self.defaults
else unsafe_prompt(
question.get_questionary_structure(), answers=result.combined
[question.get_questionary_structure()], answers=result.combined
)[question.var_name]
)
except KeyboardInterrupt as err:
Expand Down
6 changes: 4 additions & 2 deletions copier/subproject.py
Expand Up @@ -3,6 +3,7 @@
A *subproject* is a project that gets rendered and/or updated with Copier.
"""

import sys
from pathlib import Path
from typing import Optional

Expand All @@ -15,9 +16,10 @@
from .types import AbsolutePath, AnyByStrDict, VCSTypes
from .vcs import is_in_git_repo

try:
# HACK https://github.com/python/mypy/issues/8520#issuecomment-772081075
if sys.version_info >= (3, 8):
from functools import cached_property
except ImportError:
else:
from backports.cached_property import cached_property


Expand Down
9 changes: 3 additions & 6 deletions copier/template.py
Expand Up @@ -29,13 +29,10 @@
try:
from functools import cached_property
except ImportError:
from backports.cached_property import cached_property

try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
# HACK https://github.com/python/mypy/issues/1153#issuecomment-558556828
from backports.cached_property import cached_property # type: ignore

from .types import Literal

# Default list of files in the template to exclude from the rendered project
DEFAULT_EXCLUDE: Tuple[str, ...] = (
Expand Down
9 changes: 6 additions & 3 deletions copier/tools.py
Expand Up @@ -9,13 +9,14 @@
import warnings
from contextlib import suppress
from pathlib import Path
from typing import Any, Callable, Optional, TextIO, Union
from types import TracebackType
from typing import Any, Callable, Optional, TextIO, Tuple, Union

import colorama
from packaging.version import Version
from pydantic import StrictBool

from .types import ExcInfo, IntSeq
from .types import IntSeq

try:
from importlib.metadata import version
Expand Down Expand Up @@ -128,7 +129,9 @@ def force_str_end(original_str: str, end: str = "\n") -> str:
return original_str


def handle_remove_readonly(func: Callable, path: str, exc: ExcInfo) -> None:
def handle_remove_readonly(
func: Callable, path: str, exc: Tuple[BaseException, OSError, TracebackType]
) -> None:
"""Handle errors when trying to remove read-only files through `shutil.rmtree`.
This handler makes sure the given file is writable, then re-execute the given removal function.
Expand Down
8 changes: 4 additions & 4 deletions copier/types.py
@@ -1,7 +1,7 @@
"""Complex types, annotations, validators."""

import sys
from pathlib import Path
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -18,9 +18,10 @@

from pydantic.validators import path_validator

try:
# HACK https://github.com/python/mypy/issues/8520#issuecomment-772081075
if sys.version_info >= (3, 8):
from typing import Literal
except ImportError:
else:
from typing_extensions import Literal

if TYPE_CHECKING:
Expand Down Expand Up @@ -54,7 +55,6 @@
Filters = Dict[str, Callable]
LoaderPaths = Union[str, Iterable[str]]
VCSTypes = Literal["git"]
ExcInfo = Tuple[BaseException, Exception, TracebackType]


class AllowArbitraryTypes:
Expand Down
8 changes: 5 additions & 3 deletions copier/user_data.py
@@ -1,6 +1,7 @@
"""Functions used to load user data."""
import datetime
import json
import sys
import warnings
from collections import ChainMap
from dataclasses import field
Expand Down Expand Up @@ -30,9 +31,10 @@
from .tools import cast_str_to_bool, force_str_end
from .types import AllowArbitraryTypes, AnyByStrDict, OptStr, OptStrOrPath, StrOrPath

try:
# HACK https://github.com/python/mypy/issues/8520#issuecomment-772081075
if sys.version_info >= (3, 8):
from functools import cached_property
except ImportError:
else:
from backports.cached_property import cached_property

if TYPE_CHECKING:
Expand Down Expand Up @@ -249,7 +251,7 @@ def get_default_rendered(self) -> Union[bool, str, Choice, None]:
return json.dumps(default, indent=2 if self.get_multiline() else None)
if self.get_type_name() == "yaml":
return yaml.safe_dump(
default, default_flow_style=not self.get_multiline(), width=float("inf")
default, default_flow_style=not self.get_multiline(), width=2147483647
).strip()
# All other data has to be str
return str(default)
Expand Down
3 changes: 0 additions & 3 deletions mypy.ini

This file was deleted.

0 comments on commit 534f012

Please sign in to comment.