Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: fully type annotate datastructures.py #1403

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
157 changes: 110 additions & 47 deletions starlette/datastructures.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,35 @@
import sys
import tempfile
import typing
from collections import namedtuple
from collections.abc import Sequence
from shlex import shlex
from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit

if sys.version_info >= (3, 8): # pragma: no cover
from typing import Protocol
else: # pragma: no cover
from typing_extensions import Protocol

from starlette.concurrency import run_in_threadpool
from starlette.types import Scope

Address = namedtuple("Address", ["host", "port"])
_T = typing.TypeVar("_T")
_KT = typing.TypeVar("_KT") # key type
_VT = typing.TypeVar("_VT") # value type
_KT_co = typing.TypeVar("_KT_co", covariant=True) # key type for covariant containers
_VT_co = typing.TypeVar("_VT_co", covariant=True) # value type for covariant containers


class Address(typing.NamedTuple):
host: typing.Optional[str]
port: typing.Optional[int]


class URL:
def __init__(
self, url: str = "", scope: Scope = None, **components: typing.Any
self,
url: str = "",
scope: typing.Optional[Scope] = None,
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
scope: typing.Optional[Scope] = None,
scope: Scope = None,

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain this change request to me please, or reference some documentation on the subject? I understand that there is both a practical and semantic difference, but as far as I understand using Scope | None is now preferred if the default value is None. For example, see https://stackoverflow.com/questions/62732402/can-i-omit-optional-if-i-set-default-to-none

**components: typing.Any,
) -> None:
if scope is not None:
assert not url, 'Cannot set both "url" and "scope".'
Expand Down Expand Up @@ -202,7 +218,7 @@ def __str__(self) -> str:
return self._value


class CommaSeparatedStrings(Sequence):
class CommaSeparatedStrings(typing.Sequence[str]):
def __init__(self, value: typing.Union[str, typing.Sequence[str]]):
if isinstance(value, str):
splitter = shlex(value, posix=True)
Expand All @@ -215,6 +231,14 @@ def __init__(self, value: typing.Union[str, typing.Sequence[str]]):
def __len__(self) -> int:
return len(self._items)

@typing.overload
def __getitem__(self, index: int) -> str:
... # pragma: no cover

@typing.overload
def __getitem__(self, index: slice) -> typing.Sequence[str]:
... # pragma: no cover
Comment on lines +234 to +240
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this PR doesn't get merged, you can also create a separate PR for this. 👍

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep may have to split this up


def __getitem__(self, index: typing.Union[int, slice]) -> typing.Any:
return self._items[index]

Expand All @@ -230,13 +254,14 @@ def __str__(self) -> str:
return ", ".join(repr(item) for item in self)


class ImmutableMultiDict(typing.Mapping):
class ImmutableMultiDict(typing.Mapping[_KT, _VT_co]):
_dict: typing.Dict[_KT, _VT_co]
def __init__(
self,
*args: typing.Union[
"ImmutableMultiDict",
typing.Mapping,
typing.List[typing.Tuple[typing.Any, typing.Any]],
"ImmutableMultiDict[_KT, _VT_co]",
typing.Mapping[_KT, _VT_co],
typing.Sequence[typing.Tuple[_KT, _VT_co]],
],
**kwargs: typing.Any,
) -> None:
Expand All @@ -246,7 +271,7 @@ def __init__(
if kwargs:
value = (
ImmutableMultiDict(value).multi_items()
+ ImmutableMultiDict(kwargs).multi_items()
+ ImmutableMultiDict(kwargs).multi_items() # type: ignore[operator]
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is not a valid operation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because the type of value cannot be inferred from the expression value = args[0] if args or []

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could be fixed with a bit of juggling of how this is handled, but I wanted to do this with the least amount of runtime changes necessary, so adding a type: ignore for now and maybe doing a rework in a future PR seems like the best bet

)

if not value:
Expand All @@ -266,33 +291,43 @@ def __init__(
self._dict = {k: v for k, v in _items}
self._list = _items

def getlist(self, key: typing.Any) -> typing.List[typing.Any]:
def getlist(self, key: typing.Any) -> typing.List[_VT_co]:
return [item_value for item_key, item_value in self._list if item_key == key]

def keys(self) -> typing.KeysView:
def keys(self) -> typing.KeysView[_KT]:
return self._dict.keys()

def values(self) -> typing.ValuesView:
def values(self) -> typing.ValuesView[_VT_co]:
return self._dict.values()

def items(self) -> typing.ItemsView:
def items(self) -> typing.ItemsView[_KT, _VT_co]:
return self._dict.items()

def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
def multi_items(self) -> typing.List[typing.Tuple[_KT, _VT_co]]:
return list(self._list)

def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
@typing.overload
def get(self, key: _KT) -> _VT_co:
... # pragma: no cover

@typing.overload
def get(
self, key: _KT, default: typing.Optional[typing.Union[_VT_co, _T]] = ...
) -> typing.Union[_VT_co, _T]:
... # pragma: no cover

def get(self, key: _KT, default: typing.Any = None) -> typing.Any:
if key in self._dict:
return self._dict[key]
return default

def __getitem__(self, key: typing.Any) -> str:
def __getitem__(self, key: _KT) -> _VT_co:
return self._dict[key]

def __contains__(self, key: typing.Any) -> bool:
return key in self._dict

def __iter__(self) -> typing.Iterator[typing.Any]:
def __iter__(self) -> typing.Iterator[_KT]:
return iter(self.keys())

def __len__(self) -> int:
Expand All @@ -309,24 +344,34 @@ def __repr__(self) -> str:
return f"{class_name}({items!r})"


class MultiDict(ImmutableMultiDict):
def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
class MultiDict(ImmutableMultiDict[_KT, _VT]):
def __setitem__(self, key: _KT, value: _VT) -> None:
self.setlist(key, [value])

def __delitem__(self, key: typing.Any) -> None:
def __delitem__(self, key: _KT) -> None:
self._list = [(k, v) for k, v in self._list if k != key]
del self._dict[key]

def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
@typing.overload
def pop(self, key: _KT) -> _VT:
... # pragma: no cover

@typing.overload
def pop(
self, key: _KT, default: typing.Optional[typing.Union[_VT, _T]] = ...
) -> typing.Union[_VT, _T, None]:
... # pragma: no cover

def pop(self, key: _KT, default: typing.Any = None) -> typing.Any:
self._list = [(k, v) for k, v in self._list if k != key]
return self._dict.pop(key, default)

def popitem(self) -> typing.Tuple:
def popitem(self) -> typing.Tuple[_KT, _VT]:
key, value = self._dict.popitem()
self._list = [(k, v) for k, v in self._list if k != key]
return key, value

def poplist(self, key: typing.Any) -> typing.List:
def poplist(self, key: typing.Any) -> typing.List[_VT]:
values = [v for k, v in self._list if k == key]
self.pop(key)
return values
Expand All @@ -335,14 +380,14 @@ def clear(self) -> None:
self._dict.clear()
self._list.clear()

def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
def setdefault(self, key: _KT, default: typing.Union[_VT, None] = None) -> _VT:
if key not in self:
self._dict[key] = default
self._list.append((key, default))

return self[key]

def setlist(self, key: typing.Any, values: typing.List) -> None:
def setlist(self, key: _KT, values: typing.Sequence[_VT]) -> None:
if not values:
self.pop(key, None)
else:
Expand All @@ -357,8 +402,8 @@ def append(self, key: typing.Any, value: typing.Any) -> None:
def update(
self,
*args: typing.Union[
"MultiDict",
typing.Mapping,
"MultiDict[_KT, _VT]",
typing.Mapping[_KT, _VT],
typing.List[typing.Tuple[typing.Any, typing.Any]],
],
**kwargs: typing.Any,
Expand All @@ -369,16 +414,16 @@ def update(
self._dict.update(value)


class QueryParams(ImmutableMultiDict):
class QueryParams(ImmutableMultiDict[str, str]):
"""
An immutable multidict.
"""

def __init__(
self,
*args: typing.Union[
"ImmutableMultiDict",
typing.Mapping,
"ImmutableMultiDict[str, str]",
typing.Mapping[str, str],
typing.List[typing.Tuple[typing.Any, typing.Any]],
str,
bytes,
Expand Down Expand Up @@ -463,7 +508,7 @@ async def close(self) -> None:
await run_in_threadpool(self.file.close)


class FormData(ImmutableMultiDict):
class FormData(ImmutableMultiDict[str, typing.Union[str, UploadFile]]):
"""
An immutable multidict, containing both file uploads and text input.
"""
Expand All @@ -480,7 +525,7 @@ def __init__(
super().__init__(*args, **kwargs)

async def close(self) -> None:
for key, value in self.multi_items():
for _, value in self.multi_items():
if isinstance(value, UploadFile):
await value.close()

Expand All @@ -492,9 +537,9 @@ class Headers(typing.Mapping[str, str]):

def __init__(
self,
headers: typing.Mapping[str, str] = None,
raw: typing.List[typing.Tuple[bytes, bytes]] = None,
scope: Scope = None,
headers: typing.Optional[typing.Mapping[str, str]] = None,
raw: typing.Optional[typing.List[typing.Tuple[bytes, bytes]]] = None,
scope: typing.Optional[Scope] = None,
) -> None:
self._list: typing.List[typing.Tuple[bytes, bytes]] = []
if headers is not None:
Expand All @@ -515,17 +560,27 @@ def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
return list(self._list)

def keys(self) -> typing.List[str]: # type: ignore
return [key.decode("latin-1") for key, value in self._list]
return [key.decode("latin-1") for key, _ in self._list]

def values(self) -> typing.List[str]: # type: ignore
return [value.decode("latin-1") for key, value in self._list]
return [value.decode("latin-1") for _, value in self._list]

def items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore
return [
(key.decode("latin-1"), value.decode("latin-1"))
for key, value in self._list
]

@typing.overload
def get(self, key: str) -> str:
... # pragma: no cover

@typing.overload
def get(
self, key: str, default: typing.Optional[typing.Union[str, _T]] = ...
) -> typing.Union[str, _T]:
... # pragma: no cover

def get(self, key: str, default: typing.Any = None) -> typing.Any:
try:
return self[key]
Expand All @@ -551,13 +606,14 @@ def __getitem__(self, key: str) -> str:
raise KeyError(key)

def __contains__(self, key: typing.Any) -> bool:
assert isinstance(key, str)
get_header_key = key.lower().encode("latin-1")
for header_key, header_value in self._list:
for header_key, _ in self._list:
if header_key == get_header_key:
return True
return False

def __iter__(self) -> typing.Iterator[typing.Any]:
def __iter__(self) -> typing.Iterator[str]:
return iter(self.keys())

def __len__(self) -> int:
Expand All @@ -576,6 +632,11 @@ def __repr__(self) -> str:
return f"{class_name}(raw={self.raw!r})"


class _SupportsItems(Protocol[_KT_co, _VT_co]):
def items(self) -> typing.Iterable[typing.Tuple[_KT_co, _VT_co]]:
... # pragma: no cover


class MutableHeaders(Headers):
def __setitem__(self, key: str, value: str) -> None:
"""
Expand All @@ -585,8 +646,8 @@ def __setitem__(self, key: str, value: str) -> None:
set_key = key.lower().encode("latin-1")
set_value = value.encode("latin-1")

found_indexes = []
for idx, (item_key, item_value) in enumerate(self._list):
found_indexes: typing.List[int] = []
for idx, (item_key, _) in enumerate(self._list):
if item_key == set_key:
found_indexes.append(idx)

Expand All @@ -605,8 +666,8 @@ def __delitem__(self, key: str) -> None:
"""
del_key = key.lower().encode("latin-1")

pop_indexes = []
for idx, (item_key, item_value) in enumerate(self._list):
pop_indexes: typing.List[int] = []
for idx, (item_key, _) in enumerate(self._list):
if item_key == del_key:
pop_indexes.append(idx)

Expand All @@ -625,13 +686,13 @@ def setdefault(self, key: str, value: str) -> str:
set_key = key.lower().encode("latin-1")
set_value = value.encode("latin-1")

for idx, (item_key, item_value) in enumerate(self._list):
for _, (item_key, item_value) in enumerate(self._list):
if item_key == set_key:
return item_value.decode("latin-1")
self._list.append((set_key, set_value))
return value

def update(self, other: dict) -> None:
def update(self, other: _SupportsItems[str, str]) -> None:
for key, val in other.items():
self[key] = val

Expand All @@ -657,7 +718,9 @@ class State:
Used for `request.state` and `app.state`.
"""

def __init__(self, state: typing.Dict = None):
def __init__(
self, state: typing.Optional[typing.Mapping[typing.Any, typing.Any]] = None
):
if state is None:
state = {}
super().__setattr__("_state", state)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_formparsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __bool__(self):
async def app(scope, receive, send):
request = Request(scope, receive)
data = await request.form()
output = {}
output: typing.Dict[str, typing.Union[str, typing.Dict[str, typing.Any]]] = {}
for key, value in data.items():
if isinstance(value, UploadFile):
content = await value.read()
Expand Down Expand Up @@ -60,7 +60,7 @@ async def multi_items_app(scope, receive, send):
async def app_with_headers(scope, receive, send):
request = Request(scope, receive)
data = await request.form()
output = {}
output: typing.Dict[str, typing.Union[str, typing.Dict[str, typing.Any]]] = {}
for key, value in data.items():
if isinstance(value, UploadFile):
content = await value.read()
Expand Down