From 2dfb6dc62ad9ba5d61acecdb7f9e6e7204594b6f Mon Sep 17 00:00:00 2001 From: David Lord Date: Mon, 10 May 2021 22:41:52 -0700 Subject: [PATCH] enable more mypy checks --- setup.cfg | 14 ++- src/werkzeug/_internal.py | 48 ++++++---- src/werkzeug/_reloader.py | 14 +-- src/werkzeug/datastructures.pyi | 92 ++++++++++--------- src/werkzeug/debug/__init__.py | 24 +++-- src/werkzeug/debug/console.py | 6 +- src/werkzeug/debug/tbtools.py | 2 +- src/werkzeug/exceptions.py | 22 +++-- src/werkzeug/filesystem.py | 3 +- src/werkzeug/formparser.py | 21 +++-- src/werkzeug/http.py | 62 ++++++++----- src/werkzeug/local.py | 99 ++++++++++++--------- src/werkzeug/middleware/lint.py | 22 ++--- src/werkzeug/middleware/profiler.py | 6 +- src/werkzeug/middleware/shared_data.py | 33 +++++-- src/werkzeug/routing.py | 59 +++++++----- src/werkzeug/sansio/response.py | 18 ++-- src/werkzeug/serving.py | 49 +++++----- src/werkzeug/test.py | 51 ++++++----- src/werkzeug/testapp.py | 9 +- src/werkzeug/urls.py | 36 ++++---- src/werkzeug/user_agent.py | 2 +- src/werkzeug/useragents.py | 6 +- src/werkzeug/utils.py | 41 ++++----- src/werkzeug/wrappers/accept.py | 5 +- src/werkzeug/wrappers/auth.py | 9 +- src/werkzeug/wrappers/base_request.py | 7 +- src/werkzeug/wrappers/base_response.py | 7 +- src/werkzeug/wrappers/common_descriptors.py | 9 +- src/werkzeug/wrappers/cors.py | 9 +- src/werkzeug/wrappers/etag.py | 9 +- src/werkzeug/wrappers/json.py | 5 +- src/werkzeug/wrappers/request.py | 52 +++++++---- src/werkzeug/wrappers/response.py | 16 ++-- src/werkzeug/wrappers/user_agent.py | 5 +- src/werkzeug/wsgi.py | 22 +++-- 36 files changed, 533 insertions(+), 361 deletions(-) diff --git a/setup.cfg b/setup.cfg index 6e4f08d6f..d95ffca88 100644 --- a/setup.cfg +++ b/setup.cfg @@ -90,18 +90,24 @@ python_version = 3.6 allow_redefinition = True disallow_subclassing_any = True # disallow_untyped_calls = True -# disallow_untyped_defs = True -# disallow_incomplete_defs = True +disallow_untyped_defs = True +disallow_incomplete_defs = True no_implicit_optional = True local_partial_types = True -# no_implicit_reexport = True +no_implicit_reexport = True strict_equality = True warn_redundant_casts = True warn_unused_configs = True warn_unused_ignores = True -# warn_return_any = True +warn_return_any = True # warn_unreachable = True +[mypy-werkzeug] +no_implicit_reexport = False + +[mypy-werkzeug.wrappers] +no_implicit_reexport = False + [mypy-colorama.*] ignore_missing_imports = True diff --git a/src/werkzeug/_internal.py b/src/werkzeug/_internal.py index e220ecd6c..71a5e2883 100644 --- a/src/werkzeug/_internal.py +++ b/src/werkzeug/_internal.py @@ -13,6 +13,7 @@ from weakref import WeakKeyDictionary if t.TYPE_CHECKING: + from wsgiref.types import StartResponse from wsgiref.types import WSGIApplication from wsgiref.types import WSGIEnvironment from .wrappers.request import Request # noqa: F401 @@ -48,10 +49,10 @@ class _Missing: - def __repr__(self): + def __repr__(self) -> str: return "no value" - def __reduce__(self): + def __reduce__(self) -> str: return "_missing" @@ -68,7 +69,7 @@ def _make_encode_wrapper(reference: bytes) -> t.Callable[[str], bytes]: ... -def _make_encode_wrapper(reference): +def _make_encode_wrapper(reference: t.AnyStr) -> t.Callable[[str], t.AnyStr]: """Create a function that will be called with a string argument. If the reference is bytes, values will be encoded to bytes. """ @@ -127,7 +128,12 @@ def _to_str( ... -def _to_str(x, charset=_default_encoding, errors="strict", allow_none_charset=False): +def _to_str( + x: t.Optional[t.Any], + charset: t.Optional[str] = _default_encoding, + errors: str = "strict", + allow_none_charset: bool = False, +) -> t.Optional[t.Union[str, bytes]]: if x is None or isinstance(x, str): return x @@ -138,7 +144,7 @@ def _to_str(x, charset=_default_encoding, errors="strict", allow_none_charset=Fa if allow_none_charset: return x - return x.decode(charset, errors) + return x.decode(charset, errors) # type: ignore def _wsgi_decoding_dance( @@ -186,7 +192,7 @@ def _has_level_handler(logger: logging.Logger) -> bool: class _ColorStreamHandler(logging.StreamHandler): """On Windows, wrap stream with Colorama for ANSI style support.""" - def __init__(self): + def __init__(self) -> None: try: import colorama except ImportError: @@ -197,7 +203,7 @@ def __init__(self): super().__init__(stream) -def _log(type: str, message: str, *args, **kwargs) -> None: +def _log(type: str, message: str, *args: t.Any, **kwargs: t.Any) -> None: """Log a message to the 'werkzeug' logger. The logger is created the first time it is needed. If there is no @@ -219,7 +225,7 @@ def _log(type: str, message: str, *args, **kwargs) -> None: getattr(_logger, type)(message.rstrip(), *args, **kwargs) -def _parse_signature(func): +def _parse_signature(func): # type: ignore """Return a signature object for the function. .. deprecated:: 2.0 @@ -251,7 +257,7 @@ def _parse_signature(func): arguments.append(param) arguments = tuple(arguments) - def parse(args, kwargs): + def parse(args, kwargs): # type: ignore new_args = [] missing = [] extra = {} @@ -306,7 +312,7 @@ def _dt_as_utc(dt: datetime) -> datetime: ... -def _dt_as_utc(dt): +def _dt_as_utc(dt: t.Optional[datetime]) -> t.Optional[datetime]: if dt is None: return dt @@ -356,14 +362,16 @@ def __get__( def __get__(self, instance: t.Any, owner: type) -> _TAccessorValue: ... - def __get__(self, instance, owner): + def __get__( + self, instance: t.Optional[t.Any], owner: type + ) -> t.Union[_TAccessorValue, "_DictAccessorProperty[_TAccessorValue]"]: if instance is None: return self storage = self.lookup(instance) if self.name not in storage: - return self.default + return self.default # type: ignore value = storage[self.name] @@ -371,9 +379,9 @@ def __get__(self, instance, owner): try: return self.load_func(value) except (ValueError, TypeError): - return self.default + return self.default # type: ignore - return value + return value # type: ignore def __set__(self, instance: t.Any, value: _TAccessorValue) -> None: if self.read_only: @@ -513,7 +521,7 @@ def _make_cookie_domain(domain: str) -> bytes: ... -def _make_cookie_domain(domain): +def _make_cookie_domain(domain: t.Optional[str]) -> t.Optional[bytes]: if domain is None: return None domain = _encode_idna(domain) @@ -533,7 +541,7 @@ def _make_cookie_domain(domain): def _easteregg(app: t.Optional["WSGIApplication"] = None) -> "WSGIApplication": """Like the name says. But who knows how it works?""" - def bzzzzzzz(gyver): + def bzzzzzzz(gyver: bytes) -> str: import base64 import zlib @@ -579,8 +587,12 @@ def bzzzzzzz(gyver): ] ) - def easteregged(environ, start_response): - def injecting_start_response(status, headers, exc_info=None): + def easteregged( + environ: "WSGIEnvironment", start_response: "StartResponse" + ) -> t.Iterable[bytes]: + def injecting_start_response( + status: str, headers: t.List[t.Tuple[str, str]], exc_info: t.Any = None + ) -> t.Callable[[bytes], t.Any]: headers.append(("X-Powered-By", "Werkzeug")) return start_response(status, headers, exc_info) diff --git a/src/werkzeug/_reloader.py b/src/werkzeug/_reloader.py index ace5e19d7..ab34533d9 100644 --- a/src/werkzeug/_reloader.py +++ b/src/werkzeug/_reloader.py @@ -127,7 +127,7 @@ def _find_common_roots(paths: t.Iterable[str]) -> t.Iterable[str]: rv = set() - def _walk(node, path): + def _walk(node: t.Mapping[str, dict], path: t.Tuple[str, ...]) -> None: for prefix, child in node.items(): _walk(child, path + (prefix,)) @@ -218,7 +218,7 @@ def __enter__(self) -> "ReloaderLoop": self.run_step() return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore """Clean up any resources associated with the reloader.""" pass @@ -284,7 +284,7 @@ def run_step(self) -> None: class WatchdogReloaderLoop(ReloaderLoop): - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: from watchdog.observers import Observer from watchdog.events import PatternMatchingEventHandler @@ -292,7 +292,7 @@ def __init__(self, *args, **kwargs) -> None: trigger_reload = self.trigger_reload class EventHandler(PatternMatchingEventHandler): # type: ignore - def on_any_event(self, event): + def on_any_event(self, event): # type: ignore trigger_reload(event.src_path) reloader_name = Observer.__name__.lower() @@ -331,7 +331,7 @@ def __enter__(self) -> ReloaderLoop: self.observer.start() return super().__enter__() - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore self.observer.stop() self.observer.join() @@ -379,7 +379,7 @@ def run_step(self) -> None: reloader_loops["auto"] = reloader_loops["watchdog"] -def ensure_echo_on(): +def ensure_echo_on() -> None: """Ensure that echo mode is enabled. Some tools such as PDB disable it which causes usability issues after a reload.""" # tcgetattr will fail if stdin isn't a tty @@ -404,7 +404,7 @@ def run_with_reloader( exclude_patterns: t.Optional[t.Iterable[str]] = None, interval: t.Union[int, float] = 1, reloader_type: str = "auto", -): +) -> None: """Run the given function in an independent Python interpreter.""" import signal diff --git a/src/werkzeug/datastructures.pyi b/src/werkzeug/datastructures.pyi index 46afc191c..7279d3a73 100644 --- a/src/werkzeug/datastructures.pyi +++ b/src/werkzeug/datastructures.pyi @@ -36,17 +36,21 @@ def iter_multi_items( class ImmutableListMixin(List[V]): _hash_cache: Optional[int] def __hash__(self) -> int: ... # type: ignore - def __delitem__(self, key) -> NoReturn: ... - def __iadd__(self, other) -> NoReturn: ... # type: ignore - def __imul__(self, other) -> NoReturn: ... # type: ignore - def __setitem__(self, key, value) -> NoReturn: ... - def append(self, value) -> NoReturn: ... - def remove(self, value) -> NoReturn: ... - def extend(self, values) -> NoReturn: ... - def insert(self, pos, value) -> NoReturn: ... - def pop(self, index=-1) -> NoReturn: ... + def __delitem__(self, key: Union[int, slice]) -> NoReturn: ... + def __iadd__(self, other: t.Any) -> NoReturn: ... # type: ignore + def __imul__(self, other: int) -> NoReturn: ... + def __setitem__( # type: ignore + self, key: Union[int, slice], value: V + ) -> NoReturn: ... + def append(self, value: V) -> NoReturn: ... + def remove(self, value: V) -> NoReturn: ... + def extend(self, values: Iterable[V]) -> NoReturn: ... + def insert(self, pos: int, value: V) -> NoReturn: ... + def pop(self, index: int = -1) -> NoReturn: ... def reverse(self) -> NoReturn: ... - def sort(self, key=None, reverse=False) -> NoReturn: ... + def sort( + self, key: Optional[Callable[[V], Any]] = None, reverse: bool = False + ) -> NoReturn: ... class ImmutableList(ImmutableListMixin[V]): ... @@ -58,21 +62,23 @@ class ImmutableDictMixin(Dict[K, V]): ) -> ImmutableDictMixin[K, V]: ... def _iter_hashitems(self) -> Iterable[Hashable]: ... def __hash__(self) -> int: ... # type: ignore - def setdefault(self, key, default=None) -> NoReturn: ... - def update(self, *args, **kwargs) -> NoReturn: ... - def pop(self, key, default=None) -> NoReturn: ... + def setdefault(self, key: K, default: Optional[V] = None) -> NoReturn: ... + def update(self, *args: Any, **kwargs: V) -> NoReturn: ... + def pop(self, key: K, default: Optional[V] = None) -> NoReturn: ... # type: ignore def popitem(self) -> NoReturn: ... - def __setitem__(self, key, value) -> NoReturn: ... - def __delitem__(self, key) -> NoReturn: ... + def __setitem__(self, key: K, value: V) -> NoReturn: ... + def __delitem__(self, key: K) -> NoReturn: ... def clear(self) -> NoReturn: ... class ImmutableMultiDictMixin(ImmutableDictMixin[K, V]): def _iter_hashitems(self) -> Iterable[Hashable]: ... - def add(self, key, value) -> NoReturn: ... + def add(self, key: K, value: V) -> NoReturn: ... def popitemlist(self) -> NoReturn: ... - def poplist(self, key) -> NoReturn: ... - def setlist(self, key, new_list) -> NoReturn: ... - def setlistdefault(self, key, default_list=None) -> NoReturn: ... + def poplist(self, key: K) -> NoReturn: ... + def setlist(self, key: K, new_list: Iterable[V]) -> NoReturn: ... + def setlistdefault( + self, key: K, default_list: Optional[Iterable[V]] = None + ) -> NoReturn: ... def _calls_update(name: str) -> Callable[[UpdateDictMixin[K, V]], Any]: ... @@ -129,7 +135,7 @@ class MultiDict(TypeConversionDict[K, V]): def values(self) -> Iterator[V]: ... # type: ignore def listvalues(self) -> Iterator[List[V]]: ... def copy(self) -> MultiDict[K, V]: ... - def deepcopy(self, memo=None) -> MultiDict[K, V]: ... + def deepcopy(self, memo: Any = None) -> MultiDict[K, V]: ... @overload def to_dict(self) -> Dict[K, V]: ... @overload @@ -145,7 +151,7 @@ class MultiDict(TypeConversionDict[K, V]): def poplist(self, key: K) -> List[V]: ... def popitemlist(self) -> Tuple[K, List[V]]: ... def __copy__(self) -> MultiDict[K, V]: ... - def __deepcopy__(self, memo) -> MultiDict[K, V]: ... + def __deepcopy__(self, memo: Any) -> MultiDict[K, V]: ... class _omd_bucket(Generic[K, V]): prev: Optional[_omd_bucket] @@ -273,20 +279,20 @@ class Headers(Dict[str, str]): def __copy__(self) -> Headers: ... class ImmutableHeadersMixin(Headers): - def __delitem__(self, key, _index_operation: bool = True) -> NoReturn: ... - def __setitem__(self, key, value) -> NoReturn: ... - def set(self, _key, _value, **kw) -> NoReturn: ... - def setlist(self, key, values) -> NoReturn: ... - def add(self, _key, _value, **kw) -> NoReturn: ... - def add_header(self, _key, _value, **_kw) -> NoReturn: ... - def remove(self, key) -> NoReturn: ... - def extend(self, *args, **kwargs) -> NoReturn: ... - def update(self, *args, **kwargs) -> NoReturn: ... - def insert(self, pos, value) -> NoReturn: ... - def pop(self, key=None, default=...) -> NoReturn: ... + def __delitem__(self, key: Any, _index_operation: bool = True) -> NoReturn: ... + def __setitem__(self, key: Any, value: Any) -> NoReturn: ... + def set(self, _key: Any, _value: Any, **kw: Any) -> NoReturn: ... + def setlist(self, key: Any, values: Any) -> NoReturn: ... + def add(self, _key: Any, _value: Any, **kw: Any) -> NoReturn: ... + def add_header(self, _key: Any, _value: Any, **_kw: Any) -> NoReturn: ... + def remove(self, key: Any) -> NoReturn: ... + def extend(self, *args: Any, **kwargs: Any) -> NoReturn: ... + def update(self, *args: Any, **kwargs: Any) -> NoReturn: ... + def insert(self, pos: Any, value: Any) -> NoReturn: ... + def pop(self, key: Any = None, default: Any = ...) -> NoReturn: ... def popitem(self) -> NoReturn: ... - def setdefault(self, key, default) -> NoReturn: ... # type: ignore - def setlistdefault(self, key, default) -> NoReturn: ... + def setdefault(self, key: Any, default: Any) -> NoReturn: ... # type: ignore + def setlistdefault(self, key: Any, default: Any) -> NoReturn: ... class EnvironHeaders(ImmutableHeadersMixin, Headers): environ: WSGIEnvironment @@ -298,11 +304,11 @@ class EnvironHeaders(ImmutableHeadersMixin, Headers): def __iter__(self) -> Iterator[Tuple[str, str]]: ... # type: ignore def copy(self) -> NoReturn: ... -class CombinedMultiDict(ImmutableMultiDictMixin[K, V], MultiDict[K, V]): +class CombinedMultiDict(ImmutableMultiDictMixin[K, V], MultiDict[K, V]): # type: ignore dicts: List[MultiDict[K, V]] def __init__(self, dicts: Optional[Iterable[MultiDict[K, V]]]) -> None: ... @classmethod - def fromkeys(cls, keys, value=None) -> NoReturn: ... + def fromkeys(cls, keys: Any, value: Any = None) -> NoReturn: ... def __getitem__(self, key: K) -> V: ... @overload # type: ignore def get(self, key: K) -> Optional[V]: ... @@ -344,11 +350,15 @@ class ImmutableDict(ImmutableDictMixin[K, V], Dict[K, V]): def copy(self) -> Dict[K, V]: ... def __copy__(self) -> ImmutableDict[K, V]: ... -class ImmutableMultiDict(ImmutableMultiDictMixin[K, V], MultiDict[K, V]): +class ImmutableMultiDict( # type: ignore + ImmutableMultiDictMixin[K, V], MultiDict[K, V] +): def copy(self) -> MultiDict[K, V]: ... def __copy__(self) -> ImmutableMultiDict[K, V]: ... -class ImmutableOrderedMultiDict(ImmutableMultiDictMixin[K, V], OrderedMultiDict[K, V]): +class ImmutableOrderedMultiDict( # type: ignore + ImmutableMultiDictMixin[K, V], OrderedMultiDict[K, V] +): def _iter_hashitems(self) -> Iterator[Tuple[int, Tuple[K, V]]]: ... def copy(self) -> OrderedMultiDict[K, V]: ... def __copy__(self) -> ImmutableOrderedMultiDict[K, V]: ... @@ -356,9 +366,9 @@ class ImmutableOrderedMultiDict(ImmutableMultiDictMixin[K, V], OrderedMultiDict[ class Accept(ImmutableList[Tuple[str, int]]): provided: bool def __init__( - self, values: Optional[Union[Accept, Iterable[Tuple[str, int]]]] = None + self, values: Optional[Union[Accept, Iterable[Tuple[str, float]]]] = None ) -> None: ... - def _specificity(self, value) -> Tuple[bool, ...]: ... + def _specificity(self, value: str) -> Tuple[bool, ...]: ... def _value_matches(self, value: str, item: str) -> bool: ... @overload # type: ignore def __getitem__(self, key: str) -> int: ... @@ -382,7 +392,7 @@ class Accept(ImmutableList[Tuple[str, int]]): def _normalize_mime(value: str) -> List[str]: ... class MIMEAccept(Accept): - def _specificity(self, value) -> Tuple[bool, ...]: ... + def _specificity(self, value: str) -> Tuple[bool, ...]: ... def _value_matches(self, value: str, item: str) -> bool: ... @property def accept_html(self) -> bool: ... diff --git a/src/werkzeug/debug/__init__.py b/src/werkzeug/debug/__init__.py index 8f6258833..959f53013 100644 --- a/src/werkzeug/debug/__init__.py +++ b/src/werkzeug/debug/__init__.py @@ -19,8 +19,10 @@ from ..wrappers.request import Request from ..wrappers.response import Response from .console import Console +from .tbtools import Frame from .tbtools import get_current_traceback from .tbtools import render_console_html +from .tbtools import Traceback if t.TYPE_CHECKING: from wsgiref.types import StartResponse @@ -35,16 +37,16 @@ def hash_pin(pin: str) -> str: return hashlib.sha1(f"{pin} added salt".encode("utf-8", "replace")).hexdigest()[:12] -_machine_id: t.Optional[str] = None +_machine_id: t.Optional[t.Union[str, bytes]] = None -def get_machine_id() -> str: +def get_machine_id() -> t.Optional[t.Union[str, bytes]]: global _machine_id if _machine_id is not None: return _machine_id - def _generate(): + def _generate() -> t.Optional[t.Union[str, bytes]]: linux = b"" # machine-id is stable across boots, boot_id is not. @@ -100,15 +102,19 @@ def _generate(): 0, winreg.KEY_READ | winreg.KEY_WOW64_64KEY, ) as rk: + guid: t.Union[str, bytes] + guid_type: int guid, guid_type = winreg.QueryValueEx(rk, "MachineGuid") if guid_type == winreg.REG_SZ: - return guid.encode("utf-8") + return guid.encode("utf-8") # type: ignore return guid except OSError: pass + return None + _machine_id = _generate() return _machine_id @@ -254,8 +260,8 @@ def __init__( console_init_func = None self.app = app self.evalex = evalex - self.frames: t.Dict[t.Hashable, t.Any] = {} - self.tracebacks: t.Dict[t.Hashable, t.Any] = {} + self.frames: t.Dict[int, t.Union[Frame, _ConsoleFrame]] = {} + self.tracebacks: t.Dict[int, Traceback] = {} self.request_key = request_key self.console_path = console_path self.console_init_func = console_init_func @@ -344,7 +350,9 @@ def debug_application( traceback.log(environ["wsgi.errors"]) - def execute_command(self, request, command, frame): + def execute_command( + self, request: Request, command: str, frame: t.Union[Frame, _ConsoleFrame] + ) -> Response: """Execute a command in a console.""" return Response(frame.console.eval(command), mimetype="text/html") @@ -483,7 +491,7 @@ def __call__( and self.secret == secret and self.check_pin_trust(environ) ): - response = self.execute_command(request, cmd, frame) + response = self.execute_command(request, cmd, frame) # type: ignore elif ( self.evalex and self.console_path is not None diff --git a/src/werkzeug/debug/console.py b/src/werkzeug/debug/console.py index c2382c43a..da786603a 100644 --- a/src/werkzeug/debug/console.py +++ b/src/werkzeug/debug/console.py @@ -53,7 +53,7 @@ def _write(self, x: str) -> None: def write(self, x: str) -> None: self._write(escape(x)) - def writelines(self, x): + def writelines(self, x: t.Iterable[str]) -> None: self._write(escape("".join(x))) @@ -72,14 +72,14 @@ def fetch() -> str: stream = _local.stream except AttributeError: return "" - return stream.reset() + return stream.reset() # type: ignore @staticmethod def displayhook(obj: object) -> None: try: stream = _local.stream except AttributeError: - return _displayhook(obj) + return _displayhook(obj) # type: ignore # stream._write bypasses escaping as debug_repr is # already generating HTML for us. if obj is not None: diff --git a/src/werkzeug/debug/tbtools.py b/src/werkzeug/debug/tbtools.py index 75b4cd2c5..27c70ca42 100644 --- a/src/werkzeug/debug/tbtools.py +++ b/src/werkzeug/debug/tbtools.py @@ -582,7 +582,7 @@ def get_context_lines( @property def current_line(self) -> str: try: - return self.sourcelines[self.lineno - 1] + return self.sourcelines[self.lineno - 1] # type: ignore except IndexError: return "" diff --git a/src/werkzeug/exceptions.py b/src/werkzeug/exceptions.py index daceb2df3..3366f9da7 100644 --- a/src/werkzeug/exceptions.py +++ b/src/werkzeug/exceptions.py @@ -113,7 +113,9 @@ class newcls(cls, exception): # type: ignore _description = cls.description show_exception = False - def __init__(self, arg=None, *args, **kwargs): + def __init__( + self, arg: t.Optional[t.Any] = None, *args: t.Any, **kwargs: t.Any + ) -> None: super().__init__(*args, **kwargs) if arg is None: @@ -122,17 +124,17 @@ def __init__(self, arg=None, *args, **kwargs): exception.__init__(self, arg) @property - def description(self): + def description(self) -> str: if self.show_exception: return ( f"{self._description}\n" f"{exception.__name__}: {exception.__str__(self)}" ) - return self._description + return self._description # type: ignore @description.setter - def description(self, value): + def description(self, value: str) -> None: self._description = value newcls.__module__ = sys._getframe(1).f_globals["__name__"] @@ -245,7 +247,7 @@ class BadRequestKeyError(BadRequest, KeyError): #: useful in a debug mode. show_exception = False - def __init__(self, arg=None, *args, **kwargs): + def __init__(self, arg: t.Optional[str] = None, *args: t.Any, **kwargs: t.Any): super().__init__(*args, **kwargs) if arg is None: @@ -264,7 +266,7 @@ def description(self) -> str: # type: ignore return self._description @description.setter - def description(self, value): + def description(self, value: str) -> None: self._description = value @@ -899,7 +901,9 @@ def __init__( if extra is not None: self.mapping.update(extra) - def __call__(self, code: t.Union[int, "Response"], *args, **kwargs) -> t.NoReturn: + def __call__( + self, code: t.Union[int, "Response"], *args: t.Any, **kwargs: t.Any + ) -> t.NoReturn: from .sansio.response import Response if isinstance(code, Response): @@ -911,7 +915,9 @@ def __call__(self, code: t.Union[int, "Response"], *args, **kwargs) -> t.NoRetur raise self.mapping[code](*args, **kwargs) -def abort(status: t.Union[int, "Response"], *args, **kwargs) -> t.NoReturn: +def abort( + status: t.Union[int, "Response"], *args: t.Any, **kwargs: t.Any +) -> t.NoReturn: """Raises an :py:exc:`HTTPException` for the given status code or WSGI application. diff --git a/src/werkzeug/filesystem.py b/src/werkzeug/filesystem.py index bdb8c83ea..36a3d12e9 100644 --- a/src/werkzeug/filesystem.py +++ b/src/werkzeug/filesystem.py @@ -1,5 +1,6 @@ import codecs import sys +import typing as t import warnings # We do not trust traditional unixes. @@ -8,7 +9,7 @@ ) -def _is_ascii_encoding(encoding: str) -> bool: +def _is_ascii_encoding(encoding: t.Optional[str]) -> bool: """Given an encoding this figures out if the encoding is actually ASCII (which is something we don't actually want in most cases). This is necessary because ASCII comes under many names such as ANSI_X3.4-1968. diff --git a/src/werkzeug/formparser.py b/src/werkzeug/formparser.py index d023c9d8f..54b699f78 100644 --- a/src/werkzeug/formparser.py +++ b/src/werkzeug/formparser.py @@ -39,7 +39,7 @@ class TStreamFactory(t.Protocol): def __call__( self, - total_content_length: int, + total_content_length: t.Optional[int], content_type: t.Optional[str], filename: t.Optional[str], content_length: t.Optional[int] = None, @@ -47,6 +47,9 @@ def __call__( ... +F = t.TypeVar("F", bound=t.Callable[..., t.Any]) + + def _exhaust(stream: t.BinaryIO) -> None: bts = stream.read(64 * 1024) while bts: @@ -54,7 +57,7 @@ def _exhaust(stream: t.BinaryIO) -> None: def default_stream_factory( - total_content_length: int, + total_content_length: t.Optional[int], content_type: t.Optional[str], filename: t.Optional[str], content_length: t.Optional[int] = None, @@ -130,10 +133,10 @@ def parse_form_data( ).parse_from_environ(environ) -def exhaust_stream(f): +def exhaust_stream(f: F) -> F: """Helper decorator for methods that exhausts the stream on return.""" - def wrapper(self, stream, *args, **kwargs): + def wrapper(self, stream, *args, **kwargs): # type: ignore try: return f(self, stream, *args, **kwargs) finally: @@ -148,7 +151,7 @@ def wrapper(self, stream, *args, **kwargs): if not chunk: break - return update_wrapper(wrapper, f) + return update_wrapper(t.cast(F, wrapper), f) class FormDataParser: @@ -270,7 +273,7 @@ def _parse_multipart( self, stream: t.BinaryIO, mimetype: str, - content_length: int, + content_length: t.Optional[int], options: t.Dict[str, str], ) -> "t_parse_result": parser = MultiPartParser( @@ -293,7 +296,7 @@ def _parse_urlencoded( self, stream: t.BinaryIO, mimetype: str, - content_length: int, + content_length: t.Optional[int], options: t.Dict[str, str], ) -> "t_parse_result": if ( @@ -413,7 +416,7 @@ def get_part_charset(self, headers: Headers) -> str: return self.charset def start_file_streaming( - self, event: File, total_content_length: int + self, event: File, total_content_length: t.Optional[int] ) -> t.BinaryIO: content_type = event.headers.get("content-type") @@ -431,7 +434,7 @@ def start_file_streaming( return container def parse( - self, stream: t.BinaryIO, boundary: bytes, content_length: int + self, stream: t.BinaryIO, boundary: bytes, content_length: t.Optional[int] ) -> t.Tuple[MultiDict, MultiDict]: container: t.Union[t.BinaryIO, t.List[bytes]] _write: t.Callable[[bytes], t.Any] diff --git a/src/werkzeug/http.py b/src/werkzeug/http.py index 5d94e0cf5..8880d553e 100644 --- a/src/werkzeug/http.py +++ b/src/werkzeug/http.py @@ -243,7 +243,7 @@ def unquote_header_value(value: str, is_filename: bool = False) -> str: def dump_options_header( - header: str, options: t.Dict[str, t.Optional[t.Union[str, int]]] + header: t.Optional[str], options: t.Mapping[str, t.Optional[t.Union[str, int]]] ) -> str: """The reverse function to :func:`parse_options_header`. @@ -390,7 +390,9 @@ def parse_options_header( ... -def parse_options_header(value, multiple=False): +def parse_options_header( + value: t.Optional[str], multiple: bool = False +) -> t.Union[t.Tuple[str, t.Dict[str, str]], t.Tuple[t.Any, ...]]: """Parse a ``Content-Type`` like header into a tuple with the content type and the options: @@ -414,7 +416,7 @@ def parse_options_header(value, multiple=False): if not value: return "", {} - result = [] + result: t.List[t.Any] = [] value = "," + value.replace("\n", ",") while value: @@ -422,10 +424,11 @@ def parse_options_header(value, multiple=False): if not match: break result.append(match.group(1)) # mimetype - options = {} + options: t.Dict[str, str] = {} # Parse options rest = match.group(2) - continued_encoding = None + encoding: t.Optional[str] + continued_encoding: t.Optional[str] = None while rest: optmatch = _option_header_piece_re.match(rest) if not optmatch: @@ -466,7 +469,7 @@ def parse_options_header(value, multiple=False): @typing.overload -def parse_accept_header(value: t.Optional[str], cls: None = None) -> "ds.Accept": +def parse_accept_header(value: t.Optional[str]) -> "ds.Accept": ... @@ -477,7 +480,9 @@ def parse_accept_header( ... -def parse_accept_header(value, cls=None): +def parse_accept_header( + value: t.Optional[str], cls: t.Optional[t.Type[_TAnyAccept]] = None +) -> _TAnyAccept: """Parses an HTTP Accept-* header. This does not implement a complete valid algorithm but one that supports at least value and quality extraction. @@ -494,7 +499,7 @@ def parse_accept_header(value, cls=None): :return: an instance of `cls`. """ if cls is None: - cls = ds.Accept + cls = t.cast(t.Type[_TAnyAccept], ds.Accept) if not value: return cls(None) @@ -528,7 +533,11 @@ def parse_cache_control_header( ... -def parse_cache_control_header(value, on_update=None, cls=None): +def parse_cache_control_header( + value: t.Optional[str], + on_update: _t_cc_update = None, + cls: t.Optional[t.Type[_TAnyCC]] = None, +) -> _TAnyCC: """Parse a cache control header. The RFC differs between response and request cache control, this method does not. It's your responsibility to not use the wrong control statements. @@ -546,9 +555,11 @@ def parse_cache_control_header(value, on_update=None, cls=None): :return: a `cls` object. """ if cls is None: - cls = ds.RequestCacheControl + cls = t.cast(t.Type[_TAnyCC], ds.RequestCacheControl) + if not value: - return cls(None, on_update) + return cls((), on_update) + return cls(parse_dict_header(value), on_update) @@ -570,7 +581,11 @@ def parse_csp_header( ... -def parse_csp_header(value, on_update=None, cls=None): +def parse_csp_header( + value: t.Optional[str], + on_update: _t_csp_update = None, + cls: t.Optional[t.Type[_TAnyCSP]] = None, +) -> _TAnyCSP: """Parse a Content Security Policy header. .. versionadded:: 1.0.0 @@ -584,16 +599,21 @@ def parse_csp_header(value, on_update=None, cls=None): :return: a `cls` object. """ if cls is None: - cls = ds.ContentSecurityPolicy + cls = t.cast(t.Type[_TAnyCSP], ds.ContentSecurityPolicy) + if value is None: - return cls(None, on_update) + return cls((), on_update) + items = [] + for policy in value.split(";"): policy = policy.strip() + # Ignore badly formatted policies (no space) if " " in policy: directive, value = policy.strip().split(" ", 1) items.append((directive.strip(), value.strip())) + return cls(items, on_update) @@ -1199,13 +1219,15 @@ def parse_cookie( if cls is None: cls = ds.MultiDict - def _parse_pairs(): - for key, val in _cookie_parse_impl(header): - key = _to_str(key, charset, errors, allow_none_charset=True) - if not key: + def _parse_pairs() -> t.Iterator[t.Tuple[str, str]]: + for key, val in _cookie_parse_impl(header): # type: ignore + key_str = _to_str(key, charset, errors, allow_none_charset=True) + + if not key_str: continue - val = _to_str(val, charset, errors, allow_none_charset=True) - yield key, val + + val_str = _to_str(val, charset, errors, allow_none_charset=True) + yield key_str, val_str return cls(_parse_pairs()) diff --git a/src/werkzeug/local.py b/src/werkzeug/local.py index a51669be7..4e664a011 100644 --- a/src/werkzeug/local.py +++ b/src/werkzeug/local.py @@ -10,7 +10,11 @@ from .wsgi import ClosingIterator if t.TYPE_CHECKING: + from wsgiref.types import StartResponse from wsgiref.types import WSGIApplication + from wsgiref.types import WSGIEnvironment + +F = t.TypeVar("F", bound=t.Callable[..., t.Any]) try: from greenlet import getcurrent as _get_ident @@ -18,7 +22,7 @@ from threading import get_ident as _get_ident -def get_ident(): +def get_ident() -> int: warnings.warn( "'get_ident' is deprecated and will be removed in Werkzeug" " 2.1. Use 'greenlet.getcurrent' or 'threading.get_ident' for" @@ -26,7 +30,7 @@ def get_ident(): DeprecationWarning, stacklevel=2, ) - return _get_ident() + return _get_ident() # type: ignore class _CannotUseContextVar(Exception): @@ -66,13 +70,13 @@ class ContextVar: # type: ignore of gevent. """ - def __init__(self, _name): - self.storage = {} + def __init__(self, _name: str) -> None: + self.storage: t.Dict[int, t.Dict[str, t.Any]] = {} - def get(self, default): + def get(self, default: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: return self.storage.get(_get_ident(), default) - def set(self, value): + def set(self, value: t.Dict[str, t.Any]) -> None: self.storage[_get_ident()] = value @@ -106,26 +110,26 @@ def __init__(self) -> None: object.__setattr__(self, "_storage", ContextVar("local_storage")) @property - def __storage__(self): + def __storage__(self) -> t.Dict[str, t.Any]: warnings.warn( "'__storage__' is deprecated and will be removed in Werkzeug 2.1.", DeprecationWarning, stacklevel=2, ) - return self._storage.get({}) + return self._storage.get({}) # type: ignore @property - def __ident_func__(self): + def __ident_func__(self) -> t.Callable[[], int]: warnings.warn( "'__ident_func__' is deprecated and will be removed in" " Werkzeug 2.1. It should not be used in Python 3.7+.", DeprecationWarning, stacklevel=2, ) - return _get_ident + return _get_ident # type: ignore @__ident_func__.setter - def __ident_func__(self, func): + def __ident_func__(self, func: t.Callable[[], int]) -> None: warnings.warn( "'__ident_func__' is deprecated and will be removed in" " Werkzeug 2.1. Setting it no longer has any effect.", @@ -206,7 +210,7 @@ def __ident_func__(self, value: t.Callable[[], int]) -> None: object.__setattr__(self._local, "__ident_func__", value) def __call__(self) -> "LocalProxy": - def _lookup(): + def _lookup() -> t.Any: rv = self.top if rv is None: raise RuntimeError("object unbound") @@ -219,7 +223,7 @@ def push(self, obj: t.Any) -> t.List[t.Any]: rv = getattr(self._local, "stack", []).copy() rv.append(obj) self._local.stack = rv - return rv + return rv # type: ignore def pop(self) -> t.Any: """Removes the topmost item from the stack, will return the @@ -285,16 +289,16 @@ def __init__( ) @property - def ident_func(self): + def ident_func(self) -> t.Callable[[], int]: warnings.warn( "'ident_func' is deprecated and will be removed in Werkzeug 2.1.", DeprecationWarning, stacklevel=2, ) - return _get_ident + return _get_ident # type: ignore @ident_func.setter - def ident_func(self, func): + def ident_func(self, func: t.Callable[[], int]) -> None: warnings.warn( "'ident_func' is deprecated and will be removedin Werkzeug" " 2.1. Setting it no longer has any effect.", @@ -323,7 +327,7 @@ def get_ident(self) -> int: ) return self.ident_func() - def cleanup(self): + def cleanup(self) -> None: """Manually clean up the data in the locals for this context. Call this at the end of the request or use `make_middleware()`. """ @@ -335,7 +339,9 @@ def make_middleware(self, app: "WSGIApplication") -> "WSGIApplication": request end. """ - def application(environ, start_response): + def application( + environ: "WSGIEnvironment", start_response: "StartResponse" + ) -> t.Iterable[bytes]: return ClosingIterator(app(environ, start_response), self.cleanup) return application @@ -374,18 +380,25 @@ class _ProxyLookup: __slots__ = ("bind_f", "fallback", "class_value", "name") - def __init__(self, f=None, fallback=None, class_value=None): + def __init__( + self, + f: t.Optional[t.Callable] = None, + fallback: t.Optional[t.Callable] = None, + class_value: t.Optional[t.Any] = None, + ) -> None: + bind_f: t.Optional[t.Callable[["LocalProxy", t.Any], t.Callable]] + if hasattr(f, "__get__"): # A Python function, can be turned into a bound method. - def bind_f(instance, obj): - return f.__get__(obj, type(obj)) + def bind_f(instance: "LocalProxy", obj: t.Any) -> t.Callable: + return f.__get__(obj, type(obj)) # type: ignore elif f is not None: # A C function, use partial to bind the first argument. - def bind_f(instance, obj): - return partial(f, obj) + def bind_f(instance: "LocalProxy", obj: t.Any) -> t.Callable: + return partial(f, obj) # type: ignore else: # Use getattr, which will produce a bound method. @@ -395,10 +408,10 @@ def bind_f(instance, obj): self.fallback = fallback self.class_value = class_value - def __set_name__(self, owner, name): + def __set_name__(self, owner: "LocalProxy", name: str) -> None: self.name = name - def __get__(self, instance, owner=None): + def __get__(self, instance: "LocalProxy", owner: t.Optional[type] = None) -> t.Any: if instance is None: if self.class_value is not None: return self.class_value @@ -411,17 +424,17 @@ def __get__(self, instance, owner=None): if self.fallback is None: raise - return self.fallback.__get__(instance, owner) + return self.fallback.__get__(instance, owner) # type: ignore if self.bind_f is not None: return self.bind_f(instance, obj) return getattr(obj, self.name) - def __repr__(self): + def __repr__(self) -> str: return f"proxy {self.name}" - def __call__(self, instance, *args, **kwargs): + def __call__(self, instance: "LocalProxy", *args: t.Any, **kwargs: t.Any) -> t.Any: """Support calling unbound methods from the class. For example, this happens with ``copy.copy``, which does ``type(x).__copy__(x)``. ``type(x)`` can't be proxied, so it @@ -437,26 +450,28 @@ class _ProxyIOp(_ProxyLookup): __slots__ = () - def __init__(self, f=None, fallback=None): + def __init__( + self, f: t.Optional[t.Callable] = None, fallback: t.Optional[t.Callable] = None + ) -> None: super().__init__(f, fallback) - def bind_f(instance, obj): - def i_op(self, other): - f(self, other) + def bind_f(instance: "LocalProxy", obj: t.Any) -> t.Callable: + def i_op(self: t.Any, other: t.Any) -> "LocalProxy": + f(self, other) # type: ignore return instance - return i_op.__get__(obj, type(obj)) + return i_op.__get__(obj, type(obj)) # type: ignore self.bind_f = bind_f -def _l_to_r_op(op): +def _l_to_r_op(op: F) -> F: """Swap the argument order to turn an l-op into an r-op.""" - def r_op(obj, other): + def r_op(obj: t.Any, other: t.Any) -> t.Any: return op(other, obj) - return r_op + return t.cast(F, r_op) class LocalProxy: @@ -537,24 +552,24 @@ def _get_current_object(self) -> t.Any: class_value=__doc__, fallback=lambda self: type(self).__doc__ ) # __del__ should only delete the proxy - __repr__ = _ProxyLookup( + __repr__ = _ProxyLookup( # type: ignore repr, fallback=lambda self: f"<{type(self).__name__} unbound>" ) - __str__ = _ProxyLookup(str) + __str__ = _ProxyLookup(str) # type: ignore __bytes__ = _ProxyLookup(bytes) __format__ = _ProxyLookup() # type: ignore __lt__ = _ProxyLookup(operator.lt) __le__ = _ProxyLookup(operator.le) - __eq__ = _ProxyLookup(operator.eq) - __ne__ = _ProxyLookup(operator.ne) + __eq__ = _ProxyLookup(operator.eq) # type: ignore + __ne__ = _ProxyLookup(operator.ne) # type: ignore __gt__ = _ProxyLookup(operator.gt) __ge__ = _ProxyLookup(operator.ge) __hash__ = _ProxyLookup(hash) # type: ignore __bool__ = _ProxyLookup(bool, fallback=lambda self: False) __getattr__ = _ProxyLookup(getattr) # __getattribute__ triggered through __getattr__ - __setattr__ = _ProxyLookup(setattr) - __delattr__ = _ProxyLookup(delattr) + __setattr__ = _ProxyLookup(setattr) # type: ignore + __delattr__ = _ProxyLookup(delattr) # type: ignore __dir__ = _ProxyLookup(dir, fallback=lambda self: []) # type: ignore # __get__ (proxying descriptor not supported) # __set__ (descriptor) diff --git a/src/werkzeug/middleware/lint.py b/src/werkzeug/middleware/lint.py index 0724f014f..6e7df0ab9 100644 --- a/src/werkzeug/middleware/lint.py +++ b/src/werkzeug/middleware/lint.py @@ -48,7 +48,7 @@ class InputStream: def __init__(self, stream: t.BinaryIO) -> None: self._stream = stream - def read(self, *args): + def read(self, *args: t.Any) -> bytes: if len(args) == 0: warn( "WSGI does not guarantee an EOF marker on the input stream, thus making" @@ -65,7 +65,7 @@ def read(self, *args): ) return self._stream.read(*args) - def readline(self, *args): + def readline(self, *args: t.Any) -> bytes: if len(args) == 0: warn( "Calls to 'wsgi.input.readline()' without arguments are unsafe. Use" @@ -84,14 +84,14 @@ def readline(self, *args): raise TypeError("Too many arguments passed to 'wsgi.input.readline()'.") return self._stream.readline(*args) - def __iter__(self): + def __iter__(self) -> t.Iterator[bytes]: try: return iter(self._stream) except TypeError: warn("'wsgi.input' is not iterable.", WSGIWarning, stacklevel=2) return iter(()) - def close(self): + def close(self) -> None: warn("The application closed the input stream!", WSGIWarning, stacklevel=2) self._stream.close() @@ -100,18 +100,18 @@ class ErrorStream: def __init__(self, stream: t.TextIO) -> None: self._stream = stream - def write(self, s): + def write(self, s: str) -> None: check_type("wsgi.error.write()", s, str) self._stream.write(s) - def flush(self): + def flush(self) -> None: self._stream.flush() - def writelines(self, seq): + def writelines(self, seq: t.Iterable[str]) -> None: for line in seq: self.write(line) - def close(self): + def close(self) -> None: warn("The application closed the error stream!", WSGIWarning, stacklevel=2) self._stream.close() @@ -368,7 +368,7 @@ def check_iterator(self, app_iter: t.Iterable[bytes]) -> None: stacklevel=3, ) - def __call__(self, *args, **kwargs) -> t.Iterable[bytes]: + def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Iterable[bytes]: if len(args) != 2: warn("A WSGI app takes two arguments.", WSGIWarning, stacklevel=2) @@ -391,7 +391,9 @@ def __call__(self, *args, **kwargs) -> t.Iterable[bytes]: headers_set: t.List[t.Any] = [] chunks: t.List[int] = [] - def checking_start_response(*args, **kwargs) -> t.Callable[[bytes], None]: + def checking_start_response( + *args: t.Any, **kwargs: t.Any + ) -> t.Callable[[bytes], None]: if len(args) not in {2, 3}: warn( f"Invalid number of arguments: {len(args)}, expected 2 or 3.", diff --git a/src/werkzeug/middleware/profiler.py b/src/werkzeug/middleware/profiler.py index 6166e4563..a15e4b4f6 100644 --- a/src/werkzeug/middleware/profiler.py +++ b/src/werkzeug/middleware/profiler.py @@ -95,18 +95,18 @@ def __call__( ) -> t.Iterable[bytes]: response_body: t.List[bytes] = [] - def catching_start_response(status, headers, exc_info=None): + def catching_start_response(status, headers, exc_info=None): # type: ignore start_response(status, headers, exc_info) return response_body.append - def runapp(): + def runapp() -> None: app_iter = self._app( environ, t.cast("StartResponse", catching_start_response) ) response_body.extend(app_iter) if hasattr(app_iter, "close"): - app_iter.close() + app_iter.close() # type: ignore profile = Profile() start = time.time() diff --git a/src/werkzeug/middleware/shared_data.py b/src/werkzeug/middleware/shared_data.py index 2578f7ae4..9d172f7b7 100644 --- a/src/werkzeug/middleware/shared_data.py +++ b/src/werkzeug/middleware/shared_data.py @@ -28,7 +28,7 @@ from ..wsgi import wrap_file _TOpener = t.Callable[[], t.Tuple[t.BinaryIO, datetime, int]] -_TLoader = t.Callable[[t.Optional[str]], t.Tuple[str, _TOpener]] +_TLoader = t.Callable[[t.Optional[str]], t.Tuple[t.Optional[str], t.Optional[_TOpener]]] if t.TYPE_CHECKING: from wsgiref.types import StartResponse @@ -163,11 +163,17 @@ def get_package_loader(self, package: str, package_path: str) -> _TLoader: # Python 3 reader = provider.get_resource_reader(package) # type: ignore - def loader(path): + def loader( + path: t.Optional[str], + ) -> t.Tuple[t.Optional[str], t.Optional[_TOpener]]: if path is None: return None, None path = safe_join(package_path, path) + + if path is None: + return None, None + basename = posixpath.basename(path) try: @@ -198,11 +204,17 @@ def loader(path): is_filesystem = os.path.exists(package_filename) root = os.path.join(os.path.dirname(package_filename), package_path) - def loader(path): + def loader( + path: t.Optional[str], + ) -> t.Tuple[t.Optional[str], t.Optional[_TOpener]]: if path is None: return None, None path = safe_join(root, path) + + if path is None: + return None, None + basename = posixpath.basename(path) if is_filesystem: @@ -212,7 +224,7 @@ def loader(path): return basename, self._opener(path) try: - data = provider.get_data(path) + data = provider.get_data(path) # type: ignore except OSError: return None, None @@ -221,9 +233,14 @@ def loader(path): return loader def get_directory_loader(self, directory: str) -> _TLoader: - def loader(path): + def loader( + path: t.Optional[str], + ) -> t.Tuple[t.Optional[str], t.Optional[_TOpener]]: if path is not None: path = safe_join(directory, path) + + if path is None: + return None, None else: path = directory @@ -266,10 +283,10 @@ def __call__( if file_loader is not None: break - if file_loader is None or not self.is_allowed(real_filename): + if file_loader is None or not self.is_allowed(real_filename): # type: ignore return self.app(environ, start_response) - guessed_type = mimetypes.guess_type(real_filename) + guessed_type = mimetypes.guess_type(real_filename) # type: ignore mime_type = get_content_type(guessed_type[0] or self.fallback_mimetype, "utf-8") f, mtime, file_size = file_loader() @@ -277,7 +294,7 @@ def __call__( if self.cache: timeout = self.cache_timeout - etag = self.generate_etag(mtime, file_size, real_filename) + etag = self.generate_etag(mtime, file_size, real_filename) # type: ignore headers += [ ("Etag", f'"{etag}"'), ("Cache-Control", f"max-age={timeout}, public"), diff --git a/src/werkzeug/routing.py b/src/werkzeug/routing.py index a7825d219..cc5f0ad04 100644 --- a/src/werkzeug/routing.py +++ b/src/werkzeug/routing.py @@ -142,6 +142,7 @@ import typing_extensions as te from wsgiref.types import WSGIApplication from wsgiref.types import WSGIEnvironment + from .wrappers.response import Response _rule_re = re.compile( r""" @@ -182,7 +183,7 @@ def _pythonize(value: str) -> t.Union[None, bool, int, float, str]: return _PYTHON_CONSTANTS[value] for convert in int, float: try: - return convert(value) + return convert(value) # type: ignore except ValueError: pass if value[:1] == value[-1:] and value[0] in "\"'": @@ -262,7 +263,11 @@ def __init__(self, new_url: str) -> None: super().__init__(new_url) self.new_url = new_url - def get_response(self, environ=None): + def get_response( + self, + environ: t.Optional["WSGIEnvironment"] = None, + scope: t.Optional[dict] = None, + ) -> "Response": return redirect(self.new_url, self.code) @@ -478,7 +483,7 @@ class RuleTemplate: def __init__(self, rules: t.Iterable["Rule"]) -> None: self.rules = list(rules) - def __call__(self, *args, **kwargs) -> "RuleTemplateFactory": + def __call__(self, *args: t.Any, **kwargs: t.Any) -> "RuleTemplateFactory": return RuleTemplateFactory(self.rules, dict(*args, **kwargs)) @@ -801,7 +806,7 @@ def get_converter( raise LookupError(f"the converter {converter_name!r} does not exist") return self.map.converters[converter_name](self.map, *args, **kwargs) - def _encode_query_vars(self, query_vars: t.Mapping[str, t.Any]): + def _encode_query_vars(self, query_vars: t.Mapping[str, t.Any]) -> str: return url_encode( query_vars, charset=self.map.charset, @@ -863,7 +868,9 @@ def _build_regex(rule: str) -> None: if not self.is_leaf: self._trace.append((False, "/")) + self._build: t.Callable[..., t.Tuple[str, str]] self._build = self._compile_builder(False).__get__(self, None) # type: ignore + self._build_unknown: t.Callable[..., t.Tuple[str, str]] self._build_unknown = self._compile_builder(True).__get__( # type: ignore self, None ) @@ -882,7 +889,7 @@ def _build_regex(rule: str) -> None: def match( self, path: str, method: t.Optional[str] = None - ) -> t.Optional[t.Mapping[str, t.Any]]: + ) -> t.Optional[t.MutableMapping[str, t.Any]]: """Check if the rule matches a given path. Path is a string in the form ``"subdomain|/path"`` and is assembled by the map. If the map is doing host matching the subdomain part will be the host @@ -952,7 +959,7 @@ def _get_func_code(code: CodeType, name: str) -> t.Callable[..., t.Tuple[str, st globs: t.Dict[str, t.Any] = {} locs: t.Dict[str, t.Any] = {} exec(code, globs, locs) - return locs[name] + return locs[name] # type: ignore def _compile_builder( self, append_unknown: bool = True @@ -979,12 +986,12 @@ def _compile_builder( else: opl.append((True, data)) - def _convert(elem): + def _convert(elem: str) -> ast.stmt: ret = _prefix_names(_CALL_CONVERTER_CODE_FMT.format(elem=elem)) - ret.args = [ast.Name(str(elem), ast.Load())] # str for py2 + ret.args = [ast.Name(str(elem), ast.Load())] # type: ignore # str for py2 return ret - def _parts(ops): + def _parts(ops: t.List[t.Tuple[bool, str]]) -> t.List[ast.AST]: parts = [ _convert(elem) if is_dynamic else ast.Str(s=elem) for is_dynamic, elem in ops @@ -1007,7 +1014,7 @@ def _parts(ops): body = [_IF_KWARGS_URL_ENCODE_AST] url_parts.extend(_URL_ENCODE_AST_NAMES) - def _join(parts): + def _join(parts: t.List[ast.AST]) -> ast.AST: if len(parts) == 1: # shortcut return parts[0] return ast.JoinedStr(parts) @@ -1175,7 +1182,7 @@ class BaseConverter: regex = "[^/]+" weight = 100 - def __init__(self, map: "Map", *args, **kwargs) -> None: + def __init__(self, map: "Map", *args: t.Any, **kwargs: t.Any) -> None: self.map = map def to_python(self, value: str) -> t.Any: @@ -1347,8 +1354,14 @@ class FloatConverter(NumberConverter): regex = r"\d+\.\d+" num_convert = float - def __init__(self, map, min=None, max=None, signed=False): - super().__init__(map, min=min, max=max, signed=signed) + def __init__( + self, + map: "Map", + min: t.Optional[float] = None, + max: t.Optional[float] = None, + signed: bool = False, + ) -> None: + super().__init__(map, min=min, max=max, signed=signed) # type: ignore class UUIDConverter(BaseConverter): @@ -1369,7 +1382,7 @@ class UUIDConverter(BaseConverter): def to_python(self, value: str) -> uuid.UUID: return uuid.UUID(value) - def to_url(self, value): + def to_url(self, value: uuid.UUID) -> str: return str(value) @@ -1826,12 +1839,12 @@ def match( def match( self, - path_info=None, - method=None, - return_rule=False, - query_args=None, - websocket=None, - ): + path_info: t.Optional[str] = None, + method: t.Optional[str] = None, + return_rule: bool = False, + query_args: t.Optional[t.Union[t.Mapping[str, t.Any], str]] = None, + websocket: t.Optional[bool] = None, + ) -> t.Tuple[t.Union[str, Rule], t.Mapping[str, t.Any]]: """The usage is simple: you just pass the match method the current path info as well as the method (which defaults to `GET`). The following things can then happen: @@ -1925,7 +1938,7 @@ def match( else: path_info = _to_str(path_info, self.map.charset) if query_args is None: - query_args = self.query_args + query_args = self.query_args or {} method = (method or self.default_method).upper() if websocket is None: @@ -1974,8 +1987,8 @@ def match( if rule.redirect_to is not None: if isinstance(rule.redirect_to, str): - def _handle_match(match): - value = rv[match.group(1)] + def _handle_match(match: t.Match[str]) -> str: + value = rv[match.group(1)] # type: ignore return rule._converters[match.group(1)].to_url(value) redirect_url = _simple_rule_re.sub(_handle_match, rule.redirect_to) diff --git a/src/werkzeug/sansio/response.py b/src/werkzeug/sansio/response.py index f5975ccf8..aedfcb043 100644 --- a/src/werkzeug/sansio/response.py +++ b/src/werkzeug/sansio/response.py @@ -7,6 +7,7 @@ from .._internal import _to_str from ..datastructures import Headers +from ..datastructures import HeaderSet from ..http import dump_cookie from ..http import HTTP_STATUS_CODES from ..utils import get_content_type @@ -35,8 +36,8 @@ def _set_property(name: str, doc: t.Optional[str] = None) -> property: - def fget(self): - def on_update(header_set): + def fget(self: "Response") -> HeaderSet: + def on_update(header_set: HeaderSet) -> None: if not header_set and name in self.headers: del self.headers[name] elif header_set: @@ -44,7 +45,12 @@ def on_update(header_set): return parse_set_header(self.headers.get(name), on_update) - def fset(self, value): + def fset( + self: "Response", + value: t.Optional[ + t.Union[str, t.Dict[str, t.Union[str, int]], t.Iterable[str]] + ], + ) -> None: if not value: del self.headers[name] elif isinstance(value, str): @@ -249,7 +255,7 @@ def delete_cookie( secure: bool = False, httponly: bool = False, samesite: t.Optional[str] = None, - ): + ) -> None: """Delete a cookie. Fails silently if key doesn't exist. :param key: the key (name) of the cookie to be deleted. @@ -311,7 +317,7 @@ def mimetype_params(self) -> t.Dict[str, str]: .. versionadded:: 0.5 """ - def on_update(d): + def on_update(d: t.Dict[str, str]) -> None: self.headers["Content-Type"] = dump_options_header(self.mimetype, d) d = parse_options_header(self.headers.get("content-type", ""))[1] @@ -482,7 +488,7 @@ def cache_control(self) -> ResponseCacheControl: request/response chain. """ - def on_update(cache_control): + def on_update(cache_control: ResponseCacheControl) -> None: if not cache_control and "cache-control" in self.headers: del self.headers["cache-control"] elif cache_control: diff --git a/src/werkzeug/serving.py b/src/werkzeug/serving.py index e37e9af1f..97cd01013 100644 --- a/src/werkzeug/serving.py +++ b/src/werkzeug/serving.py @@ -38,7 +38,7 @@ except ImportError: class _SslDummy: - def __getattr__(self, name): + def __getattr__(self, name: str) -> t.Any: raise RuntimeError("SSL support unavailable") ssl = _SslDummy() # type: ignore @@ -159,7 +159,7 @@ def server_version(self) -> str: # type: ignore def make_environ(self) -> "WSGIEnvironment": request_url = url_parse(self.path) - def shutdown_server(): + def shutdown_server() -> None: warnings.warn( "The 'environ['werkzeug.server.shutdown']' function is" " deprecated and will be removed in Werkzeug 2.1.", @@ -290,7 +290,7 @@ def write(data: bytes) -> None: self.wfile.write(data) self.wfile.flush() - def start_response(status, headers, exc_info=None): + def start_response(status, headers, exc_info=None): # type: ignore nonlocal status_set, headers_set if exc_info: try: @@ -387,7 +387,7 @@ def version_string(self) -> str: def address_string(self) -> str: if getattr(self, "environ", None): - return self.environ["REMOTE_ADDR"] + return self.environ["REMOTE_ADDR"] # type: ignore if not self.client_address: return "" @@ -397,7 +397,9 @@ def address_string(self) -> str: def port_integer(self) -> int: return self.client_address[1] - def log_request(self, code: t.Union[int, str] = "-", size: t.Union[int, str] = "-"): + def log_request( + self, code: t.Union[int, str] = "-", size: t.Union[int, str] = "-" + ) -> None: try: path = uri_to_iri(self.path) msg = f"{self.command} {path} {self.request_version}" @@ -425,13 +427,13 @@ def log_request(self, code: t.Union[int, str] = "-", size: t.Union[int, str] = " self.log("info", '"%s" %s %s', msg, code, size) - def log_error(self, *args) -> None: - self.log("error", *args) + def log_error(self, format: str, *args: t.Any) -> None: + self.log("error", format, *args) - def log_message(self, format: str, *args) -> None: + def log_message(self, format: str, *args: t.Any) -> None: self.log("info", format, *args) - def log(self, type: str, message: str, *args) -> None: + def log(self, type: str, message: str, *args: t.Any) -> None: _log( type, f"{self.address_string()} - - [{self.log_date_time_string()}] {message}\n", @@ -439,7 +441,7 @@ def log(self, type: str, message: str, *args) -> None: ) -def _ansi_style(value, *styles): +def _ansi_style(value: str, *styles: str) -> str: codes = { "bold": 1, "red": 31, @@ -466,8 +468,10 @@ def generate_adhoc_ssl_pair( from cryptography.hazmat.primitives.asymmetric import rsa except ImportError: raise TypeError("Using ad-hoc certificates requires the cryptography library.") + + backend = default_backend() pkey = rsa.generate_private_key( - public_exponent=65537, key_size=2048, backend=default_backend() + public_exponent=65537, key_size=2048, backend=backend ) # pretty damn sure that this is not actually accepted by anyone @@ -481,6 +485,7 @@ def generate_adhoc_ssl_pair( ] ) + backend = default_backend() cert = ( x509.CertificateBuilder() .subject_name(subject) @@ -491,7 +496,7 @@ def generate_adhoc_ssl_pair( .not_valid_after(dt.now(timezone.utc) + timedelta(days=365)) .add_extension(x509.ExtendedKeyUsage([x509.OID_SERVER_AUTH]), critical=False) .add_extension(x509.SubjectAlternativeName([x509.DNSName("*")]), critical=False) - .sign(pkey, hashes.SHA256(), default_backend()) + .sign(pkey, hashes.SHA256(), backend) ) return cert, pkey @@ -591,10 +596,10 @@ def load_ssl_context( return ctx -def is_ssl_error(error=None): +def is_ssl_error(error: t.Optional[Exception] = None) -> bool: """Checks if the given error (or the current one) is an SSL error.""" if error is None: - error = sys.exc_info()[1] + error = t.cast(Exception, sys.exc_info()[1]) return isinstance(error, ssl.SSLError) @@ -624,7 +629,7 @@ def get_sockaddr( return res[0][4] # type: ignore -def get_interface_ip(family: socket.AddressFamily): +def get_interface_ip(family: socket.AddressFamily) -> str: """Get the IP address of an external interface. Used when binding to 0.0.0.0 or ::1 to show a more useful URL. @@ -639,7 +644,7 @@ def get_interface_ip(family: socket.AddressFamily): except OSError: return "::1" if family == socket.AF_INET6 else "127.0.0.1" - return s.getsockname()[0] + return s.getsockname()[0] # type: ignore class BaseWSGIServer(HTTPServer): @@ -703,10 +708,10 @@ def __init__( else: self.ssl_context = None - def log(self, type: str, message: str, *args) -> None: + def log(self, type: str, message: str, *args: t.Any) -> None: _log(type, message, *args) - def serve_forever(self, poll_interval=0.5) -> None: + def serve_forever(self, poll_interval: float = 0.5) -> None: self.shutdown_signal = False try: super().serve_forever(poll_interval=poll_interval) @@ -905,7 +910,7 @@ def run_simple( application = SharedDataMiddleware(application, static_files) - def log_startup(sock): + def log_startup(sock: socket.socket) -> None: all_addresses_message = ( " * Running on all addresses.\n" " WARNING: This is a development server. Do not use it in" @@ -935,9 +940,9 @@ def log_startup(sock): sock.getsockname()[1], ) - def inner(): + def inner() -> None: try: - fd = int(os.environ["WERKZEUG_SERVER_FD"]) + fd: t.Optional[int] = int(os.environ["WERKZEUG_SERVER_FD"]) except (LookupError, ValueError): fd = None srv = make_server( @@ -1003,7 +1008,7 @@ def inner(): inner() -def run_with_reloader(*args, **kwargs) -> None: +def run_with_reloader(*args: t.Any, **kwargs: t.Any) -> None: """Run a process with the reloader. This is not a public API, do not use this function. diff --git a/src/werkzeug/test.py b/src/werkzeug/test.py index 9e02e3b35..8815b0f79 100644 --- a/src/werkzeug/test.py +++ b/src/werkzeug/test.py @@ -64,30 +64,31 @@ def stream_encode_multipart( if boundary is None: boundary = f"---------------WerkzeugFormPart_{time()}{random()}" - stream = BytesIO() + stream: t.BinaryIO = BytesIO() total_length = 0 on_disk = False if use_tempfile: - def write_binary(string): + def write_binary(s: bytes) -> int: nonlocal stream, total_length, on_disk if on_disk: - stream.write(string) + return stream.write(s) else: - length = len(string) + length = len(s) if length + total_length <= threshold: - stream.write(string) + stream.write(s) else: - new_stream = TemporaryFile("wb+") - new_stream.write(stream.getvalue()) - new_stream.write(string) + new_stream = t.cast(t.BinaryIO, TemporaryFile("wb+")) + new_stream.write(stream.getvalue()) # type: ignore + new_stream.write(s) stream = new_stream on_disk = True total_length += length + return length else: write_binary = stream.write @@ -451,7 +452,9 @@ def __init__( self.mimetype = mimetype @classmethod - def from_environ(cls, environ: "WSGIEnvironment", **kwargs) -> "EnvironBuilder": + def from_environ( + cls, environ: "WSGIEnvironment", **kwargs: t.Any + ) -> "EnvironBuilder": """Turn an environ dict back into a builder. Any extra kwargs override the args extracted from the environ. @@ -565,7 +568,7 @@ def mimetype_params(self) -> t.Mapping[str, str]: .. versionadded:: 0.14 """ - def on_update(d): + def on_update(d: t.Mapping[str, str]) -> None: self.headers["Content-Type"] = dump_options_header(self.mimetype, d) d = parse_options_header(self.headers.get("content-type", ""))[1] @@ -602,7 +605,7 @@ def _get_form(self, name: str, storage: t.Type[_TAnyMultiDict]) -> _TAnyMultiDic rv = storage() setattr(self, name, rv) - return rv + return rv # type: ignore def _set_form(self, name: str, value: MultiDict) -> None: """Common behavior for setting the :attr:`form` and @@ -1007,11 +1010,11 @@ def resolve_redirect( def open( self, - *args, + *args: t.Any, as_tuple: bool = False, buffered: bool = False, follow_redirects: bool = False, - **kwargs, + **kwargs: t.Any, ) -> "TestResponse": """Generate an environ dict from the given arguments, make a request to the application using it, and return the response. @@ -1118,42 +1121,42 @@ def open( return response - def get(self, *args, **kw) -> "TestResponse": + def get(self, *args: t.Any, **kw: t.Any) -> "TestResponse": """Call :meth:`open` with ``method`` set to ``GET``.""" kw["method"] = "GET" return self.open(*args, **kw) - def post(self, *args, **kw) -> "TestResponse": + def post(self, *args: t.Any, **kw: t.Any) -> "TestResponse": """Call :meth:`open` with ``method`` set to ``POST``.""" kw["method"] = "POST" return self.open(*args, **kw) - def put(self, *args, **kw) -> "TestResponse": + def put(self, *args: t.Any, **kw: t.Any) -> "TestResponse": """Call :meth:`open` with ``method`` set to ``PUT``.""" kw["method"] = "PUT" return self.open(*args, **kw) - def delete(self, *args, **kw) -> "TestResponse": + def delete(self, *args: t.Any, **kw: t.Any) -> "TestResponse": """Call :meth:`open` with ``method`` set to ``DELETE``.""" kw["method"] = "DELETE" return self.open(*args, **kw) - def patch(self, *args, **kw) -> "TestResponse": + def patch(self, *args: t.Any, **kw: t.Any) -> "TestResponse": """Call :meth:`open` with ``method`` set to ``PATCH``.""" kw["method"] = "PATCH" return self.open(*args, **kw) - def options(self, *args, **kw) -> "TestResponse": + def options(self, *args: t.Any, **kw: t.Any) -> "TestResponse": """Call :meth:`open` with ``method`` set to ``OPTIONS``.""" kw["method"] = "OPTIONS" return self.open(*args, **kw) - def head(self, *args, **kw) -> "TestResponse": + def head(self, *args: t.Any, **kw: t.Any) -> "TestResponse": """Call :meth:`open` with ``method`` set to ``HEAD``.""" kw["method"] = "HEAD" return self.open(*args, **kw) - def trace(self, *args, **kw) -> "TestResponse": + def trace(self, *args: t.Any, **kw: t.Any) -> "TestResponse": """Call :meth:`open` with ``method`` set to ``TRACE``.""" kw["method"] = "TRACE" return self.open(*args, **kw) @@ -1162,7 +1165,7 @@ def __repr__(self) -> str: return f"<{type(self).__name__} {self.application!r}>" -def create_environ(*args, **kwargs) -> "WSGIEnvironment": +def create_environ(*args: t.Any, **kwargs: t.Any) -> "WSGIEnvironment": """Create a new WSGI environ dict based on the values passed. The first parameter should be the path of the request which defaults to '/'. The second one can either be an absolute path (in that case the host is @@ -1211,7 +1214,7 @@ def run_wsgi_app( response: t.Optional[t.Tuple[str, t.List[t.Tuple[str, str]]]] = None buffer: t.List[bytes] = [] - def start_response(status, headers, exc_info=None): + def start_response(status, headers, exc_info=None): # type: ignore nonlocal response if exc_info: @@ -1287,7 +1290,7 @@ def __init__( headers: Headers, request: Request, history: t.Tuple["TestResponse"] = (), # type: ignore - **kwargs, + **kwargs: t.Any, ) -> None: super().__init__(response, status, headers, **kwargs) self.request = request diff --git a/src/werkzeug/testapp.py b/src/werkzeug/testapp.py index 0f93da584..76f83974b 100644 --- a/src/werkzeug/testapp.py +++ b/src/werkzeug/testapp.py @@ -141,7 +141,7 @@ def iter_sys_path() -> t.Iterator[t.Tuple[str, bool, bool]]: if os.name == "posix": - def strip(x): + def strip(x: str) -> str: prefix = os.path.expanduser("~") if x.startswith(prefix): x = f"~{x[len(prefix) :]}" @@ -149,7 +149,7 @@ def strip(x): else: - def strip(x): + def strip(x: str) -> str: return x cwd = os.path.abspath(os.getcwd()) @@ -164,7 +164,10 @@ def render_testapp(req: Request) -> bytes: except ImportError: eggs: t.Iterable[t.Any] = () else: - eggs = sorted(pkg_resources.working_set, key=lambda x: x.project_name.lower()) + eggs = sorted( + pkg_resources.working_set, + key=lambda x: x.project_name.lower(), # type: ignore + ) python_eggs = [] for egg in eggs: try: diff --git a/src/werkzeug/urls.py b/src/werkzeug/urls.py index d9edb8860..7566ac273 100644 --- a/src/werkzeug/urls.py +++ b/src/werkzeug/urls.py @@ -60,7 +60,7 @@ class BaseURL(_URLTuple): def __str__(self) -> str: return self.to_url() - def replace(self, **kwargs) -> "BaseURL": + def replace(self, **kwargs: t.Any) -> "BaseURL": """Return an URL with the same values, except for those parameters given new values by whichever keyword arguments are specified.""" return self._replace(**kwargs) @@ -142,14 +142,14 @@ def raw_password(self) -> t.Optional[str]: """ return self._split_auth()[1] - def decode_query(self, *args, **kwargs) -> "ds.MultiDict[str, str]": + def decode_query(self, *args: t.Any, **kwargs: t.Any) -> "ds.MultiDict[str, str]": """Decodes the query part of the URL. Ths is a shortcut for calling :func:`url_decode` on the query argument. The arguments and keyword arguments are forwarded to :func:`url_decode` unchanged. """ return url_decode(self.query, *args, **kwargs) - def join(self, *args, **kwargs) -> "BaseURL": + def join(self, *args: t.Any, **kwargs: t.Any) -> "BaseURL": """Joins this URL with another one. This is just a convenience function for calling into :meth:`url_join` and then parsing the return value again. @@ -339,7 +339,7 @@ class URL(BaseURL): _lbracket = "[" _rbracket = "]" - def encode(self, charset="utf-8", errors="replace") -> "BytesURL": + def encode(self, charset: str = "utf-8", errors: str = "replace") -> "BytesURL": """Encodes the URL to a tuple made out of bytes. The charset is only being used for the path, query and fragment. """ @@ -368,7 +368,7 @@ def encode_netloc(self) -> bytes: # type: ignore """Returns the netloc unchanged as bytes.""" return self.netloc # type: ignore - def decode(self, charset="utf-8", errors="replace") -> "URL": + def decode(self, charset: str = "utf-8", errors: str = "replace") -> "URL": """Decodes the URL to a tuple made out of strings. The charset is only being used for the path, query and fragment. """ @@ -422,7 +422,7 @@ def _url_encode_impl( charset: str, sort: bool, key: t.Optional[t.Callable[[t.Tuple[str, str]], t.Any]], -): +) -> t.Iterator[str]: from .datastructures import iter_multi_items iterable: t.Iterable[t.Tuple[str, str]] = iter_multi_items(obj) @@ -699,14 +699,14 @@ def url_fix(s: str, charset: str = "utf-8") -> str: _to_iri_unsafe = "".join([chr(c) for c in range(128) if c not in _always_safe]) -def _codec_error_url_quote(e): +def _codec_error_url_quote(e: UnicodeError) -> t.Tuple[str, int]: """Used in :func:`uri_to_iri` after unquoting to re-quote any invalid bytes. """ - # the docs state that `UnicodeError` does have these attributes, - # but mypy isn't picking them up? - out = _fast_url_quote(e.object[e.start : e.end]) - return out, e.end + # the docs state that UnicodeError does have these attributes, + # but mypy isn't picking them up + out = _fast_url_quote(e.object[e.start : e.end]) # type: ignore + return out, e.end # type: ignore codecs.register_error("werkzeug.url_quote", _codec_error_url_quote) @@ -870,7 +870,7 @@ def url_decode( def url_decode_stream( stream: t.BinaryIO, - charset="utf-8", + charset: str = "utf-8", decode_keys: None = None, include_empty: bool = True, errors: str = "replace", @@ -1026,7 +1026,7 @@ def url_encode_stream( separator = _to_str(separator, "ascii") gen = _url_encode_impl(obj, charset, sort, key) if stream is None: - return gen + return gen # type: ignore for idx, chunk in enumerate(gen): if idx: stream.write(separator) @@ -1038,7 +1038,7 @@ def url_join( base: t.Union[str, t.Tuple[str, str, str, str, str]], url: t.Union[str, t.Tuple[str, str, str, str, str]], allow_fragments: bool = True, -): +) -> str: """Join a base URL and a possibly relative URL to form an absolute interpretation of the latter. @@ -1160,7 +1160,9 @@ class Href: `sort` and `key` were added. """ - def __init__(self, base="./", charset="utf-8", sort=False, key=None): + def __init__( # type: ignore + self, base="./", charset="utf-8", sort=False, key=None + ): warnings.warn( "'Href' is deprecated and will be removed in Werkzeug 2.1." " Use 'werkzeug.routing' instead.", @@ -1175,7 +1177,7 @@ def __init__(self, base="./", charset="utf-8", sort=False, key=None): self.sort = sort self.key = key - def __getattr__(self, name): + def __getattr__(self, name): # type: ignore if name[:2] == "__": raise AttributeError(name) base = self.base @@ -1183,7 +1185,7 @@ def __getattr__(self, name): base += "/" return Href(url_join(base, name), self.charset, self.sort, self.key) - def __call__(self, *path, **query): + def __call__(self, *path, **query): # type: ignore if path and isinstance(path[-1], dict): if query: raise TypeError("keyword arguments and query-dicts can't be combined") diff --git a/src/werkzeug/user_agent.py b/src/werkzeug/user_agent.py index b1367553c..66ffcbe07 100644 --- a/src/werkzeug/user_agent.py +++ b/src/werkzeug/user_agent.py @@ -33,7 +33,7 @@ def __init__(self, string: str) -> None: self.string: str = string """The original header value.""" - def __repr__(self): + def __repr__(self) -> str: return f"<{type(self).__name__} {self.browser}/{self.version}>" def __str__(self) -> str: diff --git a/src/werkzeug/useragents.py b/src/werkzeug/useragents.py index 9930f72be..9ff940312 100644 --- a/src/werkzeug/useragents.py +++ b/src/werkzeug/useragents.py @@ -120,7 +120,7 @@ class UserAgentParser(_UserAgentParser): instead. """ - def __init__(self): + def __init__(self) -> None: warnings.warn( "'UserAgentParser' is deprecated and will be removed in" " Werkzeug 2.1. Use a dedicated parser library instead.", @@ -131,7 +131,7 @@ def __init__(self): class _deprecated_property(property): - def __init__(self, fget): + def __init__(self, fget: t.Callable[["_UserAgent"], t.Any]) -> None: super().__init__(fget) self.message = ( "The built-in user agent parser is deprecated and will be" @@ -141,7 +141,7 @@ def __init__(self, fget): " parser." ) - def __get__(self, *args, **kwargs): + def __get__(self, *args: t.Any, **kwargs: t.Any) -> t.Any: warnings.warn(self.message, DeprecationWarning, stacklevel=3) return super().__get__(*args, **kwargs) diff --git a/src/werkzeug/utils.py b/src/werkzeug/utils.py index cad786c76..80e2502ce 100644 --- a/src/werkzeug/utils.py +++ b/src/werkzeug/utils.py @@ -27,7 +27,8 @@ if t.TYPE_CHECKING: from wsgiref.types import WSGIEnvironment - from .wrappers import Response + from .wrappers.request import Request + from .wrappers.response import Response _entity_re = re.compile(r"&([^;]+);") _filename_ascii_strip_re = re.compile(r"[^A-Za-z0-9_.-]") @@ -158,14 +159,14 @@ class environ_property(_DictAccessorProperty[_TAccessorValue]): read_only = True - def lookup(self, obj: t.Any) -> "WSGIEnvironment": + def lookup(self, obj: "Request") -> "WSGIEnvironment": return obj.environ class header_property(_DictAccessorProperty[_TAccessorValue]): """Like `environ_property` but for headers.""" - def lookup(self, obj: t.Any) -> Headers: + def lookup(self, obj: t.Union["Request", "Response"]) -> Headers: return obj.headers @@ -238,10 +239,10 @@ class HTMLBuilder: _plaintext_elements = {"textarea"} _c_like_cdata = {"script", "style"} - def __init__(self, dialect): + def __init__(self, dialect): # type: ignore self._dialect = dialect - def __call__(self, s): + def __call__(self, s): # type: ignore import html warnings.warn( @@ -251,7 +252,7 @@ def __call__(self, s): ) return html.escape(s) - def __getattr__(self, tag): + def __getattr__(self, tag): # type: ignore import html warnings.warn( @@ -262,7 +263,7 @@ def __getattr__(self, tag): if tag[:2] == "__": raise AttributeError(tag) - def proxy(*children, **arguments): + def proxy(*children, **arguments): # type: ignore buffer = f"<{tag}" for key, value in arguments.items(): if value is None: @@ -299,7 +300,7 @@ def proxy(*children, **arguments): return proxy - def __repr__(self): + def __repr__(self) -> str: return f"<{type(self).__name__} for {self._dialect!r}>" @@ -401,7 +402,7 @@ def detect_utf_encoding(data: bytes) -> str: return "utf-8" -def format_string(string, context): +def format_string(string: str, context: t.Mapping[str, t.Any]) -> str: """String-template format a string: >>> format_string('$foo and ${foo}s', dict(foo=42)) @@ -474,7 +475,7 @@ def secure_filename(filename: str) -> str: return filename -def escape(s): +def escape(s: t.Any) -> str: """Replace ``&``, ``<``, ``>``, ``"``, and ``'`` with HTML-safe sequences. @@ -496,15 +497,15 @@ def escape(s): return "" if hasattr(s, "__html__"): - return s.__html__() + return s.__html__() # type: ignore if not isinstance(s, str): s = str(s) - return html.escape(s, quote=True) + return html.escape(s, quote=True) # type: ignore -def unescape(s): +def unescape(s: str) -> str: """The reverse of :func:`escape`. This unescapes all the HTML entities, not only those inserted by ``escape``. @@ -600,7 +601,7 @@ def send_file( use_x_sendfile: bool = False, response_class: t.Optional[t.Type["Response"]] = None, _root_path: t.Optional[t.Union[os.PathLike, str]] = None, -): +) -> "Response": """Send the contents of a file to the client. The first argument can be a file path or a file-like object. Paths @@ -803,7 +804,7 @@ def send_from_directory( directory: t.Union[os.PathLike, str], path: t.Union[os.PathLike, str], environ: "WSGIEnvironment", - **kwargs, + **kwargs: t.Any, ) -> "Response": """Send a file from within a directory using :func:`send_file`. @@ -914,7 +915,7 @@ def find_modules( yield modname -def validate_arguments(func, args, kwargs, drop_extra=True): +def validate_arguments(func, args, kwargs, drop_extra=True): # type: ignore """Checks if the function accepts the arguments and keyword arguments. Returns a new ``(args, kwargs)`` tuple that can safely be passed to the function without causing a `TypeError` because the function signature @@ -977,7 +978,7 @@ def proxy(request): return tuple(args), kwargs -def bind_arguments(func, args, kwargs): +def bind_arguments(func, args, kwargs): # type: ignore """Bind the arguments provided into a dict. When passed a function, a tuple of arguments and a dict of keyword arguments `bind_arguments` returns a dict of names as the function would see it. This can be useful @@ -1036,7 +1037,7 @@ class ArgumentValidationError(ValueError): ``validate_arguments``. """ - def __init__(self, missing=None, extra=None, extra_positional=None): + def __init__(self, missing=None, extra=None, extra_positional=None): # type: ignore self.missing = set(missing or ()) self.extra = extra or {} self.extra_positional = extra_positional or [] @@ -1055,7 +1056,7 @@ class ImportStringError(ImportError): #: Wrapped exception. exception: BaseException - def __init__(self, import_name, exception): + def __init__(self, import_name: str, exception: BaseException) -> None: self.import_name = import_name self.exception = exception msg = import_name @@ -1085,5 +1086,5 @@ def __init__(self, import_name, exception): super().__init__(msg) - def __repr__(self): + def __repr__(self) -> str: return f"<{type(self).__name__}({self.import_name!r}, {self.exception!r})>" diff --git a/src/werkzeug/wrappers/accept.py b/src/werkzeug/wrappers/accept.py index 6de5294a8..9605e637d 100644 --- a/src/werkzeug/wrappers/accept.py +++ b/src/werkzeug/wrappers/accept.py @@ -1,8 +1,9 @@ +import typing as t import warnings class AcceptMixin: - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: warnings.warn( "'AcceptMixin' is deprecated and will be removed in" " Werkzeug 2.1. 'Request' now includes the functionality" @@ -10,4 +11,4 @@ def __init__(self, *args, **kwargs): DeprecationWarning, stacklevel=2, ) - super().__init__(*args, **kwargs) + super().__init__(*args, **kwargs) # type: ignore diff --git a/src/werkzeug/wrappers/auth.py b/src/werkzeug/wrappers/auth.py index 11514e8db..da31b7cf7 100644 --- a/src/werkzeug/wrappers/auth.py +++ b/src/werkzeug/wrappers/auth.py @@ -1,8 +1,9 @@ +import typing as t import warnings class AuthorizationMixin: - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: warnings.warn( "'AuthorizationMixin' is deprecated and will be removed in" " Werkzeug 2.1. 'Request' now includes the functionality" @@ -10,11 +11,11 @@ def __init__(self, *args, **kwargs): DeprecationWarning, stacklevel=2, ) - super().__init__(*args, **kwargs) + super().__init__(*args, **kwargs) # type: ignore class WWWAuthenticateMixin: - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: warnings.warn( "'WWWAuthenticateMixin' is deprecated and will be removed" " in Werkzeug 2.1. 'Response' now includes the" @@ -22,4 +23,4 @@ def __init__(self, *args, **kwargs): DeprecationWarning, stacklevel=2, ) - super().__init__(*args, **kwargs) + super().__init__(*args, **kwargs) # type: ignore diff --git a/src/werkzeug/wrappers/base_request.py b/src/werkzeug/wrappers/base_request.py index 48dcc46c5..451989fd7 100644 --- a/src/werkzeug/wrappers/base_request.py +++ b/src/werkzeug/wrappers/base_request.py @@ -1,10 +1,11 @@ +import typing as t import warnings from .request import Request class _FakeSubclassCheck(type): - def __subclasscheck__(cls, subclass): + def __subclasscheck__(cls, subclass: t.Type) -> bool: warnings.warn( "'BaseRequest' is deprecated and will be removed in" " Werkzeug 2.1. Use 'issubclass(cls, Request)' instead.", @@ -13,7 +14,7 @@ def __subclasscheck__(cls, subclass): ) return issubclass(subclass, Request) - def __instancecheck__(cls, instance): + def __instancecheck__(cls, instance: t.Any) -> bool: warnings.warn( "'BaseRequest' is deprecated and will be removed in" " Werkzeug 2.1. Use 'isinstance(obj, Request)' instead.", @@ -24,7 +25,7 @@ def __instancecheck__(cls, instance): class BaseRequest(Request, metaclass=_FakeSubclassCheck): - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: warnings.warn( "'BaseRequest' is deprecated and will be removed in" " Werkzeug 2.1. 'Request' now includes the functionality" diff --git a/src/werkzeug/wrappers/base_response.py b/src/werkzeug/wrappers/base_response.py index a102d8943..3e0dc6766 100644 --- a/src/werkzeug/wrappers/base_response.py +++ b/src/werkzeug/wrappers/base_response.py @@ -1,10 +1,11 @@ +import typing as t import warnings from .response import Response class _FakeSubclassCheck(type): - def __subclasscheck__(cls, subclass): + def __subclasscheck__(cls, subclass: t.Type) -> bool: warnings.warn( "'BaseResponse' is deprecated and will be removed in" " Werkzeug 2.1. Use 'issubclass(cls, Response)' instead.", @@ -13,7 +14,7 @@ def __subclasscheck__(cls, subclass): ) return issubclass(subclass, Response) - def __instancecheck__(cls, instance): + def __instancecheck__(cls, instance: t.Any) -> bool: warnings.warn( "'BaseResponse' is deprecated and will be removed in" " Werkzeug 2.1. Use 'isinstance(obj, Response)' instead.", @@ -24,7 +25,7 @@ def __instancecheck__(cls, instance): class BaseResponse(Response, metaclass=_FakeSubclassCheck): - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: warnings.warn( "'BaseResponse' is deprecated and will be removed in" " Werkzeug 2.1. 'Response' now includes the functionality" diff --git a/src/werkzeug/wrappers/common_descriptors.py b/src/werkzeug/wrappers/common_descriptors.py index 77e935cc3..db87ea5fa 100644 --- a/src/werkzeug/wrappers/common_descriptors.py +++ b/src/werkzeug/wrappers/common_descriptors.py @@ -1,8 +1,9 @@ +import typing as t import warnings class CommonRequestDescriptorsMixin: - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: warnings.warn( "'CommonRequestDescriptorsMixin' is deprecated and will be" " removed in Werkzeug 2.1. 'Request' now includes the" @@ -10,11 +11,11 @@ def __init__(self, *args, **kwargs): DeprecationWarning, stacklevel=2, ) - super().__init__(*args, **kwargs) + super().__init__(*args, **kwargs) # type: ignore class CommonResponseDescriptorsMixin: - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: warnings.warn( "'CommonResponseDescriptorsMixin' is deprecated and will be" " removed in Werkzeug 2.1. 'Response' now includes the" @@ -22,4 +23,4 @@ def __init__(self, *args, **kwargs): DeprecationWarning, stacklevel=2, ) - super().__init__(*args, **kwargs) + super().__init__(*args, **kwargs) # type: ignore diff --git a/src/werkzeug/wrappers/cors.py b/src/werkzeug/wrappers/cors.py index 3039abecb..89cf83ef8 100644 --- a/src/werkzeug/wrappers/cors.py +++ b/src/werkzeug/wrappers/cors.py @@ -1,8 +1,9 @@ +import typing as t import warnings class CORSRequestMixin: - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: warnings.warn( "'CORSRequestMixin' is deprecated and will be removed in" " Werkzeug 2.1. 'Request' now includes the functionality" @@ -10,11 +11,11 @@ def __init__(self, *args, **kwargs): DeprecationWarning, stacklevel=2, ) - super().__init__(*args, **kwargs) + super().__init__(*args, **kwargs) # type: ignore class CORSResponseMixin: - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: warnings.warn( "'CORSResponseMixin' is deprecated and will be removed in" " Werkzeug 2.1. 'Response' now includes the functionality" @@ -22,4 +23,4 @@ def __init__(self, *args, **kwargs): DeprecationWarning, stacklevel=2, ) - super().__init__(*args, **kwargs) + super().__init__(*args, **kwargs) # type: ignore diff --git a/src/werkzeug/wrappers/etag.py b/src/werkzeug/wrappers/etag.py index 8b9f87875..2e9015a58 100644 --- a/src/werkzeug/wrappers/etag.py +++ b/src/werkzeug/wrappers/etag.py @@ -1,8 +1,9 @@ +import typing as t import warnings class ETagRequestMixin: - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: warnings.warn( "'ETagRequestMixin' is deprecated and will be removed in" " Werkzeug 2.1. 'Request' now includes the functionality" @@ -10,11 +11,11 @@ def __init__(self, *args, **kwargs): DeprecationWarning, stacklevel=2, ) - super().__init__(*args, **kwargs) + super().__init__(*args, **kwargs) # type: ignore class ETagResponseMixin: - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: warnings.warn( "'ETagResponseMixin' is deprecated and will be removed in" " Werkzeug 2.1. 'Response' now includes the functionality" @@ -22,4 +23,4 @@ def __init__(self, *args, **kwargs): DeprecationWarning, stacklevel=2, ) - super().__init__(*args, **kwargs) + super().__init__(*args, **kwargs) # type: ignore diff --git a/src/werkzeug/wrappers/json.py b/src/werkzeug/wrappers/json.py index a9ff5295b..ab6ed7ba9 100644 --- a/src/werkzeug/wrappers/json.py +++ b/src/werkzeug/wrappers/json.py @@ -1,12 +1,13 @@ +import typing as t import warnings class JSONMixin: - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: warnings.warn( "'JSONMixin' is deprecated and will be removed in Werkzeug" " 2.1. 'Request' now includes the functionality directly.", DeprecationWarning, stacklevel=2, ) - super().__init__(*args, **kwargs) + super().__init__(*args, **kwargs) # type: ignore diff --git a/src/werkzeug/wrappers/request.py b/src/werkzeug/wrappers/request.py index 255913772..ecbd02017 100644 --- a/src/werkzeug/wrappers/request.py +++ b/src/werkzeug/wrappers/request.py @@ -1,5 +1,6 @@ import functools import json +import typing import typing as t import warnings from io import BytesIO @@ -21,6 +22,7 @@ from werkzeug.exceptions import BadRequest if t.TYPE_CHECKING: + import typing_extensions as te from wsgiref.types import WSGIApplication from wsgiref.types import WSGIEnvironment @@ -140,7 +142,7 @@ def __init__( self.environ["werkzeug.request"] = self @classmethod - def from_values(cls, *args, **kwargs) -> "Request": + def from_values(cls, *args: t.Any, **kwargs: t.Any) -> "Request": """Create a new request object based on the values provided. If environ is given missing values are filled from there. This method is useful for small scripts when you need to simulate a request from an URL. @@ -197,7 +199,7 @@ def my_wsgi_app(request): from ..exceptions import HTTPException @functools.wraps(f) - def application(*args): + def application(*args): # type: ignore request = cls(args[-2]) with request: try: @@ -206,15 +208,15 @@ def application(*args): resp = e.get_response(args[-2]) return resp(*args[-2:]) - return application + return t.cast("WSGIApplication", application) def _get_file_stream( self, - total_content_length: int, + total_content_length: t.Optional[int], content_type: t.Optional[str], filename: t.Optional[str] = None, content_length: t.Optional[int] = None, - ): + ) -> t.BinaryIO: """Called to get a stream for the file upload. This must provide a file-like class with `read()`, `readline()` @@ -308,7 +310,7 @@ def _get_stream_for_parsing(self) -> t.BinaryIO: cached_data = getattr(self, "_cached_data", None) if cached_data is not None: return BytesIO(cached_data) - return self.stream + return self.stream # type: ignore def close(self) -> None: """Closes associated resources of this request object. This @@ -324,7 +326,7 @@ def close(self) -> None: def __enter__(self) -> "Request": return self - def __exit__(self, exc_type, exc_value, tb) -> None: + def __exit__(self, exc_type, exc_value, tb) -> None: # type: ignore self.close() @cached_property @@ -370,9 +372,27 @@ def data(self) -> bytes: """ return self.get_data(parse_form_data=True) + @typing.overload + def get_data( # type: ignore + self, + cache: bool = True, + as_text: "te.Literal[False]" = False, + parse_form_data: bool = False, + ) -> bytes: + ... + + @typing.overload + def get_data( + self, + cache: bool = True, + as_text: "te.Literal[True]" = ..., + parse_form_data: bool = False, + ) -> str: + ... + def get_data( self, cache: bool = True, as_text: bool = False, parse_form_data: bool = False - ) -> bytes: + ) -> t.Union[bytes, str]: """This reads the buffered incoming data from the client into one bytes object. By default this is cached but that behavior can be changed by setting `cache` to `False`. @@ -406,7 +426,7 @@ def get_data( self._cached_data = rv if as_text: rv = rv.decode(self.charset, self.encoding_errors) - return rv + return rv # type: ignore @cached_property def form(self) -> "ImmutableMultiDict[str, str]": @@ -425,7 +445,7 @@ def form(self) -> "ImmutableMultiDict[str, str]": and PUT requests. """ self._load_form_data() - return self.form + return self.form # type: ignore @cached_property def values(self) -> "CombinedMultiDict[str, str]": @@ -477,7 +497,7 @@ def files(self) -> "ImmutableMultiDict[str, FileStorage]": more details about the used data structure. """ self._load_form_data() - return self.files + return self.files # type: ignore @property def script_root(self) -> str: @@ -487,11 +507,11 @@ def script_root(self) -> str: return self.root_path @cached_property - def url_root(self): + def url_root(self) -> str: """Alias for :attr:`root_url`. The URL with scheme, host, and root path. For example, ``https://example.com/app/``. """ - return self.root_url + return self.root_url # type: ignore remote_user = environ_property[str]( "REMOTE_USER", @@ -603,7 +623,7 @@ class StreamOnlyMixin: .. versionadded:: 0.9 """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: warnings.warn( "'StreamOnlyMixin' is deprecated and will be removed in" " Werkzeug 2.1. Create the request with 'shallow=True'" @@ -612,7 +632,7 @@ def __init__(self, *args, **kwargs): stacklevel=2, ) kwargs["shallow"] = True - super().__init__(*args, **kwargs) + super().__init__(*args, **kwargs) # type: ignore class PlainRequest(StreamOnlyMixin, Request): @@ -625,7 +645,7 @@ class PlainRequest(StreamOnlyMixin, Request): .. versionadded:: 0.9 """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: warnings.warn( "'PlainRequest' is deprecated and will be removed in" " Werkzeug 2.1. Create the request with 'shallow=True'" diff --git a/src/werkzeug/wrappers/response.py b/src/werkzeug/wrappers/response.py index 2e8aa6fc6..60d011c1b 100644 --- a/src/werkzeug/wrappers/response.py +++ b/src/werkzeug/wrappers/response.py @@ -293,7 +293,7 @@ def get_data(self, as_text: "te.Literal[False]" = False) -> bytes: def get_data(self, as_text: "te.Literal[True]") -> str: ... - def get_data(self, as_text=False): + def get_data(self, as_text: bool = False) -> t.Union[bytes, str]: """The string representation of the response body. Whenever you call this property the response iterable is encoded and flattened. This can lead to unwanted behavior if you stream big data. @@ -308,8 +308,10 @@ def get_data(self, as_text=False): """ self._ensure_sequence() rv = b"".join(self.iter_encoded()) + if as_text: - rv = rv.decode(self.charset) + return rv.decode(self.charset) + return rv def set_data(self, value: t.Union[bytes, str]) -> None: @@ -439,7 +441,7 @@ def close(self) -> None: def __enter__(self) -> "Response": return self - def __exit__(self, exc_type, exc_value, tb): + def __exit__(self, exc_type, exc_value, tb): # type: ignore self.close() def freeze(self, no_etag: None = None) -> None: @@ -530,7 +532,7 @@ def get_wsgi_headers(self, environ: "WSGIEnvironment") -> Headers: current_url = iri_to_uri(current_url) location = url_join(current_url, location) if location != old_location: - headers["Location"] = location # type: ignore + headers["Location"] = location # make sure the content location is a URL if content_location is not None and isinstance(content_location, str): @@ -750,7 +752,7 @@ def make_conditional( request_or_environ: "WSGIEnvironment", accept_ranges: t.Union[bool, str] = False, complete_length: t.Optional[int] = None, - ): + ) -> "Response": """Make the response conditional to the request. This method works best if an etag was defined for the response already. The `add_etag` method can be used to do that. If called without etag just the date @@ -877,7 +879,7 @@ def encoding(self) -> str: class ResponseStreamMixin: - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: warnings.warn( "'ResponseStreamMixin' is deprecated and will be removed in" " Werkzeug 2.1. 'Response' now includes the functionality" @@ -885,4 +887,4 @@ def __init__(self, *args, **kwargs): DeprecationWarning, stacklevel=2, ) - super().__init__(*args, **kwargs) + super().__init__(*args, **kwargs) # type: ignore diff --git a/src/werkzeug/wrappers/user_agent.py b/src/werkzeug/wrappers/user_agent.py index 292d2af6b..184ffd023 100644 --- a/src/werkzeug/wrappers/user_agent.py +++ b/src/werkzeug/wrappers/user_agent.py @@ -1,8 +1,9 @@ +import typing as t import warnings class UserAgentMixin: - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: warnings.warn( "'UserAgentMixin' is deprecated and will be removed in" " Werkzeug 2.1. 'Request' now includes the functionality" @@ -10,4 +11,4 @@ def __init__(self, *args, **kwargs): DeprecationWarning, stacklevel=2, ) - super().__init__(*args, **kwargs) + super().__init__(*args, **kwargs) # type: ignore diff --git a/src/werkzeug/wsgi.py b/src/werkzeug/wsgi.py index c06b2b0c7..587cf3f42 100644 --- a/src/werkzeug/wsgi.py +++ b/src/werkzeug/wsgi.py @@ -157,7 +157,7 @@ def get_input_stream( content length is not set. Disabling this allows infinite streams, which can be a denial-of-service risk. """ - stream = environ["wsgi.input"] + stream = t.cast(t.BinaryIO, environ["wsgi.input"]) content_length = get_content_length(environ) # A wsgi extension that tells us if the input is terminated. In @@ -206,7 +206,7 @@ def get_path_info( .. versionadded:: 0.9 """ path = environ.get("PATH_INFO", "").encode("latin1") - return _to_str(path, charset, errors, allow_none_charset=True) + return _to_str(path, charset, errors, allow_none_charset=True) # type: ignore def get_script_name( @@ -223,7 +223,7 @@ def get_script_name( .. versionadded:: 0.9 """ path = environ.get("SCRIPT_NAME", "").encode("latin1") - return _to_str(path, charset, errors, allow_none_charset=True) + return _to_str(path, charset, errors, allow_none_charset=True) # type: ignore def pop_path_info( @@ -281,7 +281,7 @@ def pop_path_info( environ["SCRIPT_NAME"] = script_name + segment rv = segment.encode("latin1") - return _to_str(rv, charset, errors, allow_none_charset=True) + return _to_str(rv, charset, errors, allow_none_charset=True) # type: ignore def peek_path_info( @@ -309,7 +309,7 @@ def peek_path_info( """ segments = environ.get("PATH_INFO", "").lstrip("/").split("/", 1) if segments: - return _to_str( + return _to_str( # type: ignore segments[0].encode("latin1"), charset, errors, allow_none_charset=True ) return None @@ -361,8 +361,10 @@ def extract_path_info( .. versionadded:: 0.6 """ - def _normalize_netloc(scheme, netloc): + def _normalize_netloc(scheme: str, netloc: str) -> str: parts = netloc.split("@", 1)[-1].split(":", 1) + port: t.Optional[str] + if len(parts) == 2: netloc, port = parts if (scheme == "http" and port == "80") or ( @@ -372,8 +374,10 @@ def _normalize_netloc(scheme, netloc): else: netloc = parts[0] port = None + if port is not None: netloc += f":{port}" + return netloc # make sure whatever we are working on is a IRI and parse it @@ -480,7 +484,9 @@ def wrap_file( :param file: a :class:`file`-like object with a :meth:`~file.read` method. :param buffer_size: number of bytes for one iteration. """ - return environ.get("wsgi.file_wrapper", FileWrapper)(file, buffer_size) + return environ.get("wsgi.file_wrapper", FileWrapper)( # type: ignore + file, buffer_size + ) class FileWrapper: @@ -516,7 +522,7 @@ def seekable(self) -> bool: return True return False - def seek(self, *args) -> None: + def seek(self, *args: t.Any) -> None: if hasattr(self.file, "seek"): self.file.seek(*args)