-
-
Notifications
You must be signed in to change notification settings - Fork 857
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: fully type annotate datastructures.py #1403
Changes from all commits
8565404
068f802
8eadb22
506eeae
4ab4843
803e91a
417a3d7
55f7bf6
26e8996
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,35 @@ | ||
import sys | ||
import tempfile | ||
import typing | ||
from collections import namedtuple | ||
from collections.abc import Sequence | ||
from shlex import shlex | ||
from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit | ||
|
||
if sys.version_info >= (3, 8): # pragma: no cover | ||
from typing import Protocol | ||
else: # pragma: no cover | ||
from typing_extensions import Protocol | ||
|
||
from starlette.concurrency import run_in_threadpool | ||
from starlette.types import Scope | ||
|
||
Address = namedtuple("Address", ["host", "port"]) | ||
_T = typing.TypeVar("_T") | ||
_KT = typing.TypeVar("_KT") # key type | ||
_VT = typing.TypeVar("_VT") # value type | ||
_KT_co = typing.TypeVar("_KT_co", covariant=True) # key type for covariant containers | ||
_VT_co = typing.TypeVar("_VT_co", covariant=True) # value type for covariant containers | ||
|
||
|
||
class Address(typing.NamedTuple): | ||
host: typing.Optional[str] | ||
port: typing.Optional[int] | ||
|
||
|
||
class URL: | ||
def __init__( | ||
self, url: str = "", scope: Scope = None, **components: typing.Any | ||
self, | ||
url: str = "", | ||
scope: typing.Optional[Scope] = None, | ||
**components: typing.Any, | ||
) -> None: | ||
if scope is not None: | ||
assert not url, 'Cannot set both "url" and "scope".' | ||
|
@@ -202,7 +218,7 @@ def __str__(self) -> str: | |
return self._value | ||
|
||
|
||
class CommaSeparatedStrings(Sequence): | ||
class CommaSeparatedStrings(typing.Sequence[str]): | ||
def __init__(self, value: typing.Union[str, typing.Sequence[str]]): | ||
if isinstance(value, str): | ||
splitter = shlex(value, posix=True) | ||
|
@@ -215,6 +231,14 @@ def __init__(self, value: typing.Union[str, typing.Sequence[str]]): | |
def __len__(self) -> int: | ||
return len(self._items) | ||
|
||
@typing.overload | ||
def __getitem__(self, index: int) -> str: | ||
... # pragma: no cover | ||
|
||
@typing.overload | ||
def __getitem__(self, index: slice) -> typing.Sequence[str]: | ||
... # pragma: no cover | ||
Comment on lines
+234
to
+240
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this PR doesn't get merged, you can also create a separate PR for this. 👍 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep may have to split this up |
||
|
||
def __getitem__(self, index: typing.Union[int, slice]) -> typing.Any: | ||
return self._items[index] | ||
|
||
|
@@ -230,13 +254,14 @@ def __str__(self) -> str: | |
return ", ".join(repr(item) for item in self) | ||
|
||
|
||
class ImmutableMultiDict(typing.Mapping): | ||
class ImmutableMultiDict(typing.Mapping[_KT, _VT_co]): | ||
_dict: typing.Dict[_KT, _VT_co] | ||
def __init__( | ||
self, | ||
*args: typing.Union[ | ||
"ImmutableMultiDict", | ||
typing.Mapping, | ||
typing.List[typing.Tuple[typing.Any, typing.Any]], | ||
"ImmutableMultiDict[_KT, _VT_co]", | ||
typing.Mapping[_KT, _VT_co], | ||
typing.Sequence[typing.Tuple[_KT, _VT_co]], | ||
], | ||
**kwargs: typing.Any, | ||
) -> None: | ||
|
@@ -246,7 +271,7 @@ def __init__( | |
if kwargs: | ||
value = ( | ||
ImmutableMultiDict(value).multi_items() | ||
+ ImmutableMultiDict(kwargs).multi_items() | ||
+ ImmutableMultiDict(kwargs).multi_items() # type: ignore[operator] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is not a valid operation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. because the type of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this could be fixed with a bit of juggling of how this is handled, but I wanted to do this with the least amount of runtime changes necessary, so adding a |
||
) | ||
|
||
if not value: | ||
|
@@ -266,33 +291,43 @@ def __init__( | |
self._dict = {k: v for k, v in _items} | ||
self._list = _items | ||
|
||
def getlist(self, key: typing.Any) -> typing.List[typing.Any]: | ||
def getlist(self, key: typing.Any) -> typing.List[_VT_co]: | ||
return [item_value for item_key, item_value in self._list if item_key == key] | ||
|
||
def keys(self) -> typing.KeysView: | ||
def keys(self) -> typing.KeysView[_KT]: | ||
return self._dict.keys() | ||
|
||
def values(self) -> typing.ValuesView: | ||
def values(self) -> typing.ValuesView[_VT_co]: | ||
return self._dict.values() | ||
|
||
def items(self) -> typing.ItemsView: | ||
def items(self) -> typing.ItemsView[_KT, _VT_co]: | ||
return self._dict.items() | ||
|
||
def multi_items(self) -> typing.List[typing.Tuple[str, str]]: | ||
def multi_items(self) -> typing.List[typing.Tuple[_KT, _VT_co]]: | ||
return list(self._list) | ||
|
||
def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any: | ||
@typing.overload | ||
def get(self, key: _KT) -> _VT_co: | ||
... # pragma: no cover | ||
|
||
@typing.overload | ||
def get( | ||
self, key: _KT, default: typing.Optional[typing.Union[_VT_co, _T]] = ... | ||
) -> typing.Union[_VT_co, _T]: | ||
... # pragma: no cover | ||
|
||
def get(self, key: _KT, default: typing.Any = None) -> typing.Any: | ||
if key in self._dict: | ||
return self._dict[key] | ||
return default | ||
|
||
def __getitem__(self, key: typing.Any) -> str: | ||
def __getitem__(self, key: _KT) -> _VT_co: | ||
return self._dict[key] | ||
|
||
def __contains__(self, key: typing.Any) -> bool: | ||
return key in self._dict | ||
|
||
def __iter__(self) -> typing.Iterator[typing.Any]: | ||
def __iter__(self) -> typing.Iterator[_KT]: | ||
return iter(self.keys()) | ||
|
||
def __len__(self) -> int: | ||
|
@@ -309,24 +344,34 @@ def __repr__(self) -> str: | |
return f"{class_name}({items!r})" | ||
|
||
|
||
class MultiDict(ImmutableMultiDict): | ||
def __setitem__(self, key: typing.Any, value: typing.Any) -> None: | ||
class MultiDict(ImmutableMultiDict[_KT, _VT]): | ||
def __setitem__(self, key: _KT, value: _VT) -> None: | ||
self.setlist(key, [value]) | ||
|
||
def __delitem__(self, key: typing.Any) -> None: | ||
def __delitem__(self, key: _KT) -> None: | ||
self._list = [(k, v) for k, v in self._list if k != key] | ||
del self._dict[key] | ||
|
||
def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any: | ||
@typing.overload | ||
def pop(self, key: _KT) -> _VT: | ||
... # pragma: no cover | ||
|
||
@typing.overload | ||
def pop( | ||
self, key: _KT, default: typing.Optional[typing.Union[_VT, _T]] = ... | ||
) -> typing.Union[_VT, _T, None]: | ||
... # pragma: no cover | ||
|
||
def pop(self, key: _KT, default: typing.Any = None) -> typing.Any: | ||
self._list = [(k, v) for k, v in self._list if k != key] | ||
return self._dict.pop(key, default) | ||
|
||
def popitem(self) -> typing.Tuple: | ||
def popitem(self) -> typing.Tuple[_KT, _VT]: | ||
key, value = self._dict.popitem() | ||
self._list = [(k, v) for k, v in self._list if k != key] | ||
return key, value | ||
|
||
def poplist(self, key: typing.Any) -> typing.List: | ||
def poplist(self, key: typing.Any) -> typing.List[_VT]: | ||
values = [v for k, v in self._list if k == key] | ||
self.pop(key) | ||
return values | ||
|
@@ -335,14 +380,14 @@ def clear(self) -> None: | |
self._dict.clear() | ||
self._list.clear() | ||
|
||
def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any: | ||
def setdefault(self, key: _KT, default: typing.Union[_VT, None] = None) -> _VT: | ||
if key not in self: | ||
self._dict[key] = default | ||
self._list.append((key, default)) | ||
|
||
return self[key] | ||
|
||
def setlist(self, key: typing.Any, values: typing.List) -> None: | ||
def setlist(self, key: _KT, values: typing.Sequence[_VT]) -> None: | ||
if not values: | ||
self.pop(key, None) | ||
else: | ||
|
@@ -357,8 +402,8 @@ def append(self, key: typing.Any, value: typing.Any) -> None: | |
def update( | ||
self, | ||
*args: typing.Union[ | ||
"MultiDict", | ||
typing.Mapping, | ||
"MultiDict[_KT, _VT]", | ||
typing.Mapping[_KT, _VT], | ||
typing.List[typing.Tuple[typing.Any, typing.Any]], | ||
], | ||
**kwargs: typing.Any, | ||
|
@@ -369,16 +414,16 @@ def update( | |
self._dict.update(value) | ||
|
||
|
||
class QueryParams(ImmutableMultiDict): | ||
class QueryParams(ImmutableMultiDict[str, str]): | ||
""" | ||
An immutable multidict. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*args: typing.Union[ | ||
"ImmutableMultiDict", | ||
typing.Mapping, | ||
"ImmutableMultiDict[str, str]", | ||
typing.Mapping[str, str], | ||
typing.List[typing.Tuple[typing.Any, typing.Any]], | ||
str, | ||
bytes, | ||
|
@@ -463,7 +508,7 @@ async def close(self) -> None: | |
await run_in_threadpool(self.file.close) | ||
|
||
|
||
class FormData(ImmutableMultiDict): | ||
class FormData(ImmutableMultiDict[str, typing.Union[str, UploadFile]]): | ||
""" | ||
An immutable multidict, containing both file uploads and text input. | ||
""" | ||
|
@@ -480,7 +525,7 @@ def __init__( | |
super().__init__(*args, **kwargs) | ||
|
||
async def close(self) -> None: | ||
for key, value in self.multi_items(): | ||
for _, value in self.multi_items(): | ||
if isinstance(value, UploadFile): | ||
await value.close() | ||
|
||
|
@@ -492,9 +537,9 @@ class Headers(typing.Mapping[str, str]): | |
|
||
def __init__( | ||
self, | ||
headers: typing.Mapping[str, str] = None, | ||
raw: typing.List[typing.Tuple[bytes, bytes]] = None, | ||
scope: Scope = None, | ||
headers: typing.Optional[typing.Mapping[str, str]] = None, | ||
raw: typing.Optional[typing.List[typing.Tuple[bytes, bytes]]] = None, | ||
scope: typing.Optional[Scope] = None, | ||
) -> None: | ||
self._list: typing.List[typing.Tuple[bytes, bytes]] = [] | ||
if headers is not None: | ||
|
@@ -515,17 +560,27 @@ def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]: | |
return list(self._list) | ||
|
||
def keys(self) -> typing.List[str]: # type: ignore | ||
return [key.decode("latin-1") for key, value in self._list] | ||
return [key.decode("latin-1") for key, _ in self._list] | ||
|
||
def values(self) -> typing.List[str]: # type: ignore | ||
return [value.decode("latin-1") for key, value in self._list] | ||
return [value.decode("latin-1") for _, value in self._list] | ||
|
||
def items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore | ||
return [ | ||
(key.decode("latin-1"), value.decode("latin-1")) | ||
for key, value in self._list | ||
] | ||
|
||
@typing.overload | ||
def get(self, key: str) -> str: | ||
... # pragma: no cover | ||
|
||
@typing.overload | ||
def get( | ||
self, key: str, default: typing.Optional[typing.Union[str, _T]] = ... | ||
) -> typing.Union[str, _T]: | ||
... # pragma: no cover | ||
|
||
def get(self, key: str, default: typing.Any = None) -> typing.Any: | ||
try: | ||
return self[key] | ||
|
@@ -551,13 +606,14 @@ def __getitem__(self, key: str) -> str: | |
raise KeyError(key) | ||
|
||
def __contains__(self, key: typing.Any) -> bool: | ||
assert isinstance(key, str) | ||
get_header_key = key.lower().encode("latin-1") | ||
for header_key, header_value in self._list: | ||
for header_key, _ in self._list: | ||
if header_key == get_header_key: | ||
return True | ||
return False | ||
|
||
def __iter__(self) -> typing.Iterator[typing.Any]: | ||
def __iter__(self) -> typing.Iterator[str]: | ||
return iter(self.keys()) | ||
|
||
def __len__(self) -> int: | ||
|
@@ -576,6 +632,11 @@ def __repr__(self) -> str: | |
return f"{class_name}(raw={self.raw!r})" | ||
|
||
|
||
class _SupportsItems(Protocol[_KT_co, _VT_co]): | ||
def items(self) -> typing.Iterable[typing.Tuple[_KT_co, _VT_co]]: | ||
... # pragma: no cover | ||
|
||
|
||
class MutableHeaders(Headers): | ||
def __setitem__(self, key: str, value: str) -> None: | ||
""" | ||
|
@@ -585,8 +646,8 @@ def __setitem__(self, key: str, value: str) -> None: | |
set_key = key.lower().encode("latin-1") | ||
set_value = value.encode("latin-1") | ||
|
||
found_indexes = [] | ||
for idx, (item_key, item_value) in enumerate(self._list): | ||
found_indexes: typing.List[int] = [] | ||
for idx, (item_key, _) in enumerate(self._list): | ||
if item_key == set_key: | ||
found_indexes.append(idx) | ||
|
||
|
@@ -605,8 +666,8 @@ def __delitem__(self, key: str) -> None: | |
""" | ||
del_key = key.lower().encode("latin-1") | ||
|
||
pop_indexes = [] | ||
for idx, (item_key, item_value) in enumerate(self._list): | ||
pop_indexes: typing.List[int] = [] | ||
for idx, (item_key, _) in enumerate(self._list): | ||
if item_key == del_key: | ||
pop_indexes.append(idx) | ||
|
||
|
@@ -625,13 +686,13 @@ def setdefault(self, key: str, value: str) -> str: | |
set_key = key.lower().encode("latin-1") | ||
set_value = value.encode("latin-1") | ||
|
||
for idx, (item_key, item_value) in enumerate(self._list): | ||
for _, (item_key, item_value) in enumerate(self._list): | ||
if item_key == set_key: | ||
return item_value.decode("latin-1") | ||
self._list.append((set_key, set_value)) | ||
return value | ||
|
||
def update(self, other: dict) -> None: | ||
def update(self, other: _SupportsItems[str, str]) -> None: | ||
for key, val in other.items(): | ||
self[key] = val | ||
|
||
|
@@ -657,7 +718,9 @@ class State: | |
Used for `request.state` and `app.state`. | ||
""" | ||
|
||
def __init__(self, state: typing.Dict = None): | ||
def __init__( | ||
self, state: typing.Optional[typing.Mapping[typing.Any, typing.Any]] = None | ||
): | ||
if state is None: | ||
state = {} | ||
super().__setattr__("_state", state) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain this change request to me please, or reference some documentation on the subject? I understand that there is both a practical and semantic difference, but as far as I understand using
Scope | None
is now preferred if the default value isNone
. For example, see https://stackoverflow.com/questions/62732402/can-i-omit-optional-if-i-set-default-to-none