diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 55dbe8564..b9038ca11 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -4,3 +4,7 @@ updates: directory: "/" schedule: interval: "monthly" + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: monthly diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 45debc642..010cb3ab8 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,4 +1,4 @@ -The starting point for contributions should usually be [a discussion](https://github.com/encode/httpx/discussions) +The starting point for contributions should usually be [a discussion](https://github.com/encode/starlette/discussions) Simple documentation typos may be raised as stand-alone pull requests, but otherwise please ensure you've discussed your proposal prior to issuing a pull request. diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 514ed4b97..85350472c 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -14,8 +14,8 @@ jobs: name: deploy steps: - - uses: "actions/checkout@v2" - - uses: "actions/setup-python@v2" + - uses: "actions/checkout@v3" + - uses: "actions/setup-python@v4" with: python-version: 3.7 - name: "Install dependencies" diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index 86b4a03d4..ccd112436 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -14,11 +14,11 @@ jobs: strategy: matrix: - python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11-dev"] steps: - - uses: "actions/checkout@v2" - - uses: "actions/setup-python@v2" + - uses: "actions/checkout@v3" + - uses: "actions/setup-python@v4" with: python-version: "${{ matrix.python-version }}" - name: "Install dependencies" diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 9cccc91b7..000000000 --- a/MANIFEST.in +++ /dev/null @@ -1,3 +0,0 @@ -include LICENSE.md -global-exclude __pycache__ -global-exclude *.py[co] diff --git a/README.md b/README.md index c5940b440..afc3a82bd 100644 --- a/README.md +++ b/README.md @@ -37,11 +37,11 @@ It is production-ready, and gives you the following: * 100% type annotated codebase. * Few hard dependencies. * Compatible with `asyncio` and `trio` backends. -* Great overall performance [against independant benchmarks][techempower]. +* Great overall performance [against independent benchmarks][techempower]. ## Requirements -Python 3.6+ +Python 3.7+ (For Python 3.6 support, install version 0.19.1) ## Installation diff --git a/docs/authentication.md b/docs/authentication.md index d6cec3fb2..6fe71d4e1 100644 --- a/docs/authentication.md +++ b/docs/authentication.md @@ -115,6 +115,10 @@ async def dashboard(request): ... ``` +!!! note + The `status_code` parameter is not supported with WebSockets. The 403 (Forbidden) + status code will always be used for those. + Alternatively you might want to redirect unauthenticated users to a different page. @@ -174,6 +178,8 @@ You can customise the error response sent when a `AuthenticationError` is raised by an auth backend: ```python +from starlette.applications import Starlette +from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware from starlette.requests import Request from starlette.responses import JSONResponse @@ -182,5 +188,9 @@ from starlette.responses import JSONResponse def on_auth_error(request: Request, exc: Exception): return JSONResponse({"error": str(exc)}, status_code=401) -app.add_middleware(AuthenticationMiddleware, backend=BasicAuthBackend(), on_error=on_auth_error) +app = Starlette( + middleware=[ + Middleware(AuthenticationMiddleware, backend=BasicAuthBackend(), on_error=on_auth_error), + ], +) ``` diff --git a/docs/background.md b/docs/background.md index e10832a92..a6bfd8c5f 100644 --- a/docs/background.md +++ b/docs/background.md @@ -72,3 +72,7 @@ routes = [ app = Starlette(routes=routes) ``` + +!!! important + The tasks are executed in order. In case one of the tasks raises + an exception, the following tasks will not get the opportunity to be executed. diff --git a/docs/contributing.md b/docs/contributing.md new file mode 100644 index 000000000..fcb88e5f1 --- /dev/null +++ b/docs/contributing.md @@ -0,0 +1,170 @@ +# Contributing + +Thank you for being interested in contributing to Starlette. +There are many ways you can contribute to the project: + +- Try Starlette and [report bugs/issues you find](https://github.com/encode/starlette/issues/new) +- [Implement new features](https://github.com/encode/starlette/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) +- [Review Pull Requests of others](https://github.com/encode/starlette/pulls) +- Write documentation +- Participate in discussions + +## Reporting Bugs or Other Issues + +Found something that Starlette should support? +Stumbled upon some unexpected behaviour? + +Contributions should generally start out with [a discussion](https://github.com/encode/starlette/discussions). +Possible bugs may be raised as a "Potential Issue" discussion, feature requests may +be raised as an "Ideas" discussion. We can then determine if the discussion needs +to be escalated into an "Issue" or not, or if we'd consider a pull request. + +Try to be more descriptive as you can and in case of a bug report, +provide as much information as possible like: + +- OS platform +- Python version +- Installed dependencies and versions (`python -m pip freeze`) +- Code snippet +- Error traceback + +You should always try to reduce any examples to the *simplest possible case* +that demonstrates the issue. + +## Development + +To start developing Starlette, create a **fork** of the +[Starlette repository](https://github.com/encode/starlette) on GitHub. + +Then clone your fork with the following command replacing `YOUR-USERNAME` with +your GitHub username: + +```shell +$ git clone https://github.com/YOUR-USERNAME/starlette +``` + +You can now install the project and its dependencies using: + +```shell +$ cd starlette +$ scripts/install +``` + +## Testing and Linting + +We use custom shell scripts to automate testing, linting, +and documentation building workflow. + +To run the tests, use: + +```shell +$ scripts/test +``` + +Any additional arguments will be passed to `pytest`. See the [pytest documentation](https://docs.pytest.org/en/latest/how-to/usage.html) for more information. + +For example, to run a single test script: + +```shell +$ scripts/test tests/test_application.py +``` + +To run the code auto-formatting: + +```shell +$ scripts/lint +``` + +Lastly, to run code checks separately (they are also run as part of `scripts/test`), run: + +```shell +$ scripts/check +``` + +## Documenting + +Documentation pages are located under the `docs/` folder. + +To run the documentation site locally (useful for previewing changes), use: + +```shell +$ scripts/docs +``` + +## Resolving Build / CI Failures + +Once you've submitted your pull request, the test suite will automatically run, and the results will show up in GitHub. +If the test suite fails, you'll want to click through to the "Details" link, and try to identify why the test suite failed. + +

+ Failing PR commit status +

+ +Here are some common ways the test suite can fail: + +### Check Job Failed + +

+ Failing GitHub action lint job +

+ +This job failing means there is either a code formatting issue or type-annotation issue. +You can look at the job output to figure out why it's failed or within a shell run: + +```shell +$ scripts/check +``` + +It may be worth it to run `$ scripts/lint` to attempt auto-formatting the code +and if that job succeeds commit the changes. + +### Docs Job Failed + +This job failing means the documentation failed to build. This can happen for +a variety of reasons like invalid markdown or missing configuration within `mkdocs.yml`. + +### Python 3.X Job Failed + +

+ Failing GitHub action test job +

+ +This job failing means the unit tests failed or not all code paths are covered by unit tests. + +If tests are failing you will see this message under the coverage report: + +`=== 1 failed, 435 passed, 1 skipped, 1 xfailed in 11.09s ===` + +If tests succeed but coverage doesn't reach our current threshold, you will see this +message under the coverage report: + +`FAIL Required test coverage of 100% not reached. Total coverage: 99.00%` + +## Releasing + +*This section is targeted at Starlette maintainers.* + +Before releasing a new version, create a pull request that includes: + +- **An update to the changelog**: + - We follow the format from [keepachangelog](https://keepachangelog.com/en/1.0.0/). + - [Compare](https://github.com/encode/starlette/compare/) `master` with the tag of the latest release, and list all entries that are of interest to our users: + - Things that **must** go in the changelog: added, changed, deprecated or removed features, and bug fixes. + - Things that **should not** go in the changelog: changes to documentation, tests or tooling. + - Try sorting entries in descending order of impact / importance. + - Keep it concise and to-the-point. 🎯 +- **A version bump**: see `__version__.py`. + +For an example, see [#1600](https://github.com/encode/starlette/pull/1600). + +Once the release PR is merged, create a +[new release](https://github.com/encode/starlette/releases/new) including: + +- Tag version like `0.13.3`. +- Release title `Version 0.13.3` +- Description copied from the changelog. + +Once created this release will be automatically uploaded to PyPI. + +If something goes wrong with the PyPI job the release can be published using the +`scripts/publish` script. diff --git a/docs/img/gh-actions-fail-check.png b/docs/img/gh-actions-fail-check.png new file mode 100644 index 000000000..a1e69a661 Binary files /dev/null and b/docs/img/gh-actions-fail-check.png differ diff --git a/docs/img/gh-actions-fail-test.png b/docs/img/gh-actions-fail-test.png new file mode 100644 index 000000000..b02d3c062 Binary files /dev/null and b/docs/img/gh-actions-fail-test.png differ diff --git a/docs/img/gh-actions-fail.png b/docs/img/gh-actions-fail.png new file mode 100644 index 000000000..901d72b0e Binary files /dev/null and b/docs/img/gh-actions-fail.png differ diff --git a/docs/index.md b/docs/index.md index b618fe3ea..9f1977875 100644 --- a/docs/index.md +++ b/docs/index.md @@ -34,11 +34,11 @@ It is production-ready, and gives you the following: * 100% type annotated codebase. * Few hard dependencies. * Compatible with `asyncio` and `trio` backends. -* Great overall performance [against independant benchmarks][techempower]. +* Great overall performance [against independent benchmarks][techempower]. ## Requirements -Python 3.6+ +Python 3.7+ (For Python 3.6 support, install version 0.19.1) ## Installation diff --git a/docs/middleware.md b/docs/middleware.md index b21914291..6b8ccf902 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -3,6 +3,8 @@ Starlette includes several middleware classes for adding behavior that is applie your entire application. These are all implemented as standard ASGI middleware classes, and can be applied either to Starlette or to any other ASGI application. +## Using middleware + The Starlette application class allows you to include the ASGI middleware in a way that ensures that it remains wrapped by the exception handler. @@ -14,11 +16,14 @@ from starlette.middleware.trustedhost import TrustedHostMiddleware routes = ... -# Ensure that all requests include an 'example.com' or '*.example.com' host header, -# and strictly enforce https-only access. +# Ensure that all requests include an 'example.com' or +# '*.example.com' host header, and strictly enforce https-only access. middleware = [ - Middleware(TrustedHostMiddleware, allowed_hosts=['example.com', '*.example.com']), - Middleware(HTTPSRedirectMiddleware) + Middleware( + TrustedHostMiddleware, + allowed_hosts=['example.com', '*.example.com'], + ), + Middleware(HTTPSRedirectMiddleware) ] app = Starlette(routes=routes, middleware=middleware) @@ -177,7 +182,9 @@ The following arguments are supported: ## BaseHTTPMiddleware An abstract class that allows you to write ASGI middleware against a request/response -interface, rather than dealing with ASGI messages directly. +interface. + +### Usage To implement a middleware class using `BaseHTTPMiddleware`, you must override the `async def dispatch(request, call_next)` method. @@ -191,6 +198,8 @@ class CustomHeaderMiddleware(BaseHTTPMiddleware): response.headers['Custom'] = 'Example' return response +routes = ... + middleware = [ Middleware(CustomHeaderMiddleware) ] @@ -215,7 +224,6 @@ class CustomHeaderMiddleware(BaseHTTPMiddleware): return response - middleware = [ Middleware(CustomHeaderMiddleware, header_value='Customized') ] @@ -227,6 +235,439 @@ Middleware classes should not modify their state outside of the `__init__` metho Instead you should keep any state local to the `dispatch` method, or pass it around explicitly, rather than mutating the middleware instance. +### Limitations + +Currently, the `BaseHTTPMiddleware` has some known limitations: + +- It's not possible to use `BackgroundTasks` with `BaseHTTPMiddleware`. Check [#1438](https://github.com/encode/starlette/issues/1438) for more details. +- Using `BaseHTTPMiddleware` will prevent changes to [`contextlib.ContextVar`](https://docs.python.org/3/library/contextvars.html#contextvars.ContextVar)s from propagating upwards. That is, if you set a value for a `ContextVar` in your endpoint and try to read it from a middleware you will find that the value is not the same value you set in your endpoint (see [this test](https://github.com/encode/starlette/blob/621abc747a6604825190b93467918a0ec6456a24/tests/middleware/test_base.py#L192-L223) for an example of this behavior). + +To overcome these limitations, use [pure ASGI middleware](#pure-asgi-middleware), as shown below. + +## Pure ASGI Middleware + +The [ASGI spec](https://asgi.readthedocs.io/en/latest/) makes it possible to implement ASGI middleware using the ASGI interface directly, as a chain of ASGI applications that call into the next one. In fact, this is how middleware classes shipped with Starlette are implemented. + +This lower-level approach provides greater control over behavior and enhanced interoperability across frameworks and servers. It also overcomes the [limitations of `BaseHTTPMiddleware`](#limitations). + +### Writing pure ASGI middleware + +The most common way to create an ASGI middleware is with a class. + +```python +class ASGIMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + await self.app(scope, receive, send) +``` + +The middleware above is the most basic ASGI middleware. It receives a parent ASGI application as an argument for its constructor, and implements an `async __call__` method which calls into that parent application. + +Some implementations such as [`asgi-cors`](https://github.com/simonw/asgi-cors/blob/10ef64bfcc6cd8d16f3014077f20a0fb8544ec39/asgi_cors.py) use an alternative style, using functions: + +```python +import functools + +def asgi_middleware(): + def asgi_decorator(app): + + @functools.wraps(app) + async def wrapped_app(scope, receive, send): + await app(scope, receive, send) + + return wrapped_app + + return asgi_decorator +``` + +In any case, ASGI middleware must be callables that accept three arguments: `scope`, `receive`, and `send`. + +* `scope` is a dict holding information about the connection, where `scope["type"]` may be: + * [`"http"`](https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope): for HTTP requests. + * [`"websocket"`](https://asgi.readthedocs.io/en/latest/specs/www.html#websocket-connection-scope): for WebSocket connections. + * [`"lifespan"`](https://asgi.readthedocs.io/en/latest/specs/lifespan.html#scope): for ASGI lifespan messages. +* `receive` and `send` can be used to exchange ASGI event messages with the ASGI server — more on this below. The type and contents of these messages depend on the scope type. Learn more in the [ASGI specification](https://asgi.readthedocs.io/en/latest/specs/index.html). + +### Using pure ASGI middleware + +Pure ASGI middleware can be used like any other middleware: + +```python +from starlette.applications import Starlette +from starlette.middleware import Middleware + +from .middleware import ASGIMiddleware + +routes = ... + +middleware = [ + Middleware(ASGIMiddleware), +] + +app = Starlette(..., middleware=middleware) +``` + +See also [Using middleware](#using-middleware). + +### Type annotations + +There are two ways of annotating a middleware: using Starlette itself or [`asgiref`](https://github.com/django/asgiref). + +* Using Starlette: for most common use cases. + +```python +from starlette.types import ASGIApp, Message, Scope, Receive, Send + + +class ASGIMiddleware: + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + return await self.app(scope, receive, send) + + async def send_wrapper(message: Message) -> None: + # ... Do something + await send(message) + + await self.app(scope, receive, send_wrapper) +``` + +* Using [`asgiref`](https://github.com/django/asgiref): for more rigorous type hinting. + +```python +from asgiref.typing import ASGI3Application, ASGIReceiveCallable, ASGISendCallable, Scope +from asgiref.typing import ASGIReceiveEvent, ASGISendEvent + + +class ASGIMiddleware: + def __init__(self, app: ASGI3Application) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + async def send_wrapper(message: ASGISendEvent) -> None: + # ... Do something + await send(message) + + return await self.app(scope, receive, send_wrapper) +``` + +### Common patterns + +#### Processing certain requests only + +ASGI middleware can apply specific behavior according to the contents of `scope`. + +For example, to only process HTTP requests, write this... + +```python +class ASGIMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + ... # Do something here! + + await self.app(scope, receive, send) +``` + +Likewise, WebSocket-only middleware would guard on `scope["type"] != "websocket"`. + +The middleware may also act differently based on the request method, URL, headers, etc. + +#### Reusing Starlette components + +Starlette provides several data structures that accept the ASGI `scope`, `receive` and/or `send` arguments, allowing you to work at a higher level of abstraction. Such data structures include [`Request`](requests.md#request), [`Headers`](requests.md#headers), [`QueryParams`](requests.md#query-parameters), [`URL`](requests.md#url), etc. + +For example, you can instantiate a `Request` to more easily inspect an HTTP request: + +```python +from starlette.requests import Request + +class ASGIMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + if scope["type"] == "http": + request = Request(scope) + ... # Use `request.method`, `request.url`, `request.headers`, etc. + + await self.app(scope, receive, send) +``` + +You can also reuse [responses](responses.md), which are ASGI applications as well. + +#### Sending eager responses + +Inspecting the connection `scope` allows you to conditionally call into a different ASGI app. One use case might be sending a response without calling into the app. + +As an example, this middleware uses a dictionary to perform permanent redirects based on the requested path. This could be used to implement ongoing support of legacy URLs in case you need to refactor route URL patterns. + +```python +from starlette.datastructures import URL +from starlette.responses import RedirectResponse + +class RedirectsMiddleware: + def __init__(self, app, path_mapping: dict): + self.app = app + self.path_mapping = path_mapping + + async def __call__(self, scope, receive, send): + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + url = URL(scope=scope) + + if url.path in self.path_mapping: + url = url.replace(path=self.path_mapping[url.path]) + response = RedirectResponse(url, status_code=301) + await response(scope, receive, send) + return + + await self.app(scope, receive, send) +``` + +Example usage would look like this: + +```python +from starlette.applications import Starlette +from starlette.middleware import Middleware + +routes = ... + +redirections = { + "/v1/resource/": "/v2/resource/", + # ... +} + +middleware = [ + Middleware(RedirectsMiddleware, path_mapping=redirections), +] + +app = Starlette(routes=routes, middleware=middleware) +``` + + +#### Inspecting or modifying the request + +Request information can be accessed or changed by manipulating the `scope`. For a full example of this pattern, see Uvicorn's [`ProxyHeadersMiddleware`](https://github.com/encode/uvicorn/blob/fd4386fefb8fe8a4568831a7d8b2930d5fb61455/uvicorn/middleware/proxy_headers.py) which inspects and tweaks the `scope` when serving behind a frontend proxy. + +Besides, wrapping the `receive` ASGI callable allows you to access or modify the HTTP request body by manipulating [`http.request`](https://asgi.readthedocs.io/en/latest/specs/www.html#request-receive-event) ASGI event messages. + +As an example, this middleware computes and logs the size of the incoming request body... + +```python +class LoggedRequestBodySizeMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + body_size = 0 + + async def receive_logging_request_body_size(): + nonlocal body_size + + message = await receive() + assert message["type"] == "http.request" + + body_size += len(message.get("body", b"")) + + if not message.get("more_body", False): + print(f"Size of request body was: {body_size} bytes") + + return message + + await self.app(scope, receive_logging_request_body_size, send) +``` + +Likewise, WebSocket middleware may manipulate [`websocket.receive`](https://asgi.readthedocs.io/en/latest/specs/www.html#receive-receive-event) ASGI event messages to inspect or alter incoming WebSocket data. + +For an example that changes the HTTP request body, see [`msgpack-asgi`](https://github.com/florimondmanca/msgpack-asgi). + +#### Inspecting or modifying the response + +Wrapping the `send` ASGI callable allows you to inspect or modify the HTTP response sent by the underlying application. To do so, react to [`http.response.start`](https://asgi.readthedocs.io/en/latest/specs/www.html#response-start-send-event) or [`http.response.body`](https://asgi.readthedocs.io/en/latest/specs/www.html#response-body-send-event) ASGI event messages. + +As an example, this middleware adds some fixed extra response headers: + +```python +from starlette.datastructures import MutableHeaders + +class ExtraResponseHeadersMiddleware: + def __init__(self, app, headers): + self.app = app + self.headers = headers + + async def __call__(self, scope, receive, send): + if scope["type"] != "http": + return await self.app(scope, receive, send) + + async def send_with_extra_headers(message): + if message["type"] == "http.response.start": + headers = MutableHeaders(scope=message) + for key, value in self.headers: + headers.append(key, value) + + await send(message) + + await self.app(scope, receive, send_with_extra_headers) +``` + +See also [`asgi-logger`](https://github.com/Kludex/asgi-logger/blob/main/asgi_logger/middleware.py) for an example that inspects the HTTP response and logs a configurable HTTP access log line. + +Likewise, WebSocket middleware may manipulate [`websocket.send`](https://asgi.readthedocs.io/en/latest/specs/www.html#send-send-event) ASGI event messages to inspect or alter outgoing WebSocket data. + +Note that if you change the response body, you will need to update the response `Content-Length` header to match the new response body length. See [`brotli-asgi`](https://github.com/fullonic/brotli-asgi) for a complete example. + +#### Passing information to endpoints + +If you need to share information with the underlying app or endpoints, you may store it into the `scope` dictionary. Note that this is a convention -- for example, Starlette uses this to share routing information with endpoints -- but it is not part of the ASGI specification. If you do so, be sure to avoid conflicts by using keys that have low chances of being used by other middleware or applications. + +For example, when including the middleware below, endpoints would be able to access `request.scope["asgi_transaction_id"]`. + +```python +import uuid + +class TransactionIDMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + scope["asgi_transaction_id"] = uuid.uuid4() + await self.app(scope, receive, send) +``` + +#### Cleanup and error handling + +You can wrap the application in a `try/except/finally` block or a context manager to perform cleanup operations or do error handling. + +For example, the following middleware might collect metrics and process application exceptions... + +```python +import time + +class MonitoringMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + start = time.time() + try: + await self.app(scope, receive, send) + except Exception as exc: + ... # Process the exception + raise + finally: + end = time.time() + elapsed = end - start + ... # Submit `elapsed` as a metric to a monitoring backend +``` + +See also [`timing-asgi`](https://github.com/steinnes/timing-asgi) for a full example of this pattern. + +### Gotchas + +#### ASGI middleware should be stateless + +Because ASGI is designed to handle concurrent requests, any connection-specific state should be scoped to the `__call__` implementation. Not doing so would typically lead to conflicting variable reads/writes across requests, and most likely bugs. + +As an example, this would conditionally replace the response body, if an `X-Mock` header is present in the response... + +=== "✅ Do" + + ```python + from starlette.datastructures import Headers + + class MockResponseBodyMiddleware: + def __init__(self, app, content): + self.app = app + self.content = content + + async def __call__(self, scope, receive, send): + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + # A flag that we will turn `True` if the HTTP response + # has the 'X-Mock' header. + # ✅: Scoped to this function. + should_mock = False + + async def maybe_send_with_mock_content(message): + nonlocal should_mock + + if message["type"] == "http.response.start": + headers = Headers(raw=message["headers"]) + should_mock = headers.get("X-Mock") == "1" + await send(message) + + elif message["type"] == "http.response.body": + if should_mock: + message = {"type": "http.response.body", "body": self.content} + await send(message) + + await self.app(scope, receive, maybe_send_with_mock_content) + ``` + +=== "❌ Don't" + + ```python hl_lines="7-8" + from starlette.datastructures import Headers + + class MockResponseBodyMiddleware: + def __init__(self, app, content): + self.app = app + self.content = content + # ❌: This variable would be read and written across requests! + self.should_mock = False + + async def __call__(self, scope, receive, send): + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + async def maybe_send_with_mock_content(message): + if message["type"] == "http.response.start": + headers = Headers(raw=message["headers"]) + self.should_mock = headers.get("X-Mock") == "1" + await send(message) + + elif message["type"] == "http.response.body": + if self.should_mock: + message = {"type": "http.response.body", "body": self.content} + await send(message) + + await self.app(scope, receive, maybe_send_with_mock_content) + ``` + +See also [`GZipMiddleware`](https://github.com/encode/starlette/blob/9ef1b91c9c043197da6c3f38aa153fd874b95527/starlette/middleware/gzip.py) for a full example implementation that navigates this potential gotcha. + +### Further reading + +This documentation should be enough to have a good basis on how to create an ASGI middleware. + +Nonetheless, there are great articles about the subject: + +- [Introduction to ASGI: Emergence of an Async Python Web Ecosystem](https://florimond.dev/en/posts/2019/08/introduction-to-asgi-async-python-web/) +- [How to write ASGI middleware](https://pgjones.dev/blog/how-to-write-asgi-middleware-2021/) + ## Using middleware in other frameworks To wrap ASGI middleware around other ASGI applications, you should use the diff --git a/docs/release-notes.md b/docs/release-notes.md index f2cb34cce..c2d098eb8 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -1,3 +1,56 @@ +## 0.20.4 + +June 28, 2022 + +### Fixed +* Remove converter from path when generating OpenAPI schema [#1648](https://github.com/encode/starlette/pull/1648). + +## 0.20.3 + +June 10, 2022 + +### Fixed +* Revert "Allow `StaticFiles` to follow symlinks" [#1681](https://github.com/encode/starlette/pull/1681). + +## 0.20.2 + +June 7, 2022 + +### Fixed +* Fix regression on route paths with colons [#1675](https://github.com/encode/starlette/pull/1675). +* Allow `StaticFiles` to follow symlinks [#1337](https://github.com/encode/starlette/pull/1377). + +## 0.20.1 + +May 28, 2022 + +### Fixed +* Improve detection of async callables [#1444](https://github.com/encode/starlette/pull/1444). +* Send 400 (Bad Request) when `boundary` is missing [#1617](https://github.com/encode/starlette/pull/1617). +* Send 400 (Bad Request) when missing "name" field on `Content-Disposition` header [#1643](https://github.com/encode/starlette/pull/1643). +* Do not send empty data to `StreamingResponse` on `BaseHTTPMiddleware` [#1609](https://github.com/encode/starlette/pull/1609). +* Add `__bool__` dunder for `Secret` [#1625](https://github.com/encode/starlette/pull/1625). + +## 0.20.0 + +May 3, 2022 + +### Removed +* Drop Python 3.6 support [#1357](https://github.com/encode/starlette/pull/1357) and [#1616](https://github.com/encode/starlette/pull/1616). + + +## 0.19.1 + +April 22, 2022 + +### Fixed +* Fix inference of `Route.name` when created from methods [#1553](https://github.com/encode/starlette/pull/1553). +* Avoid `TypeError` on `websocket.disconnect` when code is `None` [#1574](https://github.com/encode/starlette/pull/1574). + +### Deprecated +* Deprecate `WS_1004_NO_STATUS_RCVD` and `WS_1005_ABNORMAL_CLOSURE` in favor of `WS_1005_NO_STATUS_RCVD` and `WS_1006_ABNORMAL_CLOSURE`, as the previous constants didn't match the [WebSockets specs](https://www.iana.org/assignments/websocket/websocket.xhtml) [#1580](https://github.com/encode/starlette/pull/1580). + + ## 0.19.0 March 9, 2022 diff --git a/docs/requests.md b/docs/requests.md index 872946638..747e496d1 100644 --- a/docs/requests.md +++ b/docs/requests.md @@ -142,6 +142,13 @@ filename = form["upload_file"].filename contents = await form["upload_file"].read() ``` +!!! info + As settled in [RFC-7578: 4.2](https://www.ietf.org/rfc/rfc7578.txt), form-data content part that contains file + assumed to have `name` and `filename` fields in `Content-Disposition` header: `Content-Disposition: form-data; + name="user"; filename="somefile"`. Though `filename` field is optional according to RFC-7578, it helps + Starlette to differentiate which data should be treated as file. If `filename` field was supplied, `UploadFile` + object will be created to access underlying file, otherwise form-data part will be parsed and available as a raw + string. #### Application diff --git a/docs/responses.md b/docs/responses.md index ce91f7ffa..a8c1b78e4 100644 --- a/docs/responses.md +++ b/docs/responses.md @@ -188,3 +188,7 @@ async def app(scope, receive, send): #### [EventSourceResponse](https://github.com/sysid/sse-starlette) A response class that implements [Server-Sent Events](https://html.spec.whatwg.org/multipage/server-sent-events.html). It enables event streaming from the server to the client without the complexity of websockets. + +#### [baize.asgi.FileResponse](https://baize.aber.sh/asgi#fileresponse) + +As a smooth replacement for Starlette [`FileResponse`](https://www.starlette.io/responses/#fileresponse), it will automatically handle [Head method](https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/HEAD) and [Range requests](https://developer.mozilla.org/en-US/docs/Web/HTTP/Range_requests). diff --git a/docs/schemas.md b/docs/schemas.md index 275e7b296..fed596dbc 100644 --- a/docs/schemas.md +++ b/docs/schemas.md @@ -46,8 +46,8 @@ def openapi_schema(request): routes = [ - Route("/users", endpoint=list_users, methods=["GET"]) - Route("/users", endpoint=create_user, methods=["POST"]) + Route("/users", endpoint=list_users, methods=["GET"]), + Route("/users", endpoint=create_user, methods=["POST"]), Route("/schema", endpoint=openapi_schema, include_in_schema=False) ] diff --git a/docs/testclient.md b/docs/testclient.md index 16680b3dd..053b42005 100644 --- a/docs/testclient.md +++ b/docs/testclient.md @@ -26,11 +26,49 @@ function calls, not awaitables. You can use any of `httpx` standard API, such as authentication, session cookies handling, or file uploads. +For example, to set headers on the TestClient you can do: + +```python +client = TestClient(app) + +# Set headers on the client for future requests +client.headers = {"Authorization": "..."} +response = client.get("/") + +# Set headers for each request separately +response = client.get("/", headers={"Authorization": "..."}) +``` + +And for example to send files with the TestClient: + +```python +client = TestClient(app) + +# Send a single file +with open("example.txt", "rb") as f: + response = client.post("/form", files={"file": f}) + +# Send multiple files +with open("example.txt", "rb") as f1: + with open("example.png", "rb") as f2: + files = {"file1": f1, "file2": ("filename", f2, "image/png")} + response = client.post("/form", files=files) +``` + +For more information you can check the `requests` [documentation](https://requests.readthedocs.io/en/master/user/advanced/). + By default the `TestClient` will raise any exceptions that occur in the application. Occasionally you might want to test the content of 500 error responses, rather than allowing client to raise the server exception. In this case you should use `client = TestClient(app, raise_server_exceptions=False)`. +!!! note + + If you want the `TestClient` to run `lifespan` events (`on_startup`, `on_shutdown`, or `lifespan`), + you will need to use the `TestClient` as a context manager. Otherwise, the events + will not be triggered when the `TestClient` is instantiated. You can learn more about it + [here](/events/#running-event-handlers-in-tests). + ### Selecting the Async backend `TestClient` takes arguments `backend` (a string) and `backend_options` (a dictionary). @@ -97,6 +135,15 @@ May raise `starlette.websockets.WebSocketDisconnect` if the application does not `websocket_connect()` must be used as a context manager (in a `with` block). +!!! note + The `params` argument is not supported by `websocket_connect`. If you need to pass query arguments, hard code it + directly in the URL. + + ```python + with client.websocket_connect('/path?foo=bar') as websocket: + ... + ``` + #### Sending data * `.send_text(data)` - Send the given text to the application. @@ -114,3 +161,41 @@ May raise `starlette.websockets.WebSocketDisconnect`. #### Closing the connection * `.close(code=1000)` - Perform a client-side close of the websocket connection. + +### Asynchronous tests + +Sometimes you will want to do async things outside of your application. +For example, you might want to check the state of your database after calling your app using your existing async database client / infrastructure. + +For these situations, using `TestClient` is difficult because it creates it's own event loop and async resources (like a database connection) often cannot be shared across event loops. +The simplest way to work around this is to just make your entire test async and use an async client, like [httpx.AsyncClient]. + +Here is an example of such a test: + +```python +from httpx import AsyncClient +from starlette.applications import Starlette +from starlette.routing import Route +from starlette.requests import Request +from starlette.responses import PlainTextResponse + + +def hello(request: Request) -> PlainTextResponse: + return PlainTextResponse("Hello World!") + + +app = Starlette(routes=[Route("/", hello)]) + + +# if you're using pytest, you'll need to to add an async marker like: +# @pytest.mark.anyio # using https://github.com/agronholm/anyio +# or install and configure pytest-asyncio (https://github.com/pytest-dev/pytest-asyncio) +async def test_app() -> None: + # note: you _must_ set `base_url` for relative urls like "/" to work + async with AsyncClient(app=app, base_url="http://testserver") as client: + r = await client.get("/") + assert r.status_code == 200 + assert r.text == "Hello World!" +``` + +[httpx.AsyncClient]: https://www.python-httpx.org/advanced/#calling-into-python-web-apps diff --git a/docs/third-party-packages.md b/docs/third-party-packages.md index 66af00472..01da069e0 100644 --- a/docs/third-party-packages.md +++ b/docs/third-party-packages.md @@ -97,6 +97,15 @@ A plugin for providing an endpoint that exposes [Prometheus](https://prometheus. A simple tool for integrating Starlette and WTForms. It is modeled on the excellent Flask-WTF library. +### Starlette-Login + +GitHub | +Documentation + +User session management for Starlette. +It handles the common tasks of logging in, logging out, and remembering your users' sessions over extended periods of time. + + ### Starsessions GitHub @@ -121,6 +130,13 @@ FastAPI style routing for Starlette. Allows you to use decorators to generate routing tables. +### Starception + +GitHub + +Beautiful exception page for Starlette apps. + + ## Frameworks ### FastAPI diff --git a/mkdocs.yml b/mkdocs.yml index 379f4cbc5..6b1d6cc46 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -44,6 +44,7 @@ nav: - Configuration: 'config.md' - Test Client: 'testclient.md' - Third Party Packages: 'third-party-packages.md' + - Contributing: 'contributing.md' - Release Notes: 'release-notes.md' markdown_extensions: @@ -51,6 +52,8 @@ markdown_extensions: - admonition - pymdownx.highlight - pymdownx.superfences + - pymdownx.tabbed: + alternate_style: true extra_javascript: - 'js/chat.js' diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..11dcdccfb --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,52 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "starlette" +dynamic = ["version"] +description = "The little ASGI library that shines." +readme = "README.md" +license = "BSD-3-Clause" +requires-python = ">=3.7" +authors = [ + { name = "Tom Christie", email = "tom@tomchristie.com" }, +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Environment :: Web Environment", + "Framework :: AnyIO", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Internet :: WWW/HTTP", +] +dependencies = [ + "anyio>=3.4.0,<5", + "typing_extensions>=3.10.0; python_version < '3.10'", +] + +[project.optional-dependencies] +full = [ + "itsdangerous", + "jinja2", + "python-multipart", + "pyyaml", + "httpx", +] + +[project.urls] +Homepage = "https://github.com/encode/starlette" + +[tool.hatch.version] +path = "starlette/__init__.py" + +[tool.hatch.build.targets.sdist] +include = [ + "/starlette", +] diff --git a/requirements.txt b/requirements.txt index 440c482b9..024d87e30 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,23 +3,26 @@ # Testing autoflake==1.4 -black==22.3.0 -coverage==6.2 +black==22.6.0 +coverage==6.4.2 databases[sqlite]==0.5.5 flake8==3.9.2 isort==5.10.1 -mypy==0.942 -types-contextvars==2.4.2 -types-PyYAML==6.0.4 -types-dataclasses==0.6.2 -pytest==7.0.1 -trio==0.19.0 +mypy==0.971 +typing_extensions==4.3.0 +types-contextvars==2.4.7 +types-PyYAML==6.0.11 +types-dataclasses==0.6.6 +pytest==7.1.2 +trio==0.21.0 +# NOTE: Remove once greenlet releases 2.0.0. +greenlet==2.0.0a2; python_version >= "3.11" # Documentation -mkdocs==1.3.0 -mkdocs-material==8.2.8 +mkdocs==1.3.1 +mkdocs-material==8.3.9 mkautodoc==0.1.0 # Packaging -twine==3.8.0 -wheel==0.37.1 +build==0.8.0 +twine==4.0.1 diff --git a/scripts/build b/scripts/build index 1c47d2cc2..92378cb94 100755 --- a/scripts/build +++ b/scripts/build @@ -8,6 +8,6 @@ fi set -x -${PREFIX}python setup.py sdist bdist_wheel +${PREFIX}python -m build ${PREFIX}twine check dist/* ${PREFIX}mkdocs build diff --git a/scripts/check b/scripts/check index 23d50c7c3..5a38477cf 100755 --- a/scripts/check +++ b/scripts/check @@ -8,6 +8,7 @@ export SOURCE_FILES="starlette tests" set -x +./scripts/sync-version ${PREFIX}isort --check --diff --project=starlette $SOURCE_FILES ${PREFIX}black --check --diff $SOURCE_FILES ${PREFIX}flake8 $SOURCE_FILES diff --git a/scripts/sync-version b/scripts/sync-version new file mode 100755 index 000000000..67e9ef278 --- /dev/null +++ b/scripts/sync-version @@ -0,0 +1,9 @@ +#!/bin/sh -e + +SEMVER_REGEX="([0-9]+)\.([0-9]+)\.([0-9]+)(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+[0-9A-Za-z-]+)?" +CHANGELOG_VERSION=$(grep -o -E $SEMVER_REGEX docs/release-notes.md | head -1) +VERSION=$(grep -o -E $SEMVER_REGEX starlette/__init__.py | head -1) +if [ "$CHANGELOG_VERSION" != "$VERSION" ]; then + echo "Version in changelog does not match version in starlette/__init__.py!" + exit 1 +fi diff --git a/setup.cfg b/setup.cfg index 4381655cb..23cf32cc0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,8 +5,12 @@ max-line-length = 88 [mypy] disallow_untyped_defs = True ignore_missing_imports = True +no_implicit_optional = True show_error_codes = True +[mypy-starlette.testclient] +no_implicit_optional = False + [mypy-tests.*] disallow_untyped_defs = False check_untyped_defs = True @@ -24,15 +28,13 @@ xfail_strict=True filterwarnings= # Turn warnings that aren't filtered into exceptions error - ignore: Using or importing the ABCs from 'collections' instead of from 'collections\.abc' is deprecated.*:DeprecationWarning - ignore: The 'context' alias has been deprecated. Please use 'context_value' instead\.:DeprecationWarning - ignore: The 'variables' alias has been deprecated. Please use 'variable_values' instead\.:DeprecationWarning - # Workaround for Python 3.9.7 (see https://bugs.python.org/issue45097) - ignore:The loop argument is deprecated since Python 3\.8, and scheduled for removal in Python 3\.10\.:DeprecationWarning:asyncio - ignore: The `allow_redirects` argument is deprecated. Use `follow_redirects` instead.:DeprecationWarning - ignore: Use 'content=<...>' to upload raw bytes/text content.:DeprecationWarning ignore: run_until_first_complete is deprecated and will be removed in a future version.:DeprecationWarning ignore: starlette\.middleware\.wsgi is deprecated and will be removed in a future release\.*:DeprecationWarning + ignore: Async generator 'starlette\.requests\.Request\.stream' was garbage collected before it had been exhausted.*:ResourceWarning + ignore: path is deprecated.*:DeprecationWarning:certifi + ignore: Use 'content=<...>' to upload raw bytes/text content.:DeprecationWarning + ignore: The `allow_redirects` argument is deprecated. Use `follow_redirects` instead.:DeprecationWarning + ignore: 'cgi' is deprecated and slated for removal in Python 3\.13:DeprecationWarning [coverage:run] source_pkgs = starlette, tests @@ -42,3 +44,4 @@ exclude_lines = pragma: no cover pragma: nocover if typing.TYPE_CHECKING: + @typing.overload diff --git a/setup.py b/setup.py deleted file mode 100644 index f293b62ef..000000000 --- a/setup.py +++ /dev/null @@ -1,70 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import os -import re - -from setuptools import setup, find_packages - - -def get_version(package): - """ - Return package version as listed in `__version__` in `init.py`. - """ - with open(os.path.join(package, "__init__.py")) as f: - return re.search("__version__ = ['\"]([^'\"]+)['\"]", f.read()).group(1) - - -def get_long_description(): - """ - Return the README. - """ - with open("README.md", encoding="utf8") as f: - return f.read() - - -setup( - name="starlette", - python_requires=">=3.6", - version=get_version("starlette"), - url="https://github.com/encode/starlette", - license="BSD", - description="The little ASGI library that shines.", - long_description=get_long_description(), - long_description_content_type="text/markdown", - author="Tom Christie", - author_email="tom@tomchristie.com", - packages=find_packages(exclude=["tests*"]), - package_data={"starlette": ["py.typed"]}, - include_package_data=True, - install_requires=[ - "anyio>=3.4.0,<5", - "typing_extensions>=3.10.0; python_version < '3.10'", - "contextlib2 >= 21.6.0; python_version < '3.7'", - ], - extras_require={ - "full": [ - "itsdangerous", - "jinja2", - "python-multipart", - "pyyaml", - "httpx>=0.22.0", - ] - }, - classifiers=[ - "Development Status :: 3 - Alpha", - "Environment :: Web Environment", - "Intended Audience :: Developers", - "License :: OSI Approved :: BSD License", - "Operating System :: OS Independent", - "Topic :: Internet :: WWW/HTTP", - "Framework :: AnyIO", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - ], - zip_safe=False, -) diff --git a/starlette/__init__.py b/starlette/__init__.py index 11ac8e1a9..8b8252f48 100644 --- a/starlette/__init__.py +++ b/starlette/__init__.py @@ -1 +1 @@ -__version__ = "0.19.0" +__version__ = "0.20.4" diff --git a/starlette/_utils.py b/starlette/_utils.py new file mode 100644 index 000000000..0710aebdc --- /dev/null +++ b/starlette/_utils.py @@ -0,0 +1,12 @@ +import asyncio +import functools +import typing + + +def is_async_callable(obj: typing.Any) -> bool: + while isinstance(obj, functools.partial): + obj = obj.func + + return asyncio.iscoroutinefunction(obj) or ( + callable(obj) and asyncio.iscoroutinefunction(obj.__call__) + ) diff --git a/starlette/applications.py b/starlette/applications.py index d83f70d2f..c3daade5c 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -1,10 +1,10 @@ import typing from starlette.datastructures import State, URLPath -from starlette.exceptions import ExceptionMiddleware from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.errors import ServerErrorMiddleware +from starlette.middleware.exceptions import ExceptionMiddleware from starlette.requests import Request from starlette.responses import Response from starlette.routing import BaseRoute, Router @@ -41,17 +41,22 @@ class Starlette: def __init__( self, debug: bool = False, - routes: typing.Sequence[BaseRoute] = None, - middleware: typing.Sequence[Middleware] = None, - exception_handlers: typing.Mapping[ - typing.Any, - typing.Callable[ - [Request, Exception], typing.Union[Response, typing.Awaitable[Response]] - ], + routes: typing.Optional[typing.Sequence[BaseRoute]] = None, + middleware: typing.Optional[typing.Sequence[Middleware]] = None, + exception_handlers: typing.Optional[ + typing.Mapping[ + typing.Any, + typing.Callable[ + [Request, Exception], + typing.Union[Response, typing.Awaitable[Response]], + ], + ] + ] = None, + on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None, + on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None, + lifespan: typing.Optional[ + typing.Callable[["Starlette"], typing.AsyncContextManager] ] = None, - on_startup: typing.Sequence[typing.Callable] = None, - on_shutdown: typing.Sequence[typing.Callable] = None, - lifespan: typing.Callable[["Starlette"], typing.AsyncContextManager] = None, ) -> None: # The lifespan context function is a newer style that replaces # on_startup / on_shutdown handlers. Use one or the other, not both. @@ -124,7 +129,7 @@ def on_event(self, event_type: str) -> typing.Callable: # pragma: nocover return self.router.on_event(event_type) def mount( - self, path: str, app: ASGIApp, name: str = None + self, path: str, app: ASGIApp, name: typing.Optional[str] = None ) -> None: # pragma: nocover """ We no longer document this API, and its usage is discouraged. @@ -141,7 +146,7 @@ def mount( self.router.mount(path, app=app, name=name) def host( - self, host: str, app: ASGIApp, name: str = None + self, host: str, app: ASGIApp, name: typing.Optional[str] = None ) -> None: # pragma: no cover """ We no longer document this API, and its usage is discouraged. @@ -180,8 +185,8 @@ def add_route( self, path: str, route: typing.Callable, - methods: typing.List[str] = None, - name: str = None, + methods: typing.Optional[typing.List[str]] = None, + name: typing.Optional[str] = None, include_in_schema: bool = True, ) -> None: # pragma: no cover self.router.add_route( @@ -189,7 +194,7 @@ def add_route( ) def add_websocket_route( - self, path: str, route: typing.Callable, name: str = None + self, path: str, route: typing.Callable, name: typing.Optional[str] = None ) -> None: # pragma: no cover self.router.add_websocket_route(path, route, name=name) @@ -205,8 +210,8 @@ def decorator(func: typing.Callable) -> typing.Callable: def route( self, path: str, - methods: typing.List[str] = None, - name: str = None, + methods: typing.Optional[typing.List[str]] = None, + name: typing.Optional[str] = None, include_in_schema: bool = True, ) -> typing.Callable: # pragma: nocover """ @@ -234,7 +239,7 @@ def decorator(func: typing.Callable) -> typing.Callable: return decorator def websocket_route( - self, path: str, name: str = None + self, path: str, name: typing.Optional[str] = None ) -> typing.Callable: # pragma: nocover """ We no longer document this decorator style API, and its usage is discouraged. diff --git a/starlette/authentication.py b/starlette/authentication.py index 1a4cba377..32713eb17 100644 --- a/starlette/authentication.py +++ b/starlette/authentication.py @@ -1,14 +1,16 @@ -import asyncio import functools import inspect import typing from urllib.parse import urlencode +from starlette._utils import is_async_callable from starlette.exceptions import HTTPException from starlette.requests import HTTPConnection, Request from starlette.responses import RedirectResponse, Response from starlette.websockets import WebSocket +_CallableType = typing.TypeVar("_CallableType", bound=typing.Callable) + def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool: for scope in scopes: @@ -20,8 +22,8 @@ def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bo def requires( scopes: typing.Union[str, typing.Sequence[str]], status_code: int = 403, - redirect: str = None, -) -> typing.Callable: + redirect: typing.Optional[str] = None, +) -> typing.Callable[[_CallableType], _CallableType]: scopes_list = [scopes] if isinstance(scopes, str) else list(scopes) def decorator(func: typing.Callable) -> typing.Callable: @@ -53,7 +55,7 @@ async def websocket_wrapper( return websocket_wrapper - elif asyncio.iscoroutinefunction(func): + elif is_async_callable(func): # Handle async request/response functions. @functools.wraps(func) async def async_wrapper( @@ -95,7 +97,7 @@ def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response: return sync_wrapper - return decorator + return decorator # type: ignore[return-value] class AuthenticationError(Exception): @@ -110,7 +112,7 @@ async def authenticate( class AuthCredentials: - def __init__(self, scopes: typing.Sequence[str] = None): + def __init__(self, scopes: typing.Optional[typing.Sequence[str]] = None): self.scopes = [] if scopes is None else list(scopes) diff --git a/starlette/background.py b/starlette/background.py index 14a4e9e1a..4aaf7ae3c 100644 --- a/starlette/background.py +++ b/starlette/background.py @@ -1,4 +1,3 @@ -import asyncio import sys import typing @@ -7,6 +6,7 @@ else: # pragma: no cover from typing_extensions import ParamSpec +from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool P = ParamSpec("P") @@ -19,7 +19,7 @@ def __init__( self.func = func self.args = args self.kwargs = kwargs - self.is_async = asyncio.iscoroutinefunction(func) + self.is_async = is_async_callable(func) async def __call__(self) -> None: if self.is_async: @@ -29,7 +29,7 @@ async def __call__(self) -> None: class BackgroundTasks(BackgroundTask): - def __init__(self, tasks: typing.Sequence[BackgroundTask] = None): + def __init__(self, tasks: typing.Optional[typing.Sequence[BackgroundTask]] = None): self.tasks = list(tasks) if tasks else [] def add_task( diff --git a/starlette/config.py b/starlette/config.py index bd809afb4..8c58b2738 100644 --- a/starlette/config.py +++ b/starlette/config.py @@ -52,7 +52,7 @@ def __len__(self) -> int: class Config: def __init__( self, - env_file: typing.Union[str, Path] = None, + env_file: typing.Optional[typing.Union[str, Path]] = None, environ: typing.Mapping[str, str] = environ, ) -> None: self.environ = environ @@ -61,15 +61,17 @@ def __init__( self.file_values = self._read_file(env_file) @typing.overload - def __call__( - self, key: str, cast: typing.Type[T], default: T = ... - ) -> T: # pragma: no cover + def __call__(self, key: str, *, default: None) -> typing.Optional[str]: + ... + + @typing.overload + def __call__(self, key: str, cast: typing.Type[T], default: T = ...) -> T: ... @typing.overload def __call__( self, key: str, cast: typing.Type[str] = ..., default: str = ... - ) -> str: # pragma: no cover + ) -> str: ... @typing.overload @@ -78,22 +80,28 @@ def __call__( key: str, cast: typing.Callable[[typing.Any], T] = ..., default: typing.Any = ..., - ) -> T: # pragma: no cover + ) -> T: ... @typing.overload def __call__( self, key: str, cast: typing.Type[str] = ..., default: T = ... - ) -> typing.Union[T, str]: # pragma: no cover + ) -> typing.Union[T, str]: ... def __call__( - self, key: str, cast: typing.Callable = None, default: typing.Any = undefined + self, + key: str, + cast: typing.Optional[typing.Callable] = None, + default: typing.Any = undefined, ) -> typing.Any: return self.get(key, cast, default) def get( - self, key: str, cast: typing.Callable = None, default: typing.Any = undefined + self, + key: str, + cast: typing.Optional[typing.Callable] = None, + default: typing.Any = undefined, ) -> typing.Any: if key in self.environ: value = self.environ[key] @@ -118,7 +126,7 @@ def _read_file(self, file_name: typing.Union[str, Path]) -> typing.Dict[str, str return file_values def _perform_cast( - self, key: str, value: typing.Any, cast: typing.Callable = None + self, key: str, value: typing.Any, cast: typing.Optional[typing.Callable] = None ) -> typing.Any: if cast is None or value is None: return value diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 59863282a..42ec7a9ea 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -13,6 +13,13 @@ class Address(typing.NamedTuple): port: int +_KeyType = typing.TypeVar("_KeyType") +# Mapping keys are invariant but their values are covariant since +# you can only read them +# that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()` +_CovariantValueType = typing.TypeVar("_CovariantValueType", covariant=True) + + class URL: def __init__( self, @@ -206,6 +213,9 @@ def __repr__(self) -> str: def __str__(self) -> str: return self._value + def __bool__(self) -> bool: + return bool(self._value) + class CommaSeparatedStrings(Sequence): def __init__(self, value: typing.Union[str, typing.Sequence[str]]): @@ -235,32 +245,36 @@ def __str__(self) -> str: return ", ".join(repr(item) for item in self) -class ImmutableMultiDict(typing.Mapping): +class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]): + _dict: typing.Dict[_KeyType, _CovariantValueType] + def __init__( self, *args: typing.Union[ - "ImmutableMultiDict", - typing.Mapping, - typing.List[typing.Tuple[typing.Any, typing.Any]], + "ImmutableMultiDict[_KeyType, _CovariantValueType]", + typing.Mapping[_KeyType, _CovariantValueType], + typing.Iterable[typing.Tuple[_KeyType, _CovariantValueType]], ], **kwargs: typing.Any, ) -> None: assert len(args) < 2, "Too many arguments." - value = args[0] if args else [] + value: typing.Any = args[0] if args else [] if kwargs: value = ( ImmutableMultiDict(value).multi_items() - + ImmutableMultiDict(kwargs).multi_items() + + ImmutableMultiDict(kwargs).multi_items() # type: ignore[operator] ) if not value: _items: typing.List[typing.Tuple[typing.Any, typing.Any]] = [] elif hasattr(value, "multi_items"): - value = typing.cast(ImmutableMultiDict, value) + value = typing.cast( + ImmutableMultiDict[_KeyType, _CovariantValueType], value + ) _items = list(value.multi_items()) elif hasattr(value, "items"): - value = typing.cast(typing.Mapping, value) + value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value) _items = list(value.items()) else: value = typing.cast( @@ -271,33 +285,28 @@ def __init__( self._dict = {k: v for k, v in _items} self._list = _items - def getlist(self, key: typing.Any) -> typing.List[typing.Any]: + def getlist(self, key: typing.Any) -> typing.List[_CovariantValueType]: return [item_value for item_key, item_value in self._list if item_key == key] - def keys(self) -> typing.KeysView: + def keys(self) -> typing.KeysView[_KeyType]: return self._dict.keys() - def values(self) -> typing.ValuesView: + def values(self) -> typing.ValuesView[_CovariantValueType]: return self._dict.values() - def items(self) -> typing.ItemsView: + def items(self) -> typing.ItemsView[_KeyType, _CovariantValueType]: return self._dict.items() - def multi_items(self) -> typing.List[typing.Tuple[str, str]]: + def multi_items(self) -> typing.List[typing.Tuple[_KeyType, _CovariantValueType]]: return list(self._list) - def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any: - if key in self._dict: - return self._dict[key] - return default - - def __getitem__(self, key: typing.Any) -> str: + def __getitem__(self, key: _KeyType) -> _CovariantValueType: return self._dict[key] def __contains__(self, key: typing.Any) -> bool: return key in self._dict - def __iter__(self) -> typing.Iterator[typing.Any]: + def __iter__(self) -> typing.Iterator[_KeyType]: return iter(self.keys()) def __len__(self) -> int: @@ -314,7 +323,7 @@ def __repr__(self) -> str: return f"{class_name}({items!r})" -class MultiDict(ImmutableMultiDict): +class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]): def __setitem__(self, key: typing.Any, value: typing.Any) -> None: self.setlist(key, [value]) @@ -374,7 +383,7 @@ def update( self._dict.update(value) -class QueryParams(ImmutableMultiDict): +class QueryParams(ImmutableMultiDict[str, str]): """ An immutable multidict. """ @@ -468,7 +477,7 @@ async def close(self) -> None: await run_in_threadpool(self.file.close) -class FormData(ImmutableMultiDict): +class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]): """ An immutable multidict, containing both file uploads and text input. """ @@ -531,12 +540,6 @@ def items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore for key, value in self._list ] - def get(self, key: str, default: typing.Any = None) -> typing.Any: - try: - return self[key] - except KeyError: - return default - def getlist(self, key: str) -> typing.List[str]: get_header_key = key.lower().encode("latin-1") return [ diff --git a/starlette/endpoints.py b/starlette/endpoints.py index 73367c257..156663e49 100644 --- a/starlette/endpoints.py +++ b/starlette/endpoints.py @@ -1,8 +1,8 @@ -import asyncio import json import typing from starlette import status +from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException from starlette.requests import Request @@ -37,7 +37,7 @@ async def dispatch(self) -> None: handler: typing.Callable[[Request], typing.Any] = getattr( self, handler_name, self.method_not_allowed ) - is_async = asyncio.iscoroutinefunction(handler) + is_async = is_async_callable(handler) if is_async: response = await handler(request) else: @@ -80,7 +80,9 @@ async def dispatch(self) -> None: data = await self.decode(websocket, message) await self.on_receive(websocket, data) elif message["type"] == "websocket.disconnect": - close_code = int(message.get("code", status.WS_1000_NORMAL_CLOSURE)) + close_code = int( + message.get("code") or status.WS_1000_NORMAL_CLOSURE + ) break except Exception as exc: close_code = status.WS_1011_INTERNAL_ERROR diff --git a/starlette/exceptions.py b/starlette/exceptions.py index 8f28b6e2d..2b5acddb5 100644 --- a/starlette/exceptions.py +++ b/starlette/exceptions.py @@ -1,16 +1,16 @@ -import asyncio import http import typing +import warnings -from starlette.concurrency import run_in_threadpool -from starlette.requests import Request -from starlette.responses import PlainTextResponse, Response -from starlette.types import ASGIApp, Message, Receive, Scope, Send +__all__ = ("HTTPException",) class HTTPException(Exception): def __init__( - self, status_code: int, detail: str = None, headers: dict = None + self, + status_code: int, + detail: typing.Optional[str] = None, + headers: typing.Optional[dict] = None, ) -> None: if detail is None: detail = http.HTTPStatus(status_code).phrase @@ -23,86 +23,22 @@ def __repr__(self) -> str: return f"{class_name}(status_code={self.status_code!r}, detail={self.detail!r})" -class ExceptionMiddleware: - def __init__( - self, - app: ASGIApp, - handlers: typing.Mapping[ - typing.Any, typing.Callable[[Request, Exception], Response] - ] = None, - debug: bool = False, - ) -> None: - self.app = app - self.debug = debug # TODO: We ought to handle 404 cases if debug is set. - self._status_handlers: typing.Dict[int, typing.Callable] = {} - self._exception_handlers: typing.Dict[ - typing.Type[Exception], typing.Callable - ] = {HTTPException: self.http_exception} - if handlers is not None: - for key, value in handlers.items(): - self.add_exception_handler(key, value) - - def add_exception_handler( - self, - exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], - handler: typing.Callable[[Request, Exception], Response], - ) -> None: - if isinstance(exc_class_or_status_code, int): - self._status_handlers[exc_class_or_status_code] = handler - else: - assert issubclass(exc_class_or_status_code, Exception) - self._exception_handlers[exc_class_or_status_code] = handler - - def _lookup_exception_handler( - self, exc: Exception - ) -> typing.Optional[typing.Callable]: - for cls in type(exc).__mro__: - if cls in self._exception_handlers: - return self._exception_handlers[cls] - return None - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - if scope["type"] != "http": - await self.app(scope, receive, send) - return - - response_started = False - - async def sender(message: Message) -> None: - nonlocal response_started - - if message["type"] == "http.response.start": - response_started = True - await send(message) +__deprecated__ = "ExceptionMiddleware" - try: - await self.app(scope, receive, sender) - except Exception as exc: - handler = None - if isinstance(exc, HTTPException): - handler = self._status_handlers.get(exc.status_code) +def __getattr__(name: str) -> typing.Any: # pragma: no cover + if name == __deprecated__: + from starlette.middleware.exceptions import ExceptionMiddleware - if handler is None: - handler = self._lookup_exception_handler(exc) - - if handler is None: - raise exc - - if response_started: - msg = "Caught handled exception, but response already started." - raise RuntimeError(msg) from exc + warnings.warn( + f"{__deprecated__} is deprecated on `starlette.exceptions`. " + f"Import it from `starlette.middleware.exceptions` instead.", + category=DeprecationWarning, + stacklevel=3, + ) + return ExceptionMiddleware + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") - request = Request(scope, receive=receive) - if asyncio.iscoroutinefunction(handler): - response = await handler(request, exc) - else: - response = await run_in_threadpool(handler, request, exc) - await response(scope, receive, sender) - def http_exception(self, request: Request, exc: HTTPException) -> Response: - if exc.status_code in {204, 304}: - return Response(status_code=exc.status_code, headers=exc.headers) - return PlainTextResponse( - exc.detail, status_code=exc.status_code, headers=exc.headers - ) +def __dir__() -> typing.List[str]: + return sorted(list(__all__) + [__deprecated__]) # pragma: no cover diff --git a/starlette/formparsers.py b/starlette/formparsers.py index fd1949229..53538c814 100644 --- a/starlette/formparsers.py +++ b/starlette/formparsers.py @@ -38,6 +38,11 @@ def _user_safe_decode(src: bytes, codec: str) -> str: return src.decode("latin-1") +class MultiPartException(Exception): + def __init__(self, message: str) -> None: + self.message = message + + class FormParser: def __init__( self, headers: Headers, stream: typing.AsyncGenerator[bytes, None] @@ -159,7 +164,10 @@ async def parse(self) -> FormData: charset = params.get(b"charset", "utf-8") if type(charset) == bytes: charset = charset.decode("latin-1") - boundary = params[b"boundary"] + try: + boundary = params[b"boundary"] + except KeyError: + raise MultiPartException("Missing boundary in multipart.") # Callbacks dictionary. callbacks = { @@ -212,7 +220,13 @@ async def parse(self) -> FormData: header_value = b"" elif message_type == MultiPartMessage.HEADERS_FINISHED: disposition, options = parse_options_header(content_disposition) - field_name = _user_safe_decode(options[b"name"], charset) + try: + field_name = _user_safe_decode(options[b"name"], charset) + except KeyError: + raise MultiPartException( + 'The Content-Disposition header field "name" must be ' + "provided." + ) if b"filename" in options: filename = _user_safe_decode(options[b"filename"], charset) file = UploadFile( diff --git a/starlette/middleware/authentication.py b/starlette/middleware/authentication.py index 6e2d2dade..76e4a246d 100644 --- a/starlette/middleware/authentication.py +++ b/starlette/middleware/authentication.py @@ -16,8 +16,8 @@ def __init__( self, app: ASGIApp, backend: AuthenticationBackend, - on_error: typing.Callable[ - [HTTPConnection, AuthenticationError], Response + on_error: typing.Optional[ + typing.Callable[[HTTPConnection, AuthenticationError], Response] ] = None, ) -> None: self.app = app diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index bfb4a54a4..49a5e3e2d 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -13,7 +13,9 @@ class BaseHTTPMiddleware: - def __init__(self, app: ASGIApp, dispatch: DispatchFunction = None) -> None: + def __init__( + self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None + ) -> None: self.app = app self.dispatch_func = self.dispatch if dispatch is None else dispatch @@ -50,7 +52,11 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: async with recv_stream: async for message in recv_stream: assert message["type"] == "http.response.body" - yield message.get("body", b"") + body = message.get("body", b"") + if body: + yield body + if not message.get("more_body", False): + break if app_exc is not None: raise app_exc diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py index c850579c8..b36d155f5 100644 --- a/starlette/middleware/cors.py +++ b/starlette/middleware/cors.py @@ -18,7 +18,7 @@ def __init__( allow_methods: typing.Sequence[str] = ("GET",), allow_headers: typing.Sequence[str] = (), allow_credentials: bool = False, - allow_origin_regex: str = None, + allow_origin_regex: typing.Optional[str] = None, expose_headers: typing.Sequence[str] = (), max_age: int = 600, ) -> None: diff --git a/starlette/middleware/errors.py b/starlette/middleware/errors.py index 474c9afc0..052b885f4 100644 --- a/starlette/middleware/errors.py +++ b/starlette/middleware/errors.py @@ -1,9 +1,9 @@ -import asyncio import html import inspect import traceback import typing +from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool from starlette.requests import Request from starlette.responses import HTMLResponse, PlainTextResponse, Response @@ -135,7 +135,10 @@ class ServerErrorMiddleware: """ def __init__( - self, app: ASGIApp, handler: typing.Callable = None, debug: bool = False + self, + app: ASGIApp, + handler: typing.Optional[typing.Callable] = None, + debug: bool = False, ) -> None: self.app = app self.handler = handler @@ -167,7 +170,7 @@ async def _send(message: Message) -> None: response = self.error_response(request, exc) else: # Use an installed 500 error handler. - if asyncio.iscoroutinefunction(self.handler): + if is_async_callable(self.handler): response = await self.handler(request, exc) else: response = await run_in_threadpool(self.handler, request, exc) diff --git a/starlette/middleware/exceptions.py b/starlette/middleware/exceptions.py new file mode 100644 index 000000000..42fd41ae2 --- /dev/null +++ b/starlette/middleware/exceptions.py @@ -0,0 +1,93 @@ +import typing + +from starlette._utils import is_async_callable +from starlette.concurrency import run_in_threadpool +from starlette.exceptions import HTTPException +from starlette.requests import Request +from starlette.responses import PlainTextResponse, Response +from starlette.types import ASGIApp, Message, Receive, Scope, Send + + +class ExceptionMiddleware: + def __init__( + self, + app: ASGIApp, + handlers: typing.Optional[ + typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] + ] = None, + debug: bool = False, + ) -> None: + self.app = app + self.debug = debug # TODO: We ought to handle 404 cases if debug is set. + self._status_handlers: typing.Dict[int, typing.Callable] = {} + self._exception_handlers: typing.Dict[ + typing.Type[Exception], typing.Callable + ] = {HTTPException: self.http_exception} + if handlers is not None: + for key, value in handlers.items(): + self.add_exception_handler(key, value) + + def add_exception_handler( + self, + exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], + handler: typing.Callable[[Request, Exception], Response], + ) -> None: + if isinstance(exc_class_or_status_code, int): + self._status_handlers[exc_class_or_status_code] = handler + else: + assert issubclass(exc_class_or_status_code, Exception) + self._exception_handlers[exc_class_or_status_code] = handler + + def _lookup_exception_handler( + self, exc: Exception + ) -> typing.Optional[typing.Callable]: + for cls in type(exc).__mro__: + if cls in self._exception_handlers: + return self._exception_handlers[cls] + return None + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + response_started = False + + async def sender(message: Message) -> None: + nonlocal response_started + + if message["type"] == "http.response.start": + response_started = True + await send(message) + + try: + await self.app(scope, receive, sender) + except Exception as exc: + handler = None + + if isinstance(exc, HTTPException): + handler = self._status_handlers.get(exc.status_code) + + if handler is None: + handler = self._lookup_exception_handler(exc) + + if handler is None: + raise exc + + if response_started: + msg = "Caught handled exception, but response already started." + raise RuntimeError(msg) from exc + + request = Request(scope, receive=receive) + if is_async_callable(handler): + response = await handler(request, exc) + else: + response = await run_in_threadpool(handler, request, exc) + await response(scope, receive, sender) + + def http_exception(self, request: Request, exc: HTTPException) -> Response: + if exc.status_code in {204, 304}: + return Response(status_code=exc.status_code, headers=exc.headers) + return PlainTextResponse( + exc.detail, status_code=exc.status_code, headers=exc.headers + ) diff --git a/starlette/middleware/sessions.py b/starlette/middleware/sessions.py index 597de38a2..b1e32ec16 100644 --- a/starlette/middleware/sessions.py +++ b/starlette/middleware/sessions.py @@ -1,4 +1,5 @@ import json +import sys import typing from base64 import b64decode, b64encode @@ -9,6 +10,11 @@ from starlette.requests import HTTPConnection from starlette.types import ASGIApp, Message, Receive, Scope, Send +if sys.version_info >= (3, 8): # pragma: no cover + from typing import Literal +else: # pragma: no cover + from typing_extensions import Literal + class SessionMiddleware: def __init__( @@ -18,7 +24,7 @@ def __init__( session_cookie: str = "session", max_age: typing.Optional[int] = 14 * 24 * 60 * 60, # 14 days, in seconds path: str = "/", - same_site: str = "lax", + same_site: Literal["lax", "strict", "none"] = "lax", https_only: bool = False, ) -> None: self.app = app diff --git a/starlette/middleware/trustedhost.py b/starlette/middleware/trustedhost.py index 6bc4d2b5e..e84e6876a 100644 --- a/starlette/middleware/trustedhost.py +++ b/starlette/middleware/trustedhost.py @@ -11,7 +11,7 @@ class TrustedHostMiddleware: def __init__( self, app: ASGIApp, - allowed_hosts: typing.Sequence[str] = None, + allowed_hosts: typing.Optional[typing.Sequence[str]] = None, www_redirect: bool = True, ) -> None: if allowed_hosts is None: diff --git a/starlette/requests.py b/starlette/requests.py index e3c91e284..726abddcc 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -1,12 +1,12 @@ import json import typing -from collections.abc import Mapping from http import cookies as http_cookies import anyio from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State -from starlette.formparsers import FormParser, MultiPartParser +from starlette.exceptions import HTTPException +from starlette.formparsers import FormParser, MultiPartException, MultiPartParser from starlette.types import Message, Receive, Scope, Send try: @@ -59,13 +59,13 @@ class ClientDisconnect(Exception): pass -class HTTPConnection(Mapping): +class HTTPConnection(typing.Mapping[str, typing.Any]): """ A base class for incoming HTTP connections, that is used to provide any functionality that is common to both `Request` and `WebSocket`. """ - def __init__(self, scope: Scope, receive: Receive = None) -> None: + def __init__(self, scope: Scope, receive: typing.Optional[Receive] = None) -> None: assert scope["type"] in ("http", "websocket") self.scope = scope @@ -142,7 +142,7 @@ def client(self) -> typing.Optional[Address]: return None @property - def session(self) -> dict: + def session(self) -> typing.Dict[str, typing.Any]: assert ( "session" in self.scope ), "SessionMiddleware must be installed to access request.session" @@ -230,7 +230,7 @@ async def stream(self) -> typing.AsyncGenerator[bytes, None]: async def body(self) -> bytes: if not hasattr(self, "_body"): - chunks = [] + chunks: "typing.List[bytes]" = [] async for chunk in self.stream(): chunks.append(chunk) self._body = b"".join(chunks) @@ -248,10 +248,16 @@ async def form(self) -> FormData: parse_options_header is not None ), "The `python-multipart` library must be installed to use form parsing." content_type_header = self.headers.get("Content-Type") - content_type, options = parse_options_header(content_type_header) + content_type: bytes + content_type, _ = parse_options_header(content_type_header) if content_type == b"multipart/form-data": - multipart_parser = MultiPartParser(self.headers, self.stream()) - self._form = await multipart_parser.parse() + try: + multipart_parser = MultiPartParser(self.headers, self.stream()) + self._form = await multipart_parser.parse() + except MultiPartException as exc: + if "app" in self.scope: + raise HTTPException(status_code=400, detail=exc.message) + raise exc elif content_type == b"application/x-www-form-urlencoded": form_parser = FormParser(self.headers, self.stream()) self._form = await form_parser.parse() @@ -279,7 +285,7 @@ async def is_disconnected(self) -> bool: async def send_push_promise(self, path: str) -> None: if "http.response.push" in self.scope.get("extensions", {}): - raw_headers = [] + raw_headers: "typing.List[typing.Tuple[bytes, bytes]]" = [] for name in SERVER_PUSH_HEADERS_TO_COPY: for value in self.headers.getlist(name): raw_headers.append( diff --git a/starlette/responses.py b/starlette/responses.py index b33bdd713..a4ca8caa4 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -17,6 +17,11 @@ from starlette.datastructures import URL, MutableHeaders from starlette.types import Receive, Scope, Send +if sys.version_info >= (3, 8): # pragma: no cover + from typing import Literal +else: # pragma: no cover + from typing_extensions import Literal + # Workaround for adding samesite support to pre 3.8 python http.cookies.Morsel._reserved["samesite"] = "SameSite" # type: ignore @@ -38,9 +43,9 @@ def __init__( self, content: typing.Any = None, status_code: int = 200, - headers: typing.Mapping[str, str] = None, - media_type: str = None, - background: BackgroundTask = None, + headers: typing.Optional[typing.Mapping[str, str]] = None, + media_type: typing.Optional[str] = None, + background: typing.Optional[BackgroundTask] = None, ) -> None: self.status_code = status_code if media_type is not None: @@ -56,7 +61,9 @@ def render(self, content: typing.Any) -> bytes: return content return content.encode(self.charset) - def init_headers(self, headers: typing.Mapping[str, str] = None) -> None: + def init_headers( + self, headers: typing.Optional[typing.Mapping[str, str]] = None + ) -> None: if headers is None: raw_headers: typing.List[typing.Tuple[bytes, bytes]] = [] populate_content_length = True @@ -97,15 +104,15 @@ def set_cookie( self, key: str, value: str = "", - max_age: int = None, - expires: int = None, + max_age: typing.Optional[int] = None, + expires: typing.Optional[int] = None, path: str = "/", - domain: str = None, + domain: typing.Optional[str] = None, secure: bool = False, httponly: bool = False, - samesite: str = "lax", + samesite: typing.Optional[Literal["lax", "strict", "none"]] = "lax", ) -> None: - cookie: http.cookies.BaseCookie = http.cookies.SimpleCookie() + cookie: "http.cookies.BaseCookie[str]" = http.cookies.SimpleCookie() cookie[key] = value if max_age is not None: cookie[key]["max-age"] = max_age @@ -133,10 +140,10 @@ def delete_cookie( self, key: str, path: str = "/", - domain: str = None, + domain: typing.Optional[str] = None, secure: bool = False, httponly: bool = False, - samesite: str = "lax", + samesite: typing.Optional[Literal["lax", "strict", "none"]] = "lax", ) -> None: self.set_cookie( key, @@ -178,9 +185,9 @@ def __init__( self, content: typing.Any, status_code: int = 200, - headers: dict = None, - media_type: str = None, - background: BackgroundTask = None, + headers: typing.Optional[typing.Dict[str, str]] = None, + media_type: typing.Optional[str] = None, + background: typing.Optional[BackgroundTask] = None, ) -> None: super().__init__(content, status_code, headers, media_type, background) @@ -199,8 +206,8 @@ def __init__( self, url: typing.Union[str, URL], status_code: int = 307, - headers: typing.Mapping[str, str] = None, - background: BackgroundTask = None, + headers: typing.Optional[typing.Mapping[str, str]] = None, + background: typing.Optional[BackgroundTask] = None, ) -> None: super().__init__( content=b"", status_code=status_code, headers=headers, background=background @@ -208,14 +215,22 @@ def __init__( self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;") +Content = typing.Union[str, bytes] +SyncContentStream = typing.Iterator[Content] +AsyncContentStream = typing.AsyncIterable[Content] +ContentStream = typing.Union[AsyncContentStream, SyncContentStream] + + class StreamingResponse(Response): + body_iterator: AsyncContentStream + def __init__( self, - content: typing.Any, + content: ContentStream, status_code: int = 200, - headers: typing.Mapping[str, str] = None, - media_type: str = None, - background: BackgroundTask = None, + headers: typing.Optional[typing.Mapping[str, str]] = None, + media_type: typing.Optional[str] = None, + background: typing.Optional[BackgroundTask] = None, ) -> None: if isinstance(content, typing.AsyncIterable): self.body_iterator = content @@ -250,7 +265,7 @@ async def stream_response(self, send: Send) -> None: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async with anyio.create_task_group() as task_group: - async def wrap(func: typing.Callable[[], typing.Coroutine]) -> None: + async def wrap(func: "typing.Callable[[], typing.Awaitable[None]]") -> None: await func() task_group.cancel_scope.cancel() @@ -268,12 +283,12 @@ def __init__( self, path: typing.Union[str, "os.PathLike[str]"], status_code: int = 200, - headers: typing.Mapping[str, str] = None, - media_type: str = None, - background: BackgroundTask = None, - filename: str = None, - stat_result: os.stat_result = None, - method: str = None, + headers: typing.Optional[typing.Mapping[str, str]] = None, + media_type: typing.Optional[str] = None, + background: typing.Optional[BackgroundTask] = None, + filename: typing.Optional[str] = None, + stat_result: typing.Optional[os.stat_result] = None, + method: typing.Optional[str] = None, content_disposition_type: str = "attachment", ) -> None: self.path = path diff --git a/starlette/routing.py b/starlette/routing.py index 0388304c9..1aa2cdb6d 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -1,15 +1,15 @@ -import asyncio import contextlib import functools import inspect import re -import sys import traceback import types import typing import warnings +from contextlib import asynccontextmanager from enum import Enum +from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool from starlette.convertors import CONVERTOR_TYPES, Convertor from starlette.datastructures import URL, Headers, URLPath @@ -19,11 +19,6 @@ from starlette.types import ASGIApp, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketClose -if sys.version_info >= (3, 7): - from contextlib import asynccontextmanager # pragma: no cover -else: - from contextlib2 import asynccontextmanager # pragma: no cover - class NoMatchFound(Exception): """ @@ -42,11 +37,16 @@ class Match(Enum): FULL = 2 -def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: +def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover """ Correctly determines if an object is a coroutine function, including those wrapped in functools.partial objects. """ + warnings.warn( + "iscoroutinefunction_or_partial is deprecated, " + "and will be removed in a future release.", + DeprecationWarning, + ) while isinstance(obj, functools.partial): obj = obj.func return inspect.iscoroutinefunction(obj) @@ -57,7 +57,7 @@ def request_response(func: typing.Callable) -> ASGIApp: Takes a function or coroutine `func(request) -> response`, and returns an ASGI application. """ - is_coroutine = iscoroutinefunction_or_partial(func) + is_coroutine = is_async_callable(func) async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive=receive, send=send) @@ -84,7 +84,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: def get_name(endpoint: typing.Callable) -> str: - if inspect.isfunction(endpoint) or inspect.isclass(endpoint): + if inspect.isroutine(endpoint) or inspect.isclass(endpoint): return endpoint.__name__ return endpoint.__class__.__name__ @@ -111,13 +111,16 @@ def compile_path( path: str, ) -> typing.Tuple[typing.Pattern, str, typing.Dict[str, Convertor]]: """ - Given a path string, like: "/{username:str}", return a three-tuple + Given a path string, like: "/{username:str}", + or a host string, like: "{subdomain}.mydomain.org", return a three-tuple of (regex, format, {param_name:convertor}). regex: "/(?P[^/]+)" format: "/{username}" convertors: {"username": StringConvertor()} """ + is_host = not path.startswith("/") + path_regex = "^" path_format = "" duplicated_params = set() @@ -150,7 +153,13 @@ def compile_path( ending = "s" if len(duplicated_params) > 1 else "" raise ValueError(f"Duplicated param name{ending} {names} at path {path}") - path_regex += re.escape(path[idx:].split(":")[0]) + "$" + if is_host: + # Align with `Host.matches()` behavior, which ignores port. + hostname = path[idx:].split(":")[0] + path_regex += re.escape(hostname) + "$" + else: + path_regex += re.escape(path[idx:]) + "$" + path_format += path[idx:] return re.compile(path_regex), path_format, param_convertors @@ -192,8 +201,8 @@ def __init__( path: str, endpoint: typing.Callable, *, - methods: typing.List[str] = None, - name: str = None, + methods: typing.Optional[typing.List[str]] = None, + name: typing.Optional[str] = None, include_in_schema: bool = True, ) -> None: assert path.startswith("/"), "Routed paths must start with '/'" @@ -276,7 +285,7 @@ def __eq__(self, other: typing.Any) -> bool: class WebSocketRoute(BaseRoute): def __init__( - self, path: str, endpoint: typing.Callable, *, name: str = None + self, path: str, endpoint: typing.Callable, *, name: typing.Optional[str] = None ) -> None: assert path.startswith("/"), "Routed paths must start with '/'" self.path = path @@ -336,9 +345,9 @@ class Mount(BaseRoute): def __init__( self, path: str, - app: ASGIApp = None, - routes: typing.Sequence[BaseRoute] = None, - name: str = None, + app: typing.Optional[ASGIApp] = None, + routes: typing.Optional[typing.Sequence[BaseRoute]] = None, + name: typing.Optional[str] = None, ) -> None: assert path == "" or path.startswith("/"), "Routed paths must start with '/'" assert ( @@ -426,7 +435,10 @@ def __eq__(self, other: typing.Any) -> bool: class Host(BaseRoute): - def __init__(self, host: str, app: ASGIApp, name: str = None) -> None: + def __init__( + self, host: str, app: ASGIApp, name: typing.Optional[str] = None + ) -> None: + assert not host.startswith("/"), "Host must not start with '/'" self.host = host self.app = app self.name = name @@ -537,12 +549,14 @@ def __call__(self: _T, app: object) -> _T: class Router: def __init__( self, - routes: typing.Sequence[BaseRoute] = None, + routes: typing.Optional[typing.Sequence[BaseRoute]] = None, redirect_slashes: bool = True, - default: ASGIApp = None, - on_startup: typing.Sequence[typing.Callable] = None, - on_shutdown: typing.Sequence[typing.Callable] = None, - lifespan: typing.Callable[[typing.Any], typing.AsyncContextManager] = None, + default: typing.Optional[ASGIApp] = None, + on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None, + on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None, + lifespan: typing.Optional[ + typing.Callable[[typing.Any], typing.AsyncContextManager] + ] = None, ) -> None: self.routes = [] if routes is None else list(routes) self.redirect_slashes = redirect_slashes @@ -604,7 +618,7 @@ async def startup(self) -> None: Run any `.on_startup` event handlers. """ for handler in self.on_startup: - if asyncio.iscoroutinefunction(handler): + if is_async_callable(handler): await handler() else: handler() @@ -614,7 +628,7 @@ async def shutdown(self) -> None: Run any `.on_shutdown` event handlers. """ for handler in self.on_shutdown: - if asyncio.iscoroutinefunction(handler): + if is_async_callable(handler): await handler() else: handler() @@ -700,7 +714,7 @@ def __eq__(self, other: typing.Any) -> bool: # The following usages are now discouraged in favour of configuration #  during Router.__init__(...) def mount( - self, path: str, app: ASGIApp, name: str = None + self, path: str, app: ASGIApp, name: typing.Optional[str] = None ) -> None: # pragma: nocover """ We no longer document this API, and its usage is discouraged. @@ -718,7 +732,7 @@ def mount( self.routes.append(route) def host( - self, host: str, app: ASGIApp, name: str = None + self, host: str, app: ASGIApp, name: typing.Optional[str] = None ) -> None: # pragma: no cover """ We no longer document this API, and its usage is discouraged. @@ -739,8 +753,8 @@ def add_route( self, path: str, endpoint: typing.Callable, - methods: typing.List[str] = None, - name: str = None, + methods: typing.Optional[typing.List[str]] = None, + name: typing.Optional[str] = None, include_in_schema: bool = True, ) -> None: # pragma: nocover route = Route( @@ -753,7 +767,7 @@ def add_route( self.routes.append(route) def add_websocket_route( - self, path: str, endpoint: typing.Callable, name: str = None + self, path: str, endpoint: typing.Callable, name: typing.Optional[str] = None ) -> None: # pragma: no cover route = WebSocketRoute(path, endpoint=endpoint, name=name) self.routes.append(route) @@ -761,8 +775,8 @@ def add_websocket_route( def route( self, path: str, - methods: typing.List[str] = None, - name: str = None, + methods: typing.Optional[typing.List[str]] = None, + name: typing.Optional[str] = None, include_in_schema: bool = True, ) -> typing.Callable: # pragma: nocover """ @@ -790,7 +804,7 @@ def decorator(func: typing.Callable) -> typing.Callable: return decorator def websocket_route( - self, path: str, name: str = None + self, path: str, name: typing.Optional[str] = None ) -> typing.Callable: # pragma: nocover """ We no longer document this decorator style API, and its usage is discouraged. diff --git a/starlette/schemas.py b/starlette/schemas.py index 6ca764fdc..55bf7b397 100644 --- a/starlette/schemas.py +++ b/starlette/schemas.py @@ -1,4 +1,5 @@ import inspect +import re import typing from starlette.requests import Request @@ -49,10 +50,11 @@ def get_endpoints( for route in routes: if isinstance(route, Mount): + path = self._remove_converter(route.path) routes = route.routes or [] sub_endpoints = [ EndpointInfo( - path="".join((route.path, sub_endpoint.path)), + path="".join((path, sub_endpoint.path)), http_method=sub_endpoint.http_method, func=sub_endpoint.func, ) @@ -64,23 +66,32 @@ def get_endpoints( continue elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint): + path = self._remove_converter(route.path) for method in route.methods or ["GET"]: if method == "HEAD": continue endpoints_info.append( - EndpointInfo(route.path, method.lower(), route.endpoint) + EndpointInfo(path, method.lower(), route.endpoint) ) else: + path = self._remove_converter(route.path) for method in ["get", "post", "put", "patch", "delete", "options"]: if not hasattr(route.endpoint, method): continue func = getattr(route.endpoint, method) - endpoints_info.append( - EndpointInfo(route.path, method.lower(), func) - ) + endpoints_info.append(EndpointInfo(path, method.lower(), func)) return endpoints_info + def _remove_converter(self, path: str) -> str: + """ + Remove the converter from the path. + For example, a route like this: + Route("/users/{id:int}", endpoint=get_user, methods=["GET"]) + Should be represented as `/users/{id}` in the OpenAPI schema. + """ + return re.sub(r":\w+}", "}", path) + def parse_docstring(self, func_or_method: typing.Callable) -> dict: """ Given a function, parse the docstring as YAML and return a dictionary of info. diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py index bd4d8bced..d09630f35 100644 --- a/starlette/staticfiles.py +++ b/starlette/staticfiles.py @@ -39,8 +39,10 @@ class StaticFiles: def __init__( self, *, - directory: PathLike = None, - packages: typing.List[typing.Union[str, typing.Tuple[str, str]]] = None, + directory: typing.Optional[PathLike] = None, + packages: typing.Optional[ + typing.List[typing.Union[str, typing.Tuple[str, str]]] + ] = None, html: bool = False, check_dir: bool = True, ) -> None: @@ -54,8 +56,10 @@ def __init__( def get_directories( self, - directory: PathLike = None, - packages: typing.List[typing.Union[str, typing.Tuple[str, str]]] = None, + directory: typing.Optional[PathLike] = None, + packages: typing.Optional[ + typing.List[typing.Union[str, typing.Tuple[str, str]]] + ] = None, ) -> typing.List[PathLike]: """ Given `directory` and `packages` arguments, return a list of all the diff --git a/starlette/status.py b/starlette/status.py index b122ae85c..1689328a4 100644 --- a/starlette/status.py +++ b/starlette/status.py @@ -5,6 +5,90 @@ And RFC 2324 - https://tools.ietf.org/html/rfc2324 """ +import warnings +from typing import List + +__all__ = ( + "HTTP_100_CONTINUE", + "HTTP_101_SWITCHING_PROTOCOLS", + "HTTP_102_PROCESSING", + "HTTP_103_EARLY_HINTS", + "HTTP_200_OK", + "HTTP_201_CREATED", + "HTTP_202_ACCEPTED", + "HTTP_203_NON_AUTHORITATIVE_INFORMATION", + "HTTP_204_NO_CONTENT", + "HTTP_205_RESET_CONTENT", + "HTTP_206_PARTIAL_CONTENT", + "HTTP_207_MULTI_STATUS", + "HTTP_208_ALREADY_REPORTED", + "HTTP_226_IM_USED", + "HTTP_300_MULTIPLE_CHOICES", + "HTTP_301_MOVED_PERMANENTLY", + "HTTP_302_FOUND", + "HTTP_303_SEE_OTHER", + "HTTP_304_NOT_MODIFIED", + "HTTP_305_USE_PROXY", + "HTTP_306_RESERVED", + "HTTP_307_TEMPORARY_REDIRECT", + "HTTP_308_PERMANENT_REDIRECT", + "HTTP_400_BAD_REQUEST", + "HTTP_401_UNAUTHORIZED", + "HTTP_402_PAYMENT_REQUIRED", + "HTTP_403_FORBIDDEN", + "HTTP_404_NOT_FOUND", + "HTTP_405_METHOD_NOT_ALLOWED", + "HTTP_406_NOT_ACCEPTABLE", + "HTTP_407_PROXY_AUTHENTICATION_REQUIRED", + "HTTP_408_REQUEST_TIMEOUT", + "HTTP_409_CONFLICT", + "HTTP_410_GONE", + "HTTP_411_LENGTH_REQUIRED", + "HTTP_412_PRECONDITION_FAILED", + "HTTP_413_REQUEST_ENTITY_TOO_LARGE", + "HTTP_414_REQUEST_URI_TOO_LONG", + "HTTP_415_UNSUPPORTED_MEDIA_TYPE", + "HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE", + "HTTP_417_EXPECTATION_FAILED", + "HTTP_418_IM_A_TEAPOT", + "HTTP_421_MISDIRECTED_REQUEST", + "HTTP_422_UNPROCESSABLE_ENTITY", + "HTTP_423_LOCKED", + "HTTP_424_FAILED_DEPENDENCY", + "HTTP_425_TOO_EARLY", + "HTTP_426_UPGRADE_REQUIRED", + "HTTP_428_PRECONDITION_REQUIRED", + "HTTP_429_TOO_MANY_REQUESTS", + "HTTP_431_REQUEST_HEADER_FIELDS_TOO_LARGE", + "HTTP_451_UNAVAILABLE_FOR_LEGAL_REASONS", + "HTTP_500_INTERNAL_SERVER_ERROR", + "HTTP_501_NOT_IMPLEMENTED", + "HTTP_502_BAD_GATEWAY", + "HTTP_503_SERVICE_UNAVAILABLE", + "HTTP_504_GATEWAY_TIMEOUT", + "HTTP_505_HTTP_VERSION_NOT_SUPPORTED", + "HTTP_506_VARIANT_ALSO_NEGOTIATES", + "HTTP_507_INSUFFICIENT_STORAGE", + "HTTP_508_LOOP_DETECTED", + "HTTP_510_NOT_EXTENDED", + "HTTP_511_NETWORK_AUTHENTICATION_REQUIRED", + "WS_1000_NORMAL_CLOSURE", + "WS_1001_GOING_AWAY", + "WS_1002_PROTOCOL_ERROR", + "WS_1003_UNSUPPORTED_DATA", + "WS_1005_NO_STATUS_RCVD", + "WS_1006_ABNORMAL_CLOSURE", + "WS_1007_INVALID_FRAME_PAYLOAD_DATA", + "WS_1008_POLICY_VIOLATION", + "WS_1009_MESSAGE_TOO_BIG", + "WS_1010_MANDATORY_EXT", + "WS_1011_INTERNAL_ERROR", + "WS_1012_SERVICE_RESTART", + "WS_1013_TRY_AGAIN_LATER", + "WS_1014_BAD_GATEWAY", + "WS_1015_TLS_HANDSHAKE", +) + HTTP_100_CONTINUE = 100 HTTP_101_SWITCHING_PROTOCOLS = 101 HTTP_102_PROCESSING = 102 @@ -79,8 +163,8 @@ WS_1001_GOING_AWAY = 1001 WS_1002_PROTOCOL_ERROR = 1002 WS_1003_UNSUPPORTED_DATA = 1003 -WS_1004_NO_STATUS_RCVD = 1004 -WS_1005_ABNORMAL_CLOSURE = 1005 +WS_1005_NO_STATUS_RCVD = 1005 +WS_1006_ABNORMAL_CLOSURE = 1006 WS_1007_INVALID_FRAME_PAYLOAD_DATA = 1007 WS_1008_POLICY_VIOLATION = 1008 WS_1009_MESSAGE_TOO_BIG = 1009 @@ -90,3 +174,26 @@ WS_1013_TRY_AGAIN_LATER = 1013 WS_1014_BAD_GATEWAY = 1014 WS_1015_TLS_HANDSHAKE = 1015 + + +__deprecated__ = {"WS_1004_NO_STATUS_RCVD": 1004, "WS_1005_ABNORMAL_CLOSURE": 1005} + + +def __getattr__(name: str) -> int: + deprecation_changes = { + "WS_1004_NO_STATUS_RCVD": "WS_1005_NO_STATUS_RCVD", + "WS_1005_ABNORMAL_CLOSURE": "WS_1006_ABNORMAL_CLOSURE", + } + deprecated = __deprecated__.get(name) + if deprecated: + warnings.warn( + f"'{name}' is deprecated. Use '{deprecation_changes[name]}' instead.", + category=DeprecationWarning, + stacklevel=3, + ) + return deprecated + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + + +def __dir__() -> List[str]: + return sorted(list(__all__) + list(__deprecated__.keys())) # pragma: no cover diff --git a/starlette/templating.py b/starlette/templating.py index 27939c95e..99035837f 100644 --- a/starlette/templating.py +++ b/starlette/templating.py @@ -28,9 +28,9 @@ def __init__( template: typing.Any, context: dict, status_code: int = 200, - headers: typing.Mapping[str, str] = None, - media_type: str = None, - background: BackgroundTask = None, + headers: typing.Optional[typing.Mapping[str, str]] = None, + media_type: typing.Optional[str] = None, + background: typing.Optional[BackgroundTask] = None, ): self.template = template self.context = context @@ -88,9 +88,9 @@ def TemplateResponse( name: str, context: dict, status_code: int = 200, - headers: typing.Mapping[str, str] = None, - media_type: str = None, - background: BackgroundTask = None, + headers: typing.Optional[typing.Mapping[str, str]] = None, + media_type: typing.Optional[str] = None, + background: typing.Optional[BackgroundTask] = None, ) -> _TemplateResponse: if "request" not in context: raise ValueError('context must include a "request" key') diff --git a/starlette/testclient.py b/starlette/testclient.py index 8b9dfb6c7..455440ce5 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -1,4 +1,3 @@ -import asyncio import contextlib import inspect import io @@ -16,6 +15,7 @@ import httpx from anyio.streams.stapled import StapledObjectStream +from starlette._utils import is_async_callable from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocketDisconnect @@ -36,10 +36,7 @@ def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool: if inspect.isclass(app): return hasattr(app, "__await__") - elif inspect.isfunction(app): - return asyncio.iscoroutinefunction(app) - call = getattr(app, "__call__", None) - return asyncio.iscoroutinefunction(call) + return is_async_callable(app) class _WrapASGI2: diff --git a/starlette/websockets.py b/starlette/websockets.py index 03ed19972..afcbde7fc 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -13,7 +13,7 @@ class WebSocketState(enum.Enum): class WebSocketDisconnect(Exception): - def __init__(self, code: int = 1000, reason: str = None) -> None: + def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None: self.code = code self.reason = reason or "" @@ -88,8 +88,8 @@ async def send(self, message: Message) -> None: async def accept( self, - subprotocol: str = None, - headers: typing.Iterable[typing.Tuple[bytes, bytes]] = None, + subprotocol: typing.Optional[str] = None, + headers: typing.Optional[typing.Iterable[typing.Tuple[bytes, bytes]]] = None, ) -> None: headers = headers or [] @@ -174,14 +174,16 @@ async def send_json(self, data: typing.Any, mode: str = "text") -> None: else: await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")}) - async def close(self, code: int = 1000, reason: str = None) -> None: + async def close( + self, code: int = 1000, reason: typing.Optional[str] = None + ) -> None: await self.send( {"type": "websocket.close", "code": code, "reason": reason or ""} ) class WebSocketClose: - def __init__(self, code: int = 1000, reason: str = None) -> None: + def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None: self.code = code self.reason = reason or "" diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 04da3a961..976d77b86 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1,10 +1,13 @@ +import contextvars + import pytest from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import PlainTextResponse, StreamingResponse -from starlette.routing import Mount, Route, WebSocketRoute +from starlette.routing import Route, WebSocketRoute +from starlette.types import ASGIApp, Receive, Scope, Send class CustomMiddleware(BaseHTTPMiddleware): @@ -136,11 +139,6 @@ def homepage(request): assert response.headers["Custom-Header"] == "Example" -def test_middleware_repr(): - middleware = Middleware(CustomMiddleware) - assert repr(middleware) == "Middleware(CustomMiddleware)" - - def test_fully_evaluated_response(test_client_factory): # Test for https://github.com/encode/starlette/issues/1022 class CustomMiddleware(BaseHTTPMiddleware): @@ -155,11 +153,56 @@ async def dispatch(self, request, call_next): assert response.text == "Custom" -def test_exception_on_mounted_apps(test_client_factory): - sub_app = Starlette(routes=[Route("/", exc)]) - app = Starlette(routes=[Mount("/sub", app=sub_app)]) +ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar") + + +class CustomMiddlewareWithoutBaseHTTPMiddleware: + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + ctxvar.set("set by middleware") + await self.app(scope, receive, send) + assert ctxvar.get() == "set by endpoint" + + +class CustomMiddlewareUsingBaseHTTPMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + ctxvar.set("set by middleware") + resp = await call_next(request) + assert ctxvar.get() == "set by endpoint" + return resp # pragma: no cover + + +@pytest.mark.parametrize( + "middleware_cls", + [ + CustomMiddlewareWithoutBaseHTTPMiddleware, + pytest.param( + CustomMiddlewareUsingBaseHTTPMiddleware, + marks=pytest.mark.xfail( + reason=( + "BaseHTTPMiddleware creates a TaskGroup which copies the context" + "and erases any changes to it made within the TaskGroup" + ), + raises=AssertionError, + ), + ), + ], +) +def test_contextvars(test_client_factory, middleware_cls: type): + # this has to be an async endpoint because Starlette calls run_in_threadpool + # on sync endpoints which has it's own set of peculiarities w.r.t propagating + # contextvars (it propagates them forwards but not backwards) + async def homepage(request): + assert ctxvar.get() == "set by middleware" + ctxvar.set("set by endpoint") + return PlainTextResponse("Homepage") + + app = Starlette( + middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)] + ) client = test_client_factory(app) - with pytest.raises(Exception) as ctx: - client.get("/sub/") - assert str(ctx.value) == "Exc" + response = client.get("/") + assert response.status_code == 200, response.content diff --git a/tests/middleware/test_middleware.py b/tests/middleware/test_middleware.py new file mode 100644 index 000000000..f4d7a32f0 --- /dev/null +++ b/tests/middleware/test_middleware.py @@ -0,0 +1,10 @@ +from starlette.middleware import Middleware + + +class CustomMiddleware: + pass + + +def test_middleware_repr(): + middleware = Middleware(CustomMiddleware) + assert repr(middleware) == "Middleware(CustomMiddleware)" diff --git a/tests/test__utils.py b/tests/test__utils.py new file mode 100644 index 000000000..fac57a2e5 --- /dev/null +++ b/tests/test__utils.py @@ -0,0 +1,79 @@ +import functools + +from starlette._utils import is_async_callable + + +def test_async_func(): + async def async_func(): + ... # pragma: no cover + + def func(): + ... # pragma: no cover + + assert is_async_callable(async_func) + assert not is_async_callable(func) + + +def test_async_partial(): + async def async_func(a, b): + ... # pragma: no cover + + def func(a, b): + ... # pragma: no cover + + partial = functools.partial(async_func, 1) + assert is_async_callable(partial) + + partial = functools.partial(func, 1) + assert not is_async_callable(partial) + + +def test_async_method(): + class Async: + async def method(self): + ... # pragma: no cover + + class Sync: + def method(self): + ... # pragma: no cover + + assert is_async_callable(Async().method) + assert not is_async_callable(Sync().method) + + +def test_async_object_call(): + class Async: + async def __call__(self): + ... # pragma: no cover + + class Sync: + def __call__(self): + ... # pragma: no cover + + assert is_async_callable(Async()) + assert not is_async_callable(Sync()) + + +def test_async_partial_object_call(): + class Async: + async def __call__(self, a, b): + ... # pragma: no cover + + class Sync: + def __call__(self, a, b): + ... # pragma: no cover + + partial = functools.partial(Async(), 1) + assert is_async_callable(partial) + + partial = functools.partial(Sync(), 1) + assert not is_async_callable(partial) + + +def test_async_nested_partial(): + async def async_func(a, b): + ... # pragma: no cover + + partial = functools.partial(async_func, b=2) + nested_partial = functools.partial(partial, a=1) + assert is_async_callable(nested_partial) diff --git a/tests/test_applications.py b/tests/test_applications.py index 62ddd7602..0d0ede571 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -1,5 +1,5 @@ import os -import sys +from contextlib import asynccontextmanager import pytest @@ -12,11 +12,6 @@ from starlette.routing import Host, Mount, Route, Router, WebSocketRoute from starlette.staticfiles import StaticFiles -if sys.version_info >= (3, 7): - from contextlib import asynccontextmanager # pragma: no cover -else: - from contextlib2 import asynccontextmanager # pragma: no cover - async def error_500(request, exc): return JSONResponse({"detail": "Server Error"}, status_code=500) diff --git a/tests/test_background.py b/tests/test_background.py index e299ec362..fbe9dbf1b 100644 --- a/tests/test_background.py +++ b/tests/test_background.py @@ -1,5 +1,10 @@ +from typing import Callable + +import pytest + from starlette.background import BackgroundTask, BackgroundTasks from starlette.responses import Response +from starlette.testclient import TestClient def test_async_task(test_client_factory): @@ -40,7 +45,7 @@ async def app(scope, receive, send): assert TASK_COMPLETE -def test_multiple_tasks(test_client_factory): +def test_multiple_tasks(test_client_factory: Callable[..., TestClient]): TASK_COUNTER = 0 def increment(amount): @@ -61,3 +66,29 @@ async def app(scope, receive, send): response = client.get("/") assert response.text == "tasks initiated" assert TASK_COUNTER == 1 + 2 + 3 + + +def test_multi_tasks_failure_avoids_next_execution( + test_client_factory: Callable[..., TestClient] +) -> None: + TASK_COUNTER = 0 + + def increment(): + nonlocal TASK_COUNTER + TASK_COUNTER += 1 + if TASK_COUNTER == 1: + raise Exception("task failed") + + async def app(scope, receive, send): + tasks = BackgroundTasks() + tasks.add_task(increment) + tasks.add_task(increment) + response = Response( + "tasks initiated", media_type="text/plain", background=tasks + ) + await response(scope, receive, send) + + client = test_client_factory(app) + with pytest.raises(Exception): + client.get("/") + assert TASK_COUNTER == 1 diff --git a/tests/test_config.py b/tests/test_config.py index cfe908bc0..d33000389 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,12 +1,44 @@ import os from pathlib import Path +from typing import Any, Optional import pytest +from typing_extensions import assert_type from starlette.config import Config, Environ, EnvironError from starlette.datastructures import URL, Secret +def test_config_types() -> None: + """ + We use `assert_type` to test the types returned by Config via mypy. + """ + config = Config( + environ={"STR": "some_str_value", "STR_CAST": "some_str_value", "BOOL": "true"} + ) + + assert_type(config("STR"), str) + assert_type(config("STR_DEFAULT", default=""), str) + assert_type(config("STR_CAST", cast=str), str) + assert_type(config("STR_NONE", default=None), Optional[str]) + assert_type(config("STR_CAST_NONE", cast=str, default=None), Optional[str]) + assert_type(config("STR_CAST_STR", cast=str, default=""), str) + + assert_type(config("BOOL", cast=bool), bool) + assert_type(config("BOOL_DEFAULT", cast=bool, default=False), bool) + assert_type(config("BOOL_NONE", cast=bool, default=None), Optional[bool]) + + def cast_to_int(v: Any) -> int: + return int(v) + + # our type annotations allow these `cast` and `default` configurations, but + # the code will error at runtime. + with pytest.raises(ValueError): + config("INT_CAST_DEFAULT_STR", cast=cast_to_int, default="true") + with pytest.raises(ValueError): + config("INT_DEFAULT_STR", cast=int, default="true") + + def test_config(tmpdir, monkeypatch): path = os.path.join(tmpdir, ".env") with open(path, "w") as file: @@ -27,7 +59,10 @@ def cast_to_int(v) -> int: DATABASE_URL = config("DATABASE_URL", cast=URL) REQUEST_TIMEOUT = config("REQUEST_TIMEOUT", cast=int, default=10) REQUEST_HOSTNAME = config("REQUEST_HOSTNAME") + MAIL_HOSTNAME = config("MAIL_HOSTNAME", default=None) SECRET_KEY = config("SECRET_KEY", cast=Secret) + UNSET_SECRET = config("UNSET_SECRET", cast=Secret, default=None) + EMPTY_SECRET = config("EMPTY_SECRET", cast=Secret, default="") assert config("BOOL_AS_INT", cast=bool) is False assert config("BOOL_AS_INT", cast=cast_to_int) == 0 assert config("DEFAULTED_BOOL", cast=cast_to_int, default=True) == 1 @@ -38,8 +73,12 @@ def cast_to_int(v) -> int: assert DATABASE_URL.username == "user" assert REQUEST_TIMEOUT == 10 assert REQUEST_HOSTNAME == "example.com" + assert MAIL_HOSTNAME is None assert repr(SECRET_KEY) == "Secret('**********')" assert str(SECRET_KEY) == "12345" + assert bool(SECRET_KEY) + assert not bool(EMPTY_SECRET) + assert not bool(UNSET_SECRET) with pytest.raises(KeyError): config.get("MISSING") diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index 22e377c99..3ba8bbebc 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -220,7 +220,9 @@ def test_url_blank_params(): assert "abc" in q assert "def" in q assert "b" in q - assert len(q.get("abc")) == 0 + val = q.get("abc") + assert val is not None + assert len(val) == 0 assert len(q["a"]) == 3 assert list(q.keys()) == ["a", "abc", "def", "b"] @@ -342,6 +344,7 @@ def test_multidict(): q = MultiDict([("a", "123"), ("a", "456")]) q["a"] = "789" assert q["a"] == "789" + assert q.get("a") == "789" assert q.getlist("a") == ["789"] q = MultiDict([("a", "123"), ("a", "456")]) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 50f677467..9acd42154 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -1,6 +1,9 @@ +import warnings + import pytest -from starlette.exceptions import ExceptionMiddleware, HTTPException +from starlette.exceptions import HTTPException +from starlette.middleware.exceptions import ExceptionMiddleware from starlette.responses import PlainTextResponse from starlette.routing import Route, Router, WebSocketRoute @@ -130,3 +133,16 @@ class CustomHTTPException(HTTPException): assert repr(CustomHTTPException(500, detail="Something custom")) == ( "CustomHTTPException(status_code=500, detail='Something custom')" ) + + +def test_exception_middleware_deprecation() -> None: + # this test should be removed once the deprecation shim is removed + with pytest.warns(DeprecationWarning): + from starlette.exceptions import ExceptionMiddleware # noqa: F401 + + with warnings.catch_warnings(): + warnings.simplefilter("error") + import starlette.exceptions + + with pytest.warns(DeprecationWarning): + starlette.exceptions.ExceptionMiddleware diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index 91668d164..4792424ab 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -1,11 +1,14 @@ import os import typing +from contextlib import nullcontext as does_not_raise import pytest -from starlette.formparsers import UploadFile, _user_safe_decode +from starlette.applications import Starlette +from starlette.formparsers import MultiPartException, UploadFile, _user_safe_decode from starlette.requests import Request from starlette.responses import JSONResponse +from starlette.routing import Mount class ForceMultipartDict(dict): @@ -20,7 +23,7 @@ def __bool__(self): async def app(scope, receive, send): request = Request(scope, receive) data = await request.form() - output = {} + output: typing.Dict[str, typing.Any] = {} for key, value in data.items(): if isinstance(value, UploadFile): content = await value.read() @@ -62,7 +65,7 @@ async def multi_items_app(scope, receive, send): async def app_with_headers(scope, receive, send): request = Request(scope, receive) data = await request.form() - output = {} + output: typing.Dict[str, typing.Any] = {} for key, value in data.items(): if isinstance(value, UploadFile): content = await value.read() @@ -390,10 +393,17 @@ def test_user_safe_decode_ignores_wrong_charset(): assert result == "abc" -def test_missing_boundary_parameter(test_client_factory): +@pytest.mark.parametrize( + "app,expectation", + [ + (app, pytest.raises(MultiPartException)), + (Starlette(routes=[Mount("/", app=app)]), does_not_raise()), + ], +) +def test_missing_boundary_parameter(app, expectation, test_client_factory) -> None: client = test_client_factory(app) - with pytest.raises(KeyError, match="boundary"): - client.post( + with expectation: + res = client.post( "/", data=( # file @@ -403,3 +413,37 @@ def test_missing_boundary_parameter(test_client_factory): ), headers={"Content-Type": "multipart/form-data; charset=utf-8"}, ) + assert res.status_code == 400 + assert res.text == "Missing boundary in multipart." + + +@pytest.mark.parametrize( + "app,expectation", + [ + (app, pytest.raises(MultiPartException)), + (Starlette(routes=[Mount("/", app=app)]), does_not_raise()), + ], +) +def test_missing_name_parameter_on_content_disposition( + app, expectation, test_client_factory +): + client = test_client_factory(app) + with expectation: + res = client.post( + "/", + data=( + # data + b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" + b'Content-Disposition: form-data; ="field0"\r\n\r\n' + b"value0\r\n" + ), + headers={ + "Content-Type": ( + "multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c" + ) + }, + ) + assert res.status_code == 400 + assert ( + res.text == 'The Content-Disposition header field "name" must be provided.' + ) diff --git a/tests/test_requests.py b/tests/test_requests.py index 5c3d28fb4..7422ad72a 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -1,3 +1,4 @@ +import sys from typing import Optional import anyio @@ -36,6 +37,10 @@ async def app(scope, receive, send): assert response.json() == {"params": {"a": "123", "b": "456"}} +@pytest.mark.skipif( + any(module in sys.modules for module in ("brotli", "brotlicffi")), + reason='urllib3 includes "br" to the "accept-encoding" headers.', +) def test_request_headers(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) diff --git a/tests/test_responses.py b/tests/test_responses.py index 48913a2c0..608842da2 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -391,3 +391,34 @@ def test_streaming_response_known_size(test_client_factory): client: TestClient = test_client_factory(app) response = client.get("/") assert response.headers["content-length"] == "10" + + +@pytest.mark.anyio +async def test_streaming_response_stops_if_receiving_http_disconnect(): + streamed = 0 + + disconnected = anyio.Event() + + async def receive_disconnect(): + await disconnected.wait() + return {"type": "http.disconnect"} + + async def send(message): + nonlocal streamed + if message["type"] == "http.response.body": + streamed += len(message.get("body", b"")) + # Simulate disconnection after download has started + if streamed >= 16: + disconnected.set() + + async def stream_indefinitely(): + while True: + # Need a sleep for the event loop to switch to another task + await anyio.sleep(0) + yield b"chunk " + + response = StreamingResponse(content=stream_indefinitely()) + + with anyio.move_on_after(1) as cancel_scope: + await response({}, receive_disconnect, send) + assert not cancel_scope.cancel_called, "Content streaming should stop itself." diff --git a/tests/test_routing.py b/tests/test_routing.py index 7077c5616..e3b1e412a 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -1,4 +1,5 @@ import functools +import typing import uuid import pytest @@ -27,6 +28,11 @@ def user_me(request): return Response(content, media_type="text/plain") +def disable_user(request): + content = "User " + request.path_params["username"] + " disabled" + return Response(content, media_type="text/plain") + + def user_no_match(request): # pragma: no cover content = "User fixed no match" return Response(content, media_type="text/plain") @@ -108,6 +114,7 @@ async def websocket_params(session: WebSocket): Route("/", endpoint=users), Route("/me", endpoint=user_me), Route("/{username}", endpoint=user), + Route("/{username}:disable", endpoint=disable_user, methods=["PUT"]), Route("/nomatch", endpoint=user_no_match), ], ), @@ -188,6 +195,11 @@ def test_router(client): assert response.url == "http://testserver/users/tomchristie" assert response.text == "User tomchristie" + response = client.put("/users/tomchristie:disable") + assert response.status_code == 200 + assert response.url == "http://testserver/users/tomchristie:disable" + assert response.text == "User tomchristie disabled" + response = client.get("/users/nomatch") assert response.status_code == 200 assert response.text == "User nomatch" @@ -428,7 +440,9 @@ def test_host_routing(test_client_factory): response = client.get("/") assert response.status_code == 200 - client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org/") + client = test_client_factory( + mixed_hosts_app, base_url="https://port.example.org:3600/" + ) response = client.get("/users") assert response.status_code == 404 @@ -436,6 +450,13 @@ def test_host_routing(test_client_factory): response = client.get("/") assert response.status_code == 200 + # Port in requested Host is irrelevant. + + client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org/") + + response = client.get("/") + assert response.status_code == 200 + client = test_client_factory( mixed_hosts_app, base_url="https://port.example.org:5600/" ) @@ -710,3 +731,51 @@ def test_duplicated_param_names(): match="Duplicated param names id, name at path /{id}/{name}/{id}/{name}", ): Route("/{id}/{name}/{id}/{name}", user) + + +class Endpoint: + async def my_method(self, request): + ... # pragma: no cover + + @classmethod + async def my_classmethod(cls, request): + ... # pragma: no cover + + @staticmethod + async def my_staticmethod(request): + ... # pragma: no cover + + def __call__(self, request): + ... # pragma: no cover + + +@pytest.mark.parametrize( + "endpoint, expected_name", + [ + pytest.param(func_homepage, "func_homepage", id="function"), + pytest.param(Endpoint().my_method, "my_method", id="method"), + pytest.param(Endpoint.my_classmethod, "my_classmethod", id="classmethod"), + pytest.param( + Endpoint.my_staticmethod, + "my_staticmethod", + id="staticmethod", + ), + pytest.param(Endpoint(), "Endpoint", id="object"), + pytest.param(lambda request: ..., "", id="lambda"), + ], +) +def test_route_name(endpoint: typing.Callable, expected_name: str): + assert Route(path="/", endpoint=endpoint).name == expected_name + + +def test_exception_on_mounted_apps(test_client_factory): + def exc(request): + raise Exception("Exc") + + sub_app = Starlette(routes=[Route("/", exc)]) + app = Starlette(routes=[Mount("/sub", app=sub_app)]) + + client = test_client_factory(app) + with pytest.raises(Exception) as ctx: + client.get("/sub/") + assert str(ctx.value) == "Exc" diff --git a/tests/test_schemas.py b/tests/test_schemas.py index fa43785b9..26884b391 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -13,6 +13,17 @@ def ws(session): pass # pragma: no cover +def get_user(request): + """ + responses: + 200: + description: A user. + examples: + {"username": "tom"} + """ + pass # pragma: no cover + + def list_users(request): """ responses: @@ -103,6 +114,7 @@ def schema(request): app = Starlette( routes=[ WebSocketRoute("/ws", endpoint=ws), + Route("/users/{id:int}", endpoint=get_user, methods=["GET"]), Route("/users", endpoint=list_users, methods=["GET", "HEAD"]), Route("/users", endpoint=create_user, methods=["POST"]), Route("/orgs", endpoint=OrganisationsEndpoint), @@ -168,6 +180,16 @@ def test_schema_generation(): } }, }, + "/users/{id}": { + "get": { + "responses": { + 200: { + "description": "A user.", + "examples": {"username": "tom"}, + } + } + }, + }, }, } @@ -216,6 +238,13 @@ def test_schema_generation(): description: A user. examples: username: tom + /users/{id}: + get: + responses: + 200: + description: A user. + examples: + username: tom """ diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index 8da232fda..142c2a00b 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -377,21 +377,28 @@ def test_staticfiles_cache_invalidation_for_deleted_file_html_mode( def test_staticfiles_with_invalid_dir_permissions_returns_401( - tmpdir, test_client_factory + tmp_path, test_client_factory ): - path = os.path.join(tmpdir, "example.txt") - with open(path, "w") as file: - file.write("") - - os.chmod(tmpdir, stat.S_IRWXO) - - routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")] - app = Starlette(routes=routes) - client = test_client_factory(app) - - response = client.get("/example.txt") - assert response.status_code == 401 - assert response.text == "Unauthorized" + (tmp_path / "example.txt").write_bytes(b"") + + original_mode = tmp_path.stat().st_mode + tmp_path.chmod(stat.S_IRWXO) + try: + routes = [ + Mount( + "/", + app=StaticFiles(directory=os.fsdecode(tmp_path)), + name="static", + ) + ] + app = Starlette(routes=routes) + client = test_client_factory(app) + + response = client.get("/example.txt") + assert response.status_code == 401 + assert response.text == "Unauthorized" + finally: + tmp_path.chmod(original_mode) def test_staticfiles_with_missing_dir_returns_404(tmpdir, test_client_factory): diff --git a/tests/test_status.py b/tests/test_status.py new file mode 100644 index 000000000..04719e87e --- /dev/null +++ b/tests/test_status.py @@ -0,0 +1,25 @@ +import importlib + +import pytest + + +@pytest.mark.parametrize( + "constant,msg", + ( + ( + "WS_1004_NO_STATUS_RCVD", + "'WS_1004_NO_STATUS_RCVD' is deprecated. " + "Use 'WS_1005_NO_STATUS_RCVD' instead.", + ), + ( + "WS_1005_ABNORMAL_CLOSURE", + "'WS_1005_ABNORMAL_CLOSURE' is deprecated. " + "Use 'WS_1006_ABNORMAL_CLOSURE' instead.", + ), + ), +) +def test_deprecated_types(constant: str, msg: str) -> None: + with pytest.warns(DeprecationWarning) as record: + getattr(importlib.import_module("starlette.status"), constant) + assert len(record) == 1 + assert msg in str(record.list[0]) diff --git a/tests/test_testclient.py b/tests/test_testclient.py index 22f0b3880..c9c7f33ca 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -1,6 +1,6 @@ -import asyncio import itertools -import sys +from asyncio import current_task as asyncio_current_task +from contextlib import asynccontextmanager import anyio import pytest @@ -13,13 +13,6 @@ from starlette.routing import Route from starlette.websockets import WebSocket, WebSocketDisconnect -if sys.version_info >= (3, 7): # pragma: no cover - from asyncio import current_task as asyncio_current_task - from contextlib import asynccontextmanager -else: # pragma: no cover - asyncio_current_task = asyncio.Task.current_task - from contextlib2 import asynccontextmanager - def mock_service_endpoint(request): return JSONResponse({"mock": "example"}) diff --git a/tests/test_websockets.py b/tests/test_websockets.py index f3970967e..c1ec1153e 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -1,3 +1,5 @@ +import sys + import anyio import pytest @@ -48,6 +50,10 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert data == {"params": {"a": "abc", "b": "456"}} +@pytest.mark.skipif( + any(module in sys.modules for module in ("brotli", "brotlicffi")), + reason='urllib3 includes "br" to the "accept-encoding" headers.', +) def test_websocket_headers(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send)