Skip to content

Commit

Permalink
Update type annotations for BackgroundTask and utils (#1383)
Browse files Browse the repository at this point in the history
* Update type annotations for BackgroundTask and utils

* Add type annotation to handler

* Update setup.py

* Fix import issue

* Fix missed import

* Fix coverage

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
  • Loading branch information
uriyyo and Kludex committed Jan 8, 2022
1 parent f1c5049 commit 165592f
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 10 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_long_description():
include_package_data=True,
install_requires=[
"anyio>=3.0.0,<4",
"typing_extensions; python_version < '3.8'",
"typing_extensions; python_version < '3.10'",
"contextlib2 >= 21.6.0; python_version < '3.7'",
],
extras_require={
Expand Down
12 changes: 10 additions & 2 deletions starlette/background.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import asyncio
import sys
import typing

if sys.version_info >= (3, 10): # pragma: no cover
from typing import ParamSpec
else: # pragma: no cover
from typing_extensions import ParamSpec

from starlette.concurrency import run_in_threadpool

P = ParamSpec("P")


class BackgroundTask:
def __init__(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
) -> None:
self.func = func
self.args = args
Expand All @@ -25,7 +33,7 @@ def __init__(self, tasks: typing.Sequence[BackgroundTask] = None):
self.tasks = list(tasks) if tasks else []

def add_task(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
) -> None:
task = BackgroundTask(func, *args, **kwargs)
self.tasks.append(task)
Expand Down
20 changes: 14 additions & 6 deletions starlette/concurrency.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import functools
import sys
import typing
from typing import Any, AsyncGenerator, Iterator

import anyio

if sys.version_info >= (3, 10): # pragma: no cover
from typing import ParamSpec
else: # pragma: no cover
from typing_extensions import ParamSpec

try:
import contextvars # Python 3.7+ only or via contextvars backport.
except ImportError: # pragma: no cover
contextvars = None # type: ignore


T = typing.TypeVar("T")
P = ParamSpec("P")


async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None:
Expand All @@ -25,14 +31,14 @@ async def run(func: typing.Callable[[], typing.Coroutine]) -> None:


async def run_in_threadpool(
func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any
func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs
) -> T:
if contextvars is not None: # pragma: no cover
# Ensure we run in the same context
child = functools.partial(func, *args, **kwargs)
context = contextvars.copy_context()
func = context.run
args = (child,)
func = context.run # type: ignore[assignment]
args = (child,) # type: ignore[assignment]
elif kwargs: # pragma: no cover
# run_sync doesn't accept 'kwargs', so bind them in here
func = functools.partial(func, **kwargs)
Expand All @@ -43,7 +49,7 @@ class _StopIteration(Exception):
pass


def _next(iterator: Iterator) -> Any:
def _next(iterator: typing.Iterator[T]) -> T:
# We can't raise `StopIteration` from within the threadpool iterator
# and catch it outside that context, so we coerce them into a different
# exception type.
Expand All @@ -53,7 +59,9 @@ def _next(iterator: Iterator) -> Any:
raise _StopIteration


async def iterate_in_threadpool(iterator: Iterator) -> AsyncGenerator:
async def iterate_in_threadpool(
iterator: typing.Iterator[T],
) -> typing.AsyncIterator[T]:
while True:
try:
yield await anyio.to_thread.run_sync(_next, iterator)
Expand Down
4 changes: 3 additions & 1 deletion starlette/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ async def dispatch(self) -> None:
else request.method.lower()
)

handler = getattr(self, handler_name, self.method_not_allowed)
handler: typing.Callable[[Request], typing.Any] = getattr(
self, handler_name, self.method_not_allowed
)
is_async = asyncio.iscoroutinefunction(handler)
if is_async:
response = await handler(request)
Expand Down

0 comments on commit 165592f

Please sign in to comment.