Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: fully type annotate datastructures.py #1403

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
158 changes: 110 additions & 48 deletions starlette/datastructures.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,35 @@
import sys
import tempfile
import typing
from collections import namedtuple
from collections.abc import Sequence
from shlex import shlex
from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit

if sys.version_info >= (3, 8): # pragma: no cover
from typing import Protocol
else: # pragma: no cover
from typing_extensions import Protocol

from starlette.concurrency import run_in_threadpool
from starlette.types import Scope

Address = namedtuple("Address", ["host", "port"])
_T = typing.TypeVar("_T")
_KT = typing.TypeVar("_KT") # key type
_VT = typing.TypeVar("_VT") # value type
_KT_co = typing.TypeVar("_KT_co", covariant=True) # key type for covariant containers
_VT_co = typing.TypeVar("_VT_co", covariant=True) # value type for covariant containers


class Address(typing.NamedTuple):
host: typing.Optional[str]
port: typing.Optional[int]


class URL:
def __init__(
self, url: str = "", scope: Scope = None, **components: typing.Any
self,
url: str = "",
scope: typing.Optional[Scope] = None,
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
scope: typing.Optional[Scope] = None,
scope: Scope = None,

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain this change request to me please, or reference some documentation on the subject? I understand that there is both a practical and semantic difference, but as far as I understand using Scope | None is now preferred if the default value is None. For example, see https://stackoverflow.com/questions/62732402/can-i-omit-optional-if-i-set-default-to-none

**components: typing.Any,
) -> None:
if scope is not None:
assert not url, 'Cannot set both "url" and "scope".'
Expand Down Expand Up @@ -202,7 +218,7 @@ def __str__(self) -> str:
return self._value


class CommaSeparatedStrings(Sequence):
class CommaSeparatedStrings(typing.Sequence[str]):
def __init__(self, value: typing.Union[str, typing.Sequence[str]]):
if isinstance(value, str):
splitter = shlex(value, posix=True)
Expand All @@ -215,6 +231,14 @@ def __init__(self, value: typing.Union[str, typing.Sequence[str]]):
def __len__(self) -> int:
return len(self._items)

@typing.overload
def __getitem__(self, index: int) -> str:
... # pragma: no cover

@typing.overload
def __getitem__(self, index: slice) -> typing.Sequence[str]:
... # pragma: no cover
Comment on lines +234 to +240
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this PR doesn't get merged, you can also create a separate PR for this. 👍

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep may have to split this up


def __getitem__(self, index: typing.Union[int, slice]) -> typing.Any:
return self._items[index]

Expand All @@ -230,13 +254,13 @@ def __str__(self) -> str:
return ", ".join(repr(item) for item in self)


class ImmutableMultiDict(typing.Mapping):
class ImmutableMultiDict(typing.Mapping[_KT, _VT_co]):
def __init__(
self,
*args: typing.Union[
"ImmutableMultiDict",
typing.Mapping,
typing.List[typing.Tuple[typing.Any, typing.Any]],
"ImmutableMultiDict[_KT, _VT_co]",
typing.Mapping[_KT, _VT_co],
typing.Sequence[typing.Tuple[_KT, _VT_co]],
],
**kwargs: typing.Any,
) -> None:
Expand All @@ -246,7 +270,7 @@ def __init__(
if kwargs:
value = (
ImmutableMultiDict(value).multi_items()
+ ImmutableMultiDict(kwargs).multi_items()
+ ImmutableMultiDict(kwargs).multi_items() # type: ignore[operator]
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is not a valid operation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because the type of value cannot be inferred from the expression value = args[0] if args or []

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could be fixed with a bit of juggling of how this is handled, but I wanted to do this with the least amount of runtime changes necessary, so adding a type: ignore for now and maybe doing a rework in a future PR seems like the best bet

)

if not value:
Expand All @@ -266,33 +290,43 @@ def __init__(
self._dict = {k: v for k, v in _items}
self._list = _items

def getlist(self, key: typing.Any) -> typing.List[typing.Any]:
def getlist(self, key: typing.Any) -> typing.List[_VT_co]:
return [item_value for item_key, item_value in self._list if item_key == key]

def keys(self) -> typing.KeysView:
def keys(self) -> typing.KeysView[_KT]:
return self._dict.keys()

def values(self) -> typing.ValuesView:
def values(self) -> typing.ValuesView[_VT_co]:
return self._dict.values()

def items(self) -> typing.ItemsView:
def items(self) -> typing.ItemsView[_KT, _VT_co]:
return self._dict.items()

def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
def multi_items(self) -> typing.List[typing.Tuple[_KT, _VT_co]]:
return list(self._list)

def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
@typing.overload
def get(self, key: _KT) -> _VT_co:
... # pragma: no cover

@typing.overload
def get(
self, key: _KT, default: typing.Optional[typing.Union[_VT_co, _T]] = ...
) -> typing.Union[_VT_co, _T]:
... # pragma: no cover

def get(self, key: _KT, default: typing.Any = None) -> typing.Any:
if key in self._dict:
return self._dict[key]
return typing.cast(_VT_co, self._dict[key])
return default
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I fully understand the reason of what is written here. 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is saying:

  • If you do not provide a default, the return value is going to be of type _VT_co (so a value of the dictionary) (or an error)
  • If you do provide a default value, the return value is going to be either _VT_co (if the key is found) or the type of your default value (if the key is not found)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the typing.cast by giving ImmutableMultiDict._dict a type annotation


def __getitem__(self, key: typing.Any) -> str:
def __getitem__(self, key: _KT) -> _VT_co:
return self._dict[key]

def __contains__(self, key: typing.Any) -> bool:
return key in self._dict

def __iter__(self) -> typing.Iterator[typing.Any]:
def __iter__(self) -> typing.Iterator[_KT]:
return iter(self.keys())

def __len__(self) -> int:
Expand All @@ -309,24 +343,34 @@ def __repr__(self) -> str:
return f"{class_name}({items!r})"


class MultiDict(ImmutableMultiDict):
def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
class MultiDict(ImmutableMultiDict[_KT, _VT]):
def __setitem__(self, key: _KT, value: _VT) -> None:
self.setlist(key, [value])

def __delitem__(self, key: typing.Any) -> None:
def __delitem__(self, key: _KT) -> None:
self._list = [(k, v) for k, v in self._list if k != key]
del self._dict[key]

def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
@typing.overload
def pop(self, key: _KT) -> _VT:
... # pragma: no cover

@typing.overload
def pop(
self, key: _KT, default: typing.Optional[typing.Union[_VT, _T]] = ...
) -> typing.Union[_VT, _T, None]:
... # pragma: no cover

def pop(self, key: _KT, default: typing.Any = None) -> typing.Any:
self._list = [(k, v) for k, v in self._list if k != key]
return self._dict.pop(key, default)

def popitem(self) -> typing.Tuple:
def popitem(self) -> typing.Tuple[_KT, _VT]:
key, value = self._dict.popitem()
self._list = [(k, v) for k, v in self._list if k != key]
return key, value

def poplist(self, key: typing.Any) -> typing.List:
def poplist(self, key: typing.Any) -> typing.List[_VT]:
values = [v for k, v in self._list if k == key]
self.pop(key)
return values
Expand All @@ -335,14 +379,14 @@ def clear(self) -> None:
self._dict.clear()
self._list.clear()

def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
def setdefault(self, key: _KT, default: typing.Union[_VT, None] = None) -> _VT:
if key not in self:
self._dict[key] = default
self._list.append((key, default))

return self[key]

def setlist(self, key: typing.Any, values: typing.List) -> None:
def setlist(self, key: _KT, values: typing.Sequence[_VT]) -> None:
if not values:
self.pop(key, None)
else:
Expand All @@ -357,8 +401,8 @@ def append(self, key: typing.Any, value: typing.Any) -> None:
def update(
self,
*args: typing.Union[
"MultiDict",
typing.Mapping,
"MultiDict[_KT, _VT]",
typing.Mapping[_KT, _VT],
typing.List[typing.Tuple[typing.Any, typing.Any]],
],
**kwargs: typing.Any,
Expand All @@ -369,16 +413,16 @@ def update(
self._dict.update(value)


class QueryParams(ImmutableMultiDict):
class QueryParams(ImmutableMultiDict[str, str]):
"""
An immutable multidict.
"""

def __init__(
self,
*args: typing.Union[
"ImmutableMultiDict",
typing.Mapping,
"ImmutableMultiDict[str, str]",
typing.Mapping[str, str],
typing.List[typing.Tuple[typing.Any, typing.Any]],
str,
bytes,
Expand Down Expand Up @@ -463,7 +507,7 @@ async def close(self) -> None:
await run_in_threadpool(self.file.close)


class FormData(ImmutableMultiDict):
class FormData(ImmutableMultiDict[str, typing.Union[str, UploadFile]]):
"""
An immutable multidict, containing both file uploads and text input.
"""
Expand All @@ -480,7 +524,7 @@ def __init__(
super().__init__(*args, **kwargs)

async def close(self) -> None:
for key, value in self.multi_items():
for _, value in self.multi_items():
if isinstance(value, UploadFile):
await value.close()

Expand All @@ -492,9 +536,9 @@ class Headers(typing.Mapping[str, str]):

def __init__(
self,
headers: typing.Mapping[str, str] = None,
raw: typing.List[typing.Tuple[bytes, bytes]] = None,
scope: Scope = None,
headers: typing.Optional[typing.Mapping[str, str]] = None,
raw: typing.Optional[typing.List[typing.Tuple[bytes, bytes]]] = None,
scope: typing.Optional[Scope] = None,
) -> None:
self._list: typing.List[typing.Tuple[bytes, bytes]] = []
if headers is not None:
Expand All @@ -515,17 +559,27 @@ def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
return list(self._list)

def keys(self) -> typing.List[str]: # type: ignore
return [key.decode("latin-1") for key, value in self._list]
return [key.decode("latin-1") for key, _ in self._list]

def values(self) -> typing.List[str]: # type: ignore
return [value.decode("latin-1") for key, value in self._list]
return [value.decode("latin-1") for _, value in self._list]

def items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore
return [
(key.decode("latin-1"), value.decode("latin-1"))
for key, value in self._list
]

@typing.overload
def get(self, key: str) -> str:
... # pragma: no cover

@typing.overload
def get(
self, key: str, default: typing.Optional[typing.Union[str, _T]] = ...
) -> typing.Union[str, _T]:
... # pragma: no cover

def get(self, key: str, default: typing.Any = None) -> typing.Any:
try:
return self[key]
Expand All @@ -551,13 +605,14 @@ def __getitem__(self, key: str) -> str:
raise KeyError(key)

def __contains__(self, key: typing.Any) -> bool:
assert isinstance(key, str)
get_header_key = key.lower().encode("latin-1")
for header_key, header_value in self._list:
for header_key, _ in self._list:
if header_key == get_header_key:
return True
return False

def __iter__(self) -> typing.Iterator[typing.Any]:
def __iter__(self) -> typing.Iterator[str]:
return iter(self.keys())

def __len__(self) -> int:
Expand All @@ -576,6 +631,11 @@ def __repr__(self) -> str:
return f"{class_name}(raw={self.raw!r})"


class _SupportsItems(Protocol[_KT_co, _VT_co]):
def items(self) -> typing.Iterable[typing.Tuple[_KT_co, _VT_co]]:
... # pragma: no cover


class MutableHeaders(Headers):
def __setitem__(self, key: str, value: str) -> None:
"""
Expand All @@ -585,8 +645,8 @@ def __setitem__(self, key: str, value: str) -> None:
set_key = key.lower().encode("latin-1")
set_value = value.encode("latin-1")

found_indexes = []
for idx, (item_key, item_value) in enumerate(self._list):
found_indexes: typing.List[int] = []
for idx, (item_key, _) in enumerate(self._list):
if item_key == set_key:
found_indexes.append(idx)

Expand All @@ -605,8 +665,8 @@ def __delitem__(self, key: str) -> None:
"""
del_key = key.lower().encode("latin-1")

pop_indexes = []
for idx, (item_key, item_value) in enumerate(self._list):
pop_indexes: typing.List[int] = []
for idx, (item_key, _) in enumerate(self._list):
if item_key == del_key:
pop_indexes.append(idx)

Expand All @@ -625,13 +685,13 @@ def setdefault(self, key: str, value: str) -> str:
set_key = key.lower().encode("latin-1")
set_value = value.encode("latin-1")

for idx, (item_key, item_value) in enumerate(self._list):
for _, (item_key, item_value) in enumerate(self._list):
if item_key == set_key:
return item_value.decode("latin-1")
self._list.append((set_key, set_value))
return value

def update(self, other: dict) -> None:
def update(self, other: _SupportsItems[str, str]) -> None:
for key, val in other.items():
self[key] = val

Expand All @@ -657,7 +717,9 @@ class State:
Used for `request.state` and `app.state`.
"""

def __init__(self, state: typing.Dict = None):
def __init__(
self, state: typing.Optional[typing.Mapping[typing.Any, typing.Any]] = None
):
if state is None:
state = {}
super().__setattr__("_state", state)
Expand Down
2 changes: 1 addition & 1 deletion starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def cookies(self) -> typing.Dict[str, str]:
@property
def client(self) -> Address:
host, port = self.scope.get("client") or (None, None)
return Address(host=host, port=port)
return Address(host=host, port=port) # type: ignore[arg-type]
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm... This is actually not ASGI compliant: https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope

client (Iterable[Unicode string, int]) – A two-item iterable of [host, port], where host is the remote host’s IPv4 or IPv6 address, and port is the remote port as an integer. Optional; if missing defaults to None.

Copy link
Member Author

@adriangb adriangb Jan 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this not ASGI compliant? The NamedTuple is a tw-item utterable of [host, port]. It's the same thing we had before. In fact, I don't think we even need the # type: ignore. I just removed it in 417a3d7 and now there are 0 changes in this file.

Copy link
Sponsor Member

@Kludex Kludex Jan 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because Address is not supposed to accept (None, None), client should actually be Optional[Address].

My bad anyway, I should have mentioned that it's not because of your PR... I just noticed that it's currently wrong.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's fix it in another PR then 😄

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened #1462


@property
def session(self) -> dict:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_formparsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __bool__(self):
async def app(scope, receive, send):
request = Request(scope, receive)
data = await request.form()
output = {}
output: typing.Dict[str, typing.Union[str, typing.Dict[str, typing.Any]]] = {}
for key, value in data.items():
if isinstance(value, UploadFile):
content = await value.read()
Expand Down Expand Up @@ -60,7 +60,7 @@ async def multi_items_app(scope, receive, send):
async def app_with_headers(scope, receive, send):
request = Request(scope, receive)
data = await request.form()
output = {}
output: typing.Dict[str, typing.Union[str, typing.Dict[str, typing.Any]]] = {}
for key, value in data.items():
if isinstance(value, UploadFile):
content = await value.read()
Expand Down