diff --git a/docs/middleware.md b/docs/middleware.md index cecdafaaf..8f063dec3 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -226,12 +226,286 @@ Middleware classes should not modify their state outside of the `__init__` metho Instead you should keep any state local to the `dispatch` method, or pass it around explicitly, rather than mutating the middleware instance. -!!! bug - Currently, the `BaseHTTPMiddleware` has some known issues: +!!! warning + Currently, the `BaseHTTPMiddleware` has some known limitations: - It's not possible to use `BackgroundTasks` with `BaseHTTPMiddleware`. Check [#1438](https://github.com/encode/starlette/issues/1438) for more details. - Using `BaseHTTPMiddleware` will prevent changes to [`contextlib.ContextVar`](https://docs.python.org/3/library/contextvars.html#contextvars.ContextVar)s from propagating upwards. That is, if you set a value for a `ContextVar` in your endpoint and try to read it from a middleware you will find that the value is not the same value you set in your endpoint (see [this test](https://github.com/encode/starlette/blob/621abc747a6604825190b93467918a0ec6456a24/tests/middleware/test_base.py#L192-L223) for an example of this behavior). +## Pure ASGI Middleware + +Thanks to how ASGI was designed, it is possible to implement middleware as a chain of ASGI applications, where each application calls into the next one. +Each element of the chain is an [`ASGI`](https://asgi.readthedocs.io/en/latest/) application by itself, which per definition, is also a middleware. + +This is also an alternative approach in case the limitations of `BaseHTTPMiddleware` are a problem. + +### Guiding principles + +The most common way to create an ASGI middleware is with a class. + +```python +class ASGIMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + await self.app(scope, receive, send) +``` + +The middleware above is the most basic ASGI middleware. It receives an ASGI application as an argument for its constructor, and implements the `async __call__` method. This method should accept `scope`, which contains information about the current connection, and `receive` and `send` which allow to exchange ASGI event messages with the ASGI server (learn more in the [ASGI specification](https://asgi.readthedocs.io/en/latest/specs/index.html)). + +As an alternative for the class approach, you can also use a function: + +```python +import functools + +def asgi_middleware(): + def asgi_decorator(app): + @functools.wraps(app) + async def wrapped_app(scope, receive, send): + await app(scope, receive, send) + return wrapped_app + return asgi_decorator +``` + +!!! info + The function pattern is not commonly spread, but you can check a more advanced implementation of it on + [asgi-cors](https://github.com/simonw/asgi-cors/blob/10ef64bfcc6cd8d16f3014077f20a0fb8544ec39/asgi_cors.py). + +#### `Scope` types + +As we mentioned, the scope holds the information about the connection. There are three types of `scope`s: + +- [`lifespan`](https://asgi.readthedocs.io/en/latest/specs/lifespan.html#scope) is a special type of scope that is used for the lifespan of the ASGI application. +- [`http`](https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope) is a type of scope that is used for HTTP requests. +- [`websocket`](https://asgi.readthedocs.io/en/latest/specs/www.html#websocket-connection-scope) is a type of scope that is used for WebSocket connections. + +If you want to create a middleware that only runs on HTTP requests, you'd write something like: + +```python +class ASGIMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + if scope["type"] != "http": + return await self.app(scope, receive, send) + + # Do something here! + await self.app(scope, receive, send) +``` +In the example above, if the `scope` type is **not** `http`, meaning that is either `lifespan` or `websocket`, we'll directly call the `self.app`. + +The same applies to other scopes. + +!!! tip + Middleware classes should be stateless -- see [Per-request state](#per-request-state) if you do need to store per-request state. + +#### Wrapping `send` and `receive` + +A common pattern, that you'll probably need to use is to wrap the `send` or `receive` callables. + +For example, here's how we could write a middleware that logs the response status code, which we'd obtain +by wrapping the `send` with the `send_wrapper` callable: + +```python +class LogStatusCodeMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + if scope["type"] != "http": + return await self.app(scope, receive, send) + + status_code = 500 + + async def send_wrapper(message): + if message["type"] == "http.response.start": + status_code = message["status"] + await send(message) + + await self.app(scope, receive, send_wrapper) + + print("This is a primitive access log") + print(f"status = {status_code}") +``` + +!!! info + You can check a more advanced implementation of the same rationale on [asgi-logger](https://github.com/Kludex/asgi-logger/blob/main/asgi_logger/middleware.py). + +#### Type annotations + +There are two ways of annotating a middleware: using Starlette itself or [`asgiref`](https://github.com/django/asgiref). + +Using Starlette, you can do as: + +```python +from starlette.types import ASGIApp, Message, Scope, Receive, Send + + +class ASGIMiddleware: + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + return await self.app(scope, receive, send) + + async def send_wrapper(message: Message) -> None: + # ... Do something + await send(message) + + await self.app(scope, receive, send_wrapper) +``` + +Although this is easy, you may prefer to be more strict. In which case, you'd need to use `asgiref`: + +```python +from asgiref.typing import ASGI3Application, Scope, ASGISendCallable +from asgiref.typing import ASGIReceiveEvent, ASGISendEvent + + +class ASGIMiddleware: + def __init__(self, app: ASGI3Application) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: + if scope["type"] == "http": + async def send_wrapper(message: ASGISendEvent) -> None: + await send(message) + return await self.app(scope, receive, send_wrapper) + await self.app(scope, receive, send) +``` + +!!! info + If you're curious about the `ASGI3Application` type on the snippet above, you can read more about ASGI versions on the [Legacy Applications section on the ASGI + documentation](https://asgi.readthedocs.io/en/latest/specs/main.html#legacy-applications). + +### Reusing Starlette components + +If you need to work with request or response data, you may find it more convenient to reuse Starlette data structures ([`Request`](requests.md#request), [`Headers`](requests.md#headers), [`QueryParams`](requests.md#query-parameters), [`URL`](requests.md#url), etc) rather than work with raw ASGI data. All these components can be built from the ASGI `scope`, `receive` and `send`, allowing you to work on pure ASGI middleware at a higher level of abstraction. + +For example, we can create a `Request` object, and work with it. +```python +from starlette.requests import Request + +class ASGIMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + if scope["type"] == "http": + request = Request(scope, receive, send) + # Do something here! + await self.app(scope, receive, send) +``` + +Or we might use `MutableHeaders` to change the response headers: + +```python +from starlette.datastructures import MutableHeaders + +class ExtraResponseHeadersMiddleware: + def __init__(self, app, headers): + self.app = app + self.headers = headers + + async def __call__(self, scope, receive, send): + if scope["type"] != "http": + return await self.app(scope, receive, send) + + async def wrapped_send(message): + if message["type"] == "http.response.start": + headers = MutableHeaders(scope=message) + for key, value in self.headers: + headers.append(key, value) + await send(message) + + await self.app(scope, receive, wrapped_send) +``` + +### Per-request state + +ASGI middleware classes should be stateless, as we typically don't want to leak state across requests. + +The risk is low when defining wrappers inside `__call__`, as state would typically be defined as inline variables. + +But if the middleware grows larger and more complex, you might be tempted to refactor wrappers as methods. Still, state should not be stored in the middleware instance. This means that the middleware should not have attributes that would change across requests or connections. If the middleware has an attribute, for example set in the `__init__()` method, nothing else should change it afterwards. Instead, if you need to manipulate per-request state, you may write a separate `Responder` class: + +```python +from functools import partial + +from starlette.datastructures import Headers + +class TweakMiddleware: + """ + Make a change to the response body if 'X-Tweak' is + present in the response headers. + """ + + async def __call_(self, scope, receive, send): + if scope["type"] != "http": + return await self.app(scope, receive, send) + + responder = TweakResponder(self.app) + await responder(scope, receive, send) + +class TweakResponder: + def __init__(self, app): + self.app = app + self.should_tweak = False + + async def __call__(self, scope, receive, send): + send = partial(self.maybe_send_with_tweaks, send=send) + await self.app(scope, receive, send) + + async def maybe_send_with_tweaks(self, message, send): + if message["type"] == "http.response.start": + headers = Headers(raw=message["headers"]) + self.should_tweak = headers.get("X-Tweak") == "1" + await send(message) + return + + if message["type"] == "http.response.body": + if not self.should_tweak: + await send(message) + return + + # Actually tweak the response body... +``` + +See also [`GZipMiddleware`](https://github.com/encode/starlette/blob/9ef1b91c9c043197da6c3f38aa153fd874b95527/starlette/middleware/gzip.py) for a full example of this pattern. + +### Storing context in `scope` + +As we know by now, the `scope` holds the information about the connection. + +As per the ASGI specifications, any application can store custom information on the `scope`. +Have in mind that other components could also store data in the same `scope`, so it's important to use a key that has a low chance of being used by other things. + +For example, if you are building an application called `super-app` you could have that as a prefix for any keys you put in the `scope`, and then you could have a key called `super-app-transaction-id`. + +```python +from uuid import uuid4 + + +class ASGIMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + scope["super-app-transaction-id"] = uuid4() + await self.app(scope, receive, send) +``` +On the example above, we stored a key called "super-app-transaction-id" in the scope. That can be used by the application itself, as the scope is forwarded to it. + +!!! important + This documentation should be enough to have a good basis on how to create an ASGI middleware. + Nonetheless, there are great articles about the subject: + + - [Introduction to ASGI: Emergence of an Async Python Web Ecosystem](https://florimond.dev/en/posts/2019/08/introduction-to-asgi-async-python-web/) + - [How to write ASGI middleware](https://pgjones.dev/blog/how-to-write-asgi-middleware-2021/) + ## Using middleware in other frameworks To wrap ASGI middleware around other ASGI applications, you should use the diff --git a/docs/release-notes.md b/docs/release-notes.md index 1ac7205e0..c2d098eb8 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -1,3 +1,10 @@ +## 0.20.4 + +June 28, 2022 + +### Fixed +* Remove converter from path when generating OpenAPI schema [#1648](https://github.com/encode/starlette/pull/1648). + ## 0.20.3 June 10, 2022 diff --git a/docs/requests.md b/docs/requests.md index 872946638..a50c0753f 100644 --- a/docs/requests.md +++ b/docs/requests.md @@ -142,6 +142,13 @@ filename = form["upload_file"].filename contents = await form["upload_file"].read() ``` +!!! info + As settled in [RFC-7578: 4.2](https://www.ietf.org/rfc/rfc7578.txt), form-data content part that contains file + assumed to have `name` and `filename` fields in `Content-Disposition` header: `Content-Disposition: form- + data; name="user"; filename="somefile"`. Though `filename` field is optional according to RFC-7578, it helps + Starlette to differentiate which data should be treated as file. If `filename` field was supplied, `UploadFile` + object will be created to access underlying file, otherwise form-data part will be parsed and available as a raw + string. #### Application diff --git a/requirements.txt b/requirements.txt index c834ac8b8..d1218aaf4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,17 +8,17 @@ coverage==6.2 databases[sqlite]==0.5.5 flake8==3.9.2 isort==5.10.1 -mypy==0.960 +mypy==0.961 types-requests==2.26.3 -types-contextvars==2.4.6 +types-contextvars==2.4.7 types-PyYAML==6.0.4 -types-dataclasses==0.6.5 +types-dataclasses==0.6.6 pytest==7.1.2 -trio==0.19.0 +trio==0.21.0 # Documentation mkdocs==1.3.0 -mkdocs-material==8.2.8 +mkdocs-material==8.3.8 mkautodoc==0.1.0 # Packaging diff --git a/starlette/__init__.py b/starlette/__init__.py index 8815fb52f..8b8252f48 100644 --- a/starlette/__init__.py +++ b/starlette/__init__.py @@ -1 +1 @@ -__version__ = "0.20.3" +__version__ = "0.20.4" diff --git a/starlette/requests.py b/starlette/requests.py index 66c510cfe..726abddcc 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -1,6 +1,5 @@ import json import typing -from collections.abc import Mapping from http import cookies as http_cookies import anyio @@ -60,7 +59,7 @@ class ClientDisconnect(Exception): pass -class HTTPConnection(Mapping): +class HTTPConnection(typing.Mapping[str, typing.Any]): """ A base class for incoming HTTP connections, that is used to provide any functionality that is common to both `Request` and `WebSocket`. @@ -143,7 +142,7 @@ def client(self) -> typing.Optional[Address]: return None @property - def session(self) -> dict: + def session(self) -> typing.Dict[str, typing.Any]: assert ( "session" in self.scope ), "SessionMiddleware must be installed to access request.session" @@ -231,7 +230,7 @@ async def stream(self) -> typing.AsyncGenerator[bytes, None]: async def body(self) -> bytes: if not hasattr(self, "_body"): - chunks = [] + chunks: "typing.List[bytes]" = [] async for chunk in self.stream(): chunks.append(chunk) self._body = b"".join(chunks) @@ -249,7 +248,8 @@ async def form(self) -> FormData: parse_options_header is not None ), "The `python-multipart` library must be installed to use form parsing." content_type_header = self.headers.get("Content-Type") - content_type, options = parse_options_header(content_type_header) + content_type: bytes + content_type, _ = parse_options_header(content_type_header) if content_type == b"multipart/form-data": try: multipart_parser = MultiPartParser(self.headers, self.stream()) @@ -285,7 +285,7 @@ async def is_disconnected(self) -> bool: async def send_push_promise(self, path: str) -> None: if "http.response.push" in self.scope.get("extensions", {}): - raw_headers = [] + raw_headers: "typing.List[typing.Tuple[bytes, bytes]]" = [] for name in SERVER_PUSH_HEADERS_TO_COPY: for value in self.headers.getlist(name): raw_headers.append( diff --git a/starlette/responses.py b/starlette/responses.py index 8c67f9c9d..f93588b49 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -112,7 +112,7 @@ def set_cookie( httponly: bool = False, samesite: typing.Optional[Literal["lax", "strict", "none"]] = "lax", ) -> None: - cookie: http.cookies.BaseCookie = http.cookies.SimpleCookie() + cookie: "http.cookies.BaseCookie[str]" = http.cookies.SimpleCookie() cookie[key] = value if max_age is not None: cookie[key]["max-age"] = max_age @@ -186,7 +186,7 @@ def __init__( self, content: typing.Any, status_code: int = 200, - headers: typing.Optional[dict] = None, + headers: typing.Optional[typing.Dict[str, str]] = None, media_type: typing.Optional[str] = None, background: typing.Optional[BackgroundTask] = None, ) -> None: @@ -216,10 +216,18 @@ def __init__( self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;") +Content = typing.Union[str, bytes] +SyncContentStream = typing.Iterator[Content] +AsyncContentStream = typing.AsyncIterable[Content] +ContentStream = typing.Union[AsyncContentStream, SyncContentStream] + + class StreamingResponse(Response): + body_iterator: AsyncContentStream + def __init__( self, - content: typing.Any, + content: ContentStream, status_code: int = 200, headers: typing.Optional[typing.Mapping[str, str]] = None, media_type: typing.Optional[str] = None, @@ -262,7 +270,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async with anyio.create_task_group() as task_group: - async def wrap(func: typing.Callable[[], typing.Coroutine]) -> None: + async def wrap(func: "typing.Callable[[], typing.Awaitable[None]]") -> None: await func() task_group.cancel_scope.cancel() diff --git a/starlette/schemas.py b/starlette/schemas.py index 6ca764fdc..55bf7b397 100644 --- a/starlette/schemas.py +++ b/starlette/schemas.py @@ -1,4 +1,5 @@ import inspect +import re import typing from starlette.requests import Request @@ -49,10 +50,11 @@ def get_endpoints( for route in routes: if isinstance(route, Mount): + path = self._remove_converter(route.path) routes = route.routes or [] sub_endpoints = [ EndpointInfo( - path="".join((route.path, sub_endpoint.path)), + path="".join((path, sub_endpoint.path)), http_method=sub_endpoint.http_method, func=sub_endpoint.func, ) @@ -64,23 +66,32 @@ def get_endpoints( continue elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint): + path = self._remove_converter(route.path) for method in route.methods or ["GET"]: if method == "HEAD": continue endpoints_info.append( - EndpointInfo(route.path, method.lower(), route.endpoint) + EndpointInfo(path, method.lower(), route.endpoint) ) else: + path = self._remove_converter(route.path) for method in ["get", "post", "put", "patch", "delete", "options"]: if not hasattr(route.endpoint, method): continue func = getattr(route.endpoint, method) - endpoints_info.append( - EndpointInfo(route.path, method.lower(), func) - ) + endpoints_info.append(EndpointInfo(path, method.lower(), func)) return endpoints_info + def _remove_converter(self, path: str) -> str: + """ + Remove the converter from the path. + For example, a route like this: + Route("/users/{id:int}", endpoint=get_user, methods=["GET"]) + Should be represented as `/users/{id}` in the OpenAPI schema. + """ + return re.sub(r":\w+}", "}", path) + def parse_docstring(self, func_or_method: typing.Callable) -> dict: """ Given a function, parse the docstring as YAML and return a dictionary of info. diff --git a/tests/test_schemas.py b/tests/test_schemas.py index fa43785b9..26884b391 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -13,6 +13,17 @@ def ws(session): pass # pragma: no cover +def get_user(request): + """ + responses: + 200: + description: A user. + examples: + {"username": "tom"} + """ + pass # pragma: no cover + + def list_users(request): """ responses: @@ -103,6 +114,7 @@ def schema(request): app = Starlette( routes=[ WebSocketRoute("/ws", endpoint=ws), + Route("/users/{id:int}", endpoint=get_user, methods=["GET"]), Route("/users", endpoint=list_users, methods=["GET", "HEAD"]), Route("/users", endpoint=create_user, methods=["POST"]), Route("/orgs", endpoint=OrganisationsEndpoint), @@ -168,6 +180,16 @@ def test_schema_generation(): } }, }, + "/users/{id}": { + "get": { + "responses": { + 200: { + "description": "A user.", + "examples": {"username": "tom"}, + } + } + }, + }, }, } @@ -216,6 +238,13 @@ def test_schema_generation(): description: A user. examples: username: tom + /users/{id}: + get: + responses: + 200: + description: A user. + examples: + username: tom """