Skip to content

Commit

Permalink
Merge branch 'master' into task-group
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Jul 2, 2022
2 parents c9ac227 + a3b43f0 commit f5cb1ee
Show file tree
Hide file tree
Showing 9 changed files with 359 additions and 23 deletions.
278 changes: 276 additions & 2 deletions docs/middleware.md
Expand Up @@ -226,12 +226,286 @@ Middleware classes should not modify their state outside of the `__init__` metho
Instead you should keep any state local to the `dispatch` method, or pass it
around explicitly, rather than mutating the middleware instance.

!!! bug
Currently, the `BaseHTTPMiddleware` has some known issues:
!!! warning
Currently, the `BaseHTTPMiddleware` has some known limitations:

- It's not possible to use `BackgroundTasks` with `BaseHTTPMiddleware`. Check [#1438](https://github.com/encode/starlette/issues/1438) for more details.
- Using `BaseHTTPMiddleware` will prevent changes to [`contextlib.ContextVar`](https://docs.python.org/3/library/contextvars.html#contextvars.ContextVar)s from propagating upwards. That is, if you set a value for a `ContextVar` in your endpoint and try to read it from a middleware you will find that the value is not the same value you set in your endpoint (see [this test](https://github.com/encode/starlette/blob/621abc747a6604825190b93467918a0ec6456a24/tests/middleware/test_base.py#L192-L223) for an example of this behavior).

## Pure ASGI Middleware

Thanks to how ASGI was designed, it is possible to implement middleware as a chain of ASGI applications, where each application calls into the next one.
Each element of the chain is an [`ASGI`](https://asgi.readthedocs.io/en/latest/) application by itself, which per definition, is also a middleware.

This is also an alternative approach in case the limitations of `BaseHTTPMiddleware` are a problem.

### Guiding principles

The most common way to create an ASGI middleware is with a class.

```python
class ASGIMiddleware:
def __init__(self, app):
self.app = app

async def __call__(self, scope, receive, send):
await self.app(scope, receive, send)
```

The middleware above is the most basic ASGI middleware. It receives an ASGI application as an argument for its constructor, and implements the `async __call__` method. This method should accept `scope`, which contains information about the current connection, and `receive` and `send` which allow to exchange ASGI event messages with the ASGI server (learn more in the [ASGI specification](https://asgi.readthedocs.io/en/latest/specs/index.html)).

As an alternative for the class approach, you can also use a function:

```python
import functools

def asgi_middleware():
def asgi_decorator(app):
@functools.wraps(app)
async def wrapped_app(scope, receive, send):
await app(scope, receive, send)
return wrapped_app
return asgi_decorator
```

!!! info
The function pattern is not commonly spread, but you can check a more advanced implementation of it on
[asgi-cors](https://github.com/simonw/asgi-cors/blob/10ef64bfcc6cd8d16f3014077f20a0fb8544ec39/asgi_cors.py).

#### `Scope` types

As we mentioned, the scope holds the information about the connection. There are three types of `scope`s:

- [`lifespan`](https://asgi.readthedocs.io/en/latest/specs/lifespan.html#scope) is a special type of scope that is used for the lifespan of the ASGI application.
- [`http`](https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope) is a type of scope that is used for HTTP requests.
- [`websocket`](https://asgi.readthedocs.io/en/latest/specs/www.html#websocket-connection-scope) is a type of scope that is used for WebSocket connections.

If you want to create a middleware that only runs on HTTP requests, you'd write something like:

```python
class ASGIMiddleware:
def __init__(self, app):
self.app = app

async def __call__(self, scope, receive, send):
if scope["type"] != "http":
return await self.app(scope, receive, send)

# Do something here!
await self.app(scope, receive, send)
```
In the example above, if the `scope` type is **not** `http`, meaning that is either `lifespan` or `websocket`, we'll directly call the `self.app`.

The same applies to other scopes.

!!! tip
Middleware classes should be stateless -- see [Per-request state](#per-request-state) if you do need to store per-request state.

#### Wrapping `send` and `receive`

A common pattern, that you'll probably need to use is to wrap the `send` or `receive` callables.

For example, here's how we could write a middleware that logs the response status code, which we'd obtain
by wrapping the `send` with the `send_wrapper` callable:

```python
class LogStatusCodeMiddleware:
def __init__(self, app):
self.app = app

async def __call__(self, scope, receive, send):
if scope["type"] != "http":
return await self.app(scope, receive, send)

status_code = 500

async def send_wrapper(message):
if message["type"] == "http.response.start":
status_code = message["status"]
await send(message)

await self.app(scope, receive, send_wrapper)

print("This is a primitive access log")
print(f"status = {status_code}")
```

!!! info
You can check a more advanced implementation of the same rationale on [asgi-logger](https://github.com/Kludex/asgi-logger/blob/main/asgi_logger/middleware.py).

#### Type annotations

There are two ways of annotating a middleware: using Starlette itself or [`asgiref`](https://github.com/django/asgiref).

Using Starlette, you can do as:

```python
from starlette.types import ASGIApp, Message, Scope, Receive, Send


class ASGIMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
return await self.app(scope, receive, send)

async def send_wrapper(message: Message) -> None:
# ... Do something
await send(message)

await self.app(scope, receive, send_wrapper)
```

Although this is easy, you may prefer to be more strict. In which case, you'd need to use `asgiref`:

```python
from asgiref.typing import ASGI3Application, Scope, ASGISendCallable
from asgiref.typing import ASGIReceiveEvent, ASGISendEvent


class ASGIMiddleware:
def __init__(self, app: ASGI3Application) -> None:
self.app = app

async def __call__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
if scope["type"] == "http":
async def send_wrapper(message: ASGISendEvent) -> None:
await send(message)
return await self.app(scope, receive, send_wrapper)
await self.app(scope, receive, send)
```

!!! info
If you're curious about the `ASGI3Application` type on the snippet above, you can read more about ASGI versions on the [Legacy Applications section on the ASGI
documentation](https://asgi.readthedocs.io/en/latest/specs/main.html#legacy-applications).

### Reusing Starlette components

If you need to work with request or response data, you may find it more convenient to reuse Starlette data structures ([`Request`](requests.md#request), [`Headers`](requests.md#headers), [`QueryParams`](requests.md#query-parameters), [`URL`](requests.md#url), etc) rather than work with raw ASGI data. All these components can be built from the ASGI `scope`, `receive` and `send`, allowing you to work on pure ASGI middleware at a higher level of abstraction.

For example, we can create a `Request` object, and work with it.
```python
from starlette.requests import Request

class ASGIMiddleware:
def __init__(self, app):
self.app = app

async def __call__(self, scope, receive, send):
if scope["type"] == "http":
request = Request(scope, receive, send)
# Do something here!
await self.app(scope, receive, send)
```

Or we might use `MutableHeaders` to change the response headers:

```python
from starlette.datastructures import MutableHeaders

class ExtraResponseHeadersMiddleware:
def __init__(self, app, headers):
self.app = app
self.headers = headers

async def __call__(self, scope, receive, send):
if scope["type"] != "http":
return await self.app(scope, receive, send)

async def wrapped_send(message):
if message["type"] == "http.response.start":
headers = MutableHeaders(scope=message)
for key, value in self.headers:
headers.append(key, value)
await send(message)

await self.app(scope, receive, wrapped_send)
```

### Per-request state

ASGI middleware classes should be stateless, as we typically don't want to leak state across requests.

The risk is low when defining wrappers inside `__call__`, as state would typically be defined as inline variables.

But if the middleware grows larger and more complex, you might be tempted to refactor wrappers as methods. Still, state should not be stored in the middleware instance. This means that the middleware should not have attributes that would change across requests or connections. If the middleware has an attribute, for example set in the `__init__()` method, nothing else should change it afterwards. Instead, if you need to manipulate per-request state, you may write a separate `Responder` class:

```python
from functools import partial

from starlette.datastructures import Headers

class TweakMiddleware:
"""
Make a change to the response body if 'X-Tweak' is
present in the response headers.
"""

async def __call_(self, scope, receive, send):
if scope["type"] != "http":
return await self.app(scope, receive, send)

responder = TweakResponder(self.app)
await responder(scope, receive, send)

class TweakResponder:
def __init__(self, app):
self.app = app
self.should_tweak = False

async def __call__(self, scope, receive, send):
send = partial(self.maybe_send_with_tweaks, send=send)
await self.app(scope, receive, send)

async def maybe_send_with_tweaks(self, message, send):
if message["type"] == "http.response.start":
headers = Headers(raw=message["headers"])
self.should_tweak = headers.get("X-Tweak") == "1"
await send(message)
return

if message["type"] == "http.response.body":
if not self.should_tweak:
await send(message)
return

# Actually tweak the response body...
```

See also [`GZipMiddleware`](https://github.com/encode/starlette/blob/9ef1b91c9c043197da6c3f38aa153fd874b95527/starlette/middleware/gzip.py) for a full example of this pattern.

### Storing context in `scope`

As we know by now, the `scope` holds the information about the connection.

As per the ASGI specifications, any application can store custom information on the `scope`.
Have in mind that other components could also store data in the same `scope`, so it's important to use a key that has a low chance of being used by other things.

For example, if you are building an application called `super-app` you could have that as a prefix for any keys you put in the `scope`, and then you could have a key called `super-app-transaction-id`.

```python
from uuid import uuid4


class ASGIMiddleware:
def __init__(self, app):
self.app = app

async def __call__(self, scope, receive, send):
scope["super-app-transaction-id"] = uuid4()
await self.app(scope, receive, send)
```
On the example above, we stored a key called "super-app-transaction-id" in the scope. That can be used by the application itself, as the scope is forwarded to it.

!!! important
This documentation should be enough to have a good basis on how to create an ASGI middleware.
Nonetheless, there are great articles about the subject:

- [Introduction to ASGI: Emergence of an Async Python Web Ecosystem](https://florimond.dev/en/posts/2019/08/introduction-to-asgi-async-python-web/)
- [How to write ASGI middleware](https://pgjones.dev/blog/how-to-write-asgi-middleware-2021/)

## Using middleware in other frameworks

To wrap ASGI middleware around other ASGI applications, you should use the
Expand Down
7 changes: 7 additions & 0 deletions docs/release-notes.md
@@ -1,3 +1,10 @@
## 0.20.4

June 28, 2022

### Fixed
* Remove converter from path when generating OpenAPI schema [#1648](https://github.com/encode/starlette/pull/1648).

## 0.20.3

June 10, 2022
Expand Down
7 changes: 7 additions & 0 deletions docs/requests.md
Expand Up @@ -142,6 +142,13 @@ filename = form["upload_file"].filename
contents = await form["upload_file"].read()
```

!!! info
As settled in [RFC-7578: 4.2](https://www.ietf.org/rfc/rfc7578.txt), form-data content part that contains file
assumed to have `name` and `filename` fields in `Content-Disposition` header: `Content-Disposition: form-
data; name="user"; filename="somefile"`. Though `filename` field is optional according to RFC-7578, it helps
Starlette to differentiate which data should be treated as file. If `filename` field was supplied, `UploadFile`
object will be created to access underlying file, otherwise form-data part will be parsed and available as a raw
string.

#### Application

Expand Down
10 changes: 5 additions & 5 deletions requirements.txt
Expand Up @@ -8,17 +8,17 @@ coverage==6.2
databases[sqlite]==0.5.5
flake8==3.9.2
isort==5.10.1
mypy==0.960
mypy==0.961
types-requests==2.26.3
types-contextvars==2.4.6
types-contextvars==2.4.7
types-PyYAML==6.0.4
types-dataclasses==0.6.5
types-dataclasses==0.6.6
pytest==7.1.2
trio==0.19.0
trio==0.21.0

# Documentation
mkdocs==1.3.0
mkdocs-material==8.2.8
mkdocs-material==8.3.8
mkautodoc==0.1.0

# Packaging
Expand Down
2 changes: 1 addition & 1 deletion starlette/__init__.py
@@ -1 +1 @@
__version__ = "0.20.3"
__version__ = "0.20.4"
12 changes: 6 additions & 6 deletions starlette/requests.py
@@ -1,6 +1,5 @@
import json
import typing
from collections.abc import Mapping
from http import cookies as http_cookies

import anyio
Expand Down Expand Up @@ -60,7 +59,7 @@ class ClientDisconnect(Exception):
pass


class HTTPConnection(Mapping):
class HTTPConnection(typing.Mapping[str, typing.Any]):
"""
A base class for incoming HTTP connections, that is used to provide
any functionality that is common to both `Request` and `WebSocket`.
Expand Down Expand Up @@ -143,7 +142,7 @@ def client(self) -> typing.Optional[Address]:
return None

@property
def session(self) -> dict:
def session(self) -> typing.Dict[str, typing.Any]:
assert (
"session" in self.scope
), "SessionMiddleware must be installed to access request.session"
Expand Down Expand Up @@ -231,7 +230,7 @@ async def stream(self) -> typing.AsyncGenerator[bytes, None]:

async def body(self) -> bytes:
if not hasattr(self, "_body"):
chunks = []
chunks: "typing.List[bytes]" = []
async for chunk in self.stream():
chunks.append(chunk)
self._body = b"".join(chunks)
Expand All @@ -249,7 +248,8 @@ async def form(self) -> FormData:
parse_options_header is not None
), "The `python-multipart` library must be installed to use form parsing."
content_type_header = self.headers.get("Content-Type")
content_type, options = parse_options_header(content_type_header)
content_type: bytes
content_type, _ = parse_options_header(content_type_header)
if content_type == b"multipart/form-data":
try:
multipart_parser = MultiPartParser(self.headers, self.stream())
Expand Down Expand Up @@ -285,7 +285,7 @@ async def is_disconnected(self) -> bool:

async def send_push_promise(self, path: str) -> None:
if "http.response.push" in self.scope.get("extensions", {}):
raw_headers = []
raw_headers: "typing.List[typing.Tuple[bytes, bytes]]" = []
for name in SERVER_PUSH_HEADERS_TO_COPY:
for value in self.headers.getlist(name):
raw_headers.append(
Expand Down

0 comments on commit f5cb1ee

Please sign in to comment.