Skip to content

Commit

Permalink
add generics type parameters to ImmutableMultiDict
Browse files Browse the repository at this point in the history
Slice of encode#1403
  • Loading branch information
adriangb committed Jan 29, 2022
1 parent 199fc70 commit 4adca90
Showing 1 changed file with 31 additions and 14 deletions.
45 changes: 31 additions & 14 deletions starlette/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
Address = namedtuple("Address", ["host", "port"])


_T = typing.TypeVar("_T")
_KT = typing.TypeVar("_KT") # key type
_VT_co = typing.TypeVar("_VT_co", covariant=True) # value type for covariant containers


class URL:
def __init__(
self, url: str = "", scope: Scope = None, **components: typing.Any
Expand Down Expand Up @@ -230,13 +235,15 @@ def __str__(self) -> str:
return ", ".join(repr(item) for item in self)


class ImmutableMultiDict(typing.Mapping):
class ImmutableMultiDict(typing.Mapping[_KT, _VT_co]):
_dict: typing.Dict[_KT, _VT_co]

def __init__(
self,
*args: typing.Union[
"ImmutableMultiDict",
typing.Mapping,
typing.List[typing.Tuple[typing.Any, typing.Any]],
"ImmutableMultiDict[_KT, _VT_co]",
typing.Mapping[_KT, _VT_co],
typing.Sequence[typing.Tuple[_KT, _VT_co]],
],
**kwargs: typing.Any,
) -> None:
Expand All @@ -246,7 +253,7 @@ def __init__(
if kwargs:
value = (
ImmutableMultiDict(value).multi_items()
+ ImmutableMultiDict(kwargs).multi_items()
+ ImmutableMultiDict(kwargs).multi_items() # type: ignore[operator]
)

if not value:
Expand All @@ -266,33 +273,43 @@ def __init__(
self._dict = {k: v for k, v in _items}
self._list = _items

def getlist(self, key: typing.Any) -> typing.List[typing.Any]:
def getlist(self, key: typing.Any) -> typing.List[_VT_co]:
return [item_value for item_key, item_value in self._list if item_key == key]

def keys(self) -> typing.KeysView:
def keys(self) -> typing.KeysView[_KT]:
return self._dict.keys()

def values(self) -> typing.ValuesView:
def values(self) -> typing.ValuesView[_VT_co]:
return self._dict.values()

def items(self) -> typing.ItemsView:
def items(self) -> typing.ItemsView[_KT, _VT_co]:
return self._dict.items()

def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
def multi_items(self) -> typing.List[typing.Tuple[_KT, _VT_co]]:
return list(self._list)

def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
@typing.overload
def get(self, key: _KT) -> _VT_co:
... # pragma: no cover

@typing.overload
def get(
self, key: _KT, default: typing.Optional[typing.Union[_VT_co, _T]] = ...
) -> typing.Union[_VT_co, _T]:
... # pragma: no cover

def get(self, key: _KT, default: typing.Any = None) -> typing.Any:
if key in self._dict:
return self._dict[key]
return default

def __getitem__(self, key: typing.Any) -> str:
def __getitem__(self, key: _KT) -> _VT_co:
return self._dict[key]

def __contains__(self, key: typing.Any) -> bool:
return key in self._dict

def __iter__(self) -> typing.Iterator[typing.Any]:
def __iter__(self) -> typing.Iterator[_KT]:
return iter(self.keys())

def __len__(self) -> int:
Expand All @@ -309,7 +326,7 @@ def __repr__(self) -> str:
return f"{class_name}({items!r})"


class MultiDict(ImmutableMultiDict):
class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]):
def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
self.setlist(key, [value])

Expand Down

0 comments on commit 4adca90

Please sign in to comment.