From f0caa0f001ae8c50448b4328e806cc3f4ee320c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Fri, 24 May 2019 14:59:48 +0400 Subject: [PATCH 01/14] :sparkles: Add WebSocket exception handling --- starlette/exceptions.py | 48 ++++++++++++++++++++++++++++------ starlette/middleware/errors.py | 15 ++++++++--- 2 files changed, 52 insertions(+), 11 deletions(-) diff --git a/starlette/exceptions.py b/starlette/exceptions.py index 0ef621508..7e3f272a8 100644 --- a/starlette/exceptions.py +++ b/starlette/exceptions.py @@ -2,10 +2,12 @@ import http import typing +from starlette import status from starlette.concurrency import run_in_threadpool from starlette.requests import Request from starlette.responses import PlainTextResponse, Response from starlette.types import ASGIApp, Message, Receive, Scope, Send +from starlette.websockets import WebSocket, WebSocketClose class HTTPException(Exception): @@ -16,13 +18,31 @@ def __init__(self, status_code: int, detail: str = None) -> None: self.detail = detail +class WebSocketException(Exception): + def __init__(self, code: int = status.WS_1008_POLICY_VIOLATION) -> None: + """ + `code` defaults to 1008, from the WebSocket specification: + + > 1008 indicates that an endpoint is terminating the connection + > because it has received a message that violates its policy. This + > is a generic status code that can be returned when there is no + > other more suitable status code (e.g., 1003 or 1009) or if there + > is a need to hide specific details about the policy. + + Set `code` to any value allowed by + [the WebSocket specification](https://tools.ietf.org/html/rfc6455#section-7.4.1). + """ + self.code = code + + class ExceptionMiddleware: def __init__(self, app: ASGIApp, debug: bool = False) -> None: self.app = app self.debug = debug # TODO: We ought to handle 404 cases if debug is set. self._status_handlers = {} # type: typing.Dict[int, typing.Callable] self._exception_handlers = { - HTTPException: self.http_exception + HTTPException: self.http_exception, + WebSocketException: self.websocket_exception, } # type: typing.Dict[typing.Type[Exception], typing.Callable] def add_exception_handler( @@ -45,7 +65,7 @@ def _lookup_exception_handler( return None async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - if scope["type"] != "http": + if scope["type"] not in {"http", "websocket"}: await self.app(scope, receive, send) return @@ -76,14 +96,26 @@ async def sender(message: Message) -> None: msg = "Caught handled exception, but response already started." raise RuntimeError(msg) from exc - request = Request(scope, receive=receive) - if asyncio.iscoroutinefunction(handler): - response = await handler(request, exc) - else: - response = await run_in_threadpool(handler, request, exc) - await response(scope, receive, sender) + if scope["type"] == "http": + request = Request(scope, receive=receive) + if asyncio.iscoroutinefunction(handler): + response = await handler(request, exc) + else: + response = await run_in_threadpool(handler, request, exc) + await response(scope, receive, sender) + elif scope["type"] == "websocket": + websocket = WebSocket(scope, receive=receive, send=send) + if asyncio.iscoroutinefunction(handler): + await handler(websocket, exc) + else: + await run_in_threadpool(handler, websocket, exc) def http_exception(self, request: Request, exc: HTTPException) -> Response: if exc.status_code in {204, 304}: return Response(b"", status_code=exc.status_code) return PlainTextResponse(exc.detail, status_code=exc.status_code) + + async def websocket_exception( + self, websocket: WebSocket, exc: WebSocketException + ) -> None: + await websocket.close(code=exc.code) diff --git a/starlette/middleware/errors.py b/starlette/middleware/errors.py index 54f5fd2ae..643ebc104 100644 --- a/starlette/middleware/errors.py +++ b/starlette/middleware/errors.py @@ -2,10 +2,12 @@ import traceback import typing +from starlette import status from starlette.concurrency import run_in_threadpool -from starlette.requests import Request +from starlette.requests import Request, empty_receive from starlette.responses import HTMLResponse, PlainTextResponse, Response from starlette.types import ASGIApp, Message, Receive, Scope, Send +from starlette.websockets import WebSocket, WebSocketState STYLES = """ .traceback-container { @@ -83,7 +85,7 @@ def __init__( self.debug = debug async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - if scope["type"] != "http": + if scope["type"] not in {"http", "websocket"}: await self.app(scope, receive, send) return @@ -99,7 +101,7 @@ async def _send(message: Message) -> None: try: await self.app(scope, receive, _send) except Exception as exc: - if not response_started: + if not response_started and scope["type"] == "http": request = Request(scope) if self.debug: # In debug mode, return traceback responses. @@ -115,6 +117,13 @@ async def _send(message: Message) -> None: response = await run_in_threadpool(self.handler, request, exc) await response(scope, receive, send) + elif scope["type"] == "websocket": + websocket = WebSocket(scope, receive, send) + # https://tools.ietf.org/html/rfc6455#section-7.4.1 + # 1011 indicates that a server is terminating the connection because + # it encountered an unexpected condition that prevented it from + # fulfilling the request. + await websocket.close(code=status.WS_1011_INTERNAL_ERROR) # We always continue to raise the exception. # This allows servers to log the error, or allows test clients From 113d5c79dca0d45b1682de287e358205b3a1f9ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Fri, 24 May 2019 15:00:11 +0400 Subject: [PATCH 02/14] :white_check_mark: Test WebSocket exceptions --- tests/middleware/test_errors.py | 9 ++---- tests/test_applications.py | 53 ++++++++++++++++++++++++++++++++- 2 files changed, 55 insertions(+), 7 deletions(-) diff --git a/tests/middleware/test_errors.py b/tests/middleware/test_errors.py index ff74d87bd..768a4ee0b 100644 --- a/tests/middleware/test_errors.py +++ b/tests/middleware/test_errors.py @@ -3,6 +3,7 @@ from starlette.middleware.errors import ServerErrorMiddleware from starlette.responses import JSONResponse, Response from starlette.testclient import TestClient +from starlette.websockets import WebSocket, WebSocketDisconnect def test_handler(): @@ -55,16 +56,12 @@ async def app(scope, receive, send): client.get("/") -def test_debug_not_http(): - """ - DebugMiddleware should just pass through any non-http messages as-is. - """ - +def test_debug_websocket(): async def app(scope, receive, send): raise RuntimeError("Something went wrong") app = ServerErrorMiddleware(app) - with pytest.raises(RuntimeError): + with pytest.raises(WebSocketDisconnect): client = TestClient(app) client.websocket_connect("/") diff --git a/tests/test_applications.py b/tests/test_applications.py index bece5c7a6..9c2d4d25d 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -1,13 +1,18 @@ +import asyncio import os +import pytest + +from starlette import status from starlette.applications import Starlette from starlette.endpoints import HTTPEndpoint -from starlette.exceptions import HTTPException +from starlette.exceptions import HTTPException, WebSocketException from starlette.middleware.trustedhost import TrustedHostMiddleware from starlette.responses import JSONResponse, PlainTextResponse from starlette.routing import Host, Mount, Route, Router, WebSocketRoute from starlette.staticfiles import StaticFiles from starlette.testclient import TestClient +from starlette.websockets import WebSocketDisconnect app = Starlette() @@ -86,6 +91,28 @@ async def websocket_endpoint(session): await session.close() +@app.websocket_route("/ws-raise-websocket") +async def websocket_raise_websocket_exception(websocket): + await websocket.accept() + raise WebSocketException(code=status.WS_1003_UNSUPPORTED_DATA) + + +class CustomWSException(Exception): + pass + + +@app.websocket_route("/ws-raise-custom") +async def websocket_raise_custom(websocket): + await websocket.accept() + raise CustomWSException() + + +@app.exception_handler(CustomWSException) +def custom_ws_exception_handler(websocket, exc): + loop = asyncio.new_event_loop() + loop.run_until_complete(websocket.close(code=status.WS_1013_TRY_AGAIN_LATER)) + + client = TestClient(app) @@ -164,6 +191,26 @@ def test_500(): assert response.json() == {"detail": "Server Error"} +def test_websocket_raise_websocket_exception(): + client = TestClient(app) + with client.websocket_connect("/ws-raise-websocket") as session: + response = session.receive() + assert response == { + "type": "websocket.close", + "code": status.WS_1003_UNSUPPORTED_DATA, + } + + +def test_websocket_raise_custom_exception(): + client = TestClient(app) + with client.websocket_connect("/ws-raise-custom") as session: + response = session.receive() + assert response == { + "type": "websocket.close", + "code": status.WS_1013_TRY_AGAIN_LATER, + } + + def test_middleware(): client = TestClient(app, base_url="http://incorrecthost") response = client.get("/func") @@ -191,6 +238,10 @@ def test_routes(): ), Route("/500", endpoint=runtime_error, methods=["GET"]), WebSocketRoute("/ws", endpoint=websocket_endpoint), + WebSocketRoute( + "/ws-raise-websocket", endpoint=websocket_raise_websocket_exception + ), + WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom), ] From 1f0bc8640908320d71bd5680228a4ce59aa6a2ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Fri, 24 May 2019 15:11:30 +0400 Subject: [PATCH 03/14] :memo: Document WebSocketException --- docs/exceptions.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/docs/exceptions.md b/docs/exceptions.md index 5a7afd35b..732cf81cd 100644 --- a/docs/exceptions.md +++ b/docs/exceptions.md @@ -42,6 +42,14 @@ async def http_exception(request, exc): return JSONResponse({"detail": exc.detail}, status_code=exc.status_code) ``` +You might also want to override how `WebSocketException` is handled: + +```python +@app.exception_handler(WebSocketException) +async def websocket_exception(websocket, exc): + await websocket.close(code=1008) +``` + ## Errors and handled exceptions It is important to differentiate between handled exceptions and errors. @@ -74,3 +82,11 @@ returning plain-text HTTP responses for any `HTTPException`. You should only raise `HTTPException` inside routing or endpoints. Middleware classes should instead just return appropriate responses directly. + +## WebSocketException + +You can use the `WebSocketException` class to raise errors inside of WebSocket endpoints. + +* `WebSocketException(code=1008)` + +You can set any code valid as defined [in the specification](https://tools.ietf.org/html/rfc6455#section-7.4.1). From 31adab155d91865e014298670adecdb1077c26e0 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 11 Aug 2021 19:29:04 +0200 Subject: [PATCH 04/14] merge master --- .codecov.yml | 11 - .github/FUNDING.yml | 1 + .github/ISSUE_TEMPLATE/2-bug-report.md | 53 +++ .github/ISSUE_TEMPLATE/3-feature-request.md | 33 ++ .github/ISSUE_TEMPLATE/config.yml | 7 + .github/workflows/publish.yml | 27 ++ .github/workflows/test-suite.yml | 33 ++ .gitignore | 10 +- .travis.yml | 19 - CONTRIBUTING.md | 79 ---- MANIFEST.in | 3 + README.md | 55 ++- docs/applications.md | 61 ++- docs/authentication.md | 32 +- docs/background.md | 21 +- docs/config.md | 34 +- docs/database.md | 63 +++- docs/endpoints.md | 32 +- docs/events.md | 60 +-- docs/exceptions.md | 20 +- docs/graphql.md | 38 +- docs/index.md | 52 ++- docs/js/chat.js | 3 + docs/js/sidecar-1.5.0.js | 6 + docs/middleware.md | 133 +++++-- docs/release-notes.md | 160 +++++++- docs/requests.md | 44 ++- docs/responses.md | 138 +++---- docs/routing.md | 216 +++++++++-- docs/schemas.md | 17 +- docs/server-push.md | 36 ++ docs/staticfiles.md | 28 +- docs/templates.md | 30 +- docs/testclient.md | 61 +-- docs/third-party-packages.md | 75 +++- docs/websockets.md | 21 +- mkdocs.yml | 56 +-- requirements.txt | 21 +- scripts/README.md | 5 +- scripts/build | 13 + scripts/check | 14 + scripts/coverage | 10 + scripts/docs | 10 + scripts/install | 23 +- scripts/lint | 9 +- scripts/publish | 32 +- scripts/test | 18 +- setup.cfg | 39 ++ setup.py | 29 +- starlette/__init__.py | 2 +- starlette/applications.py | 122 ++++-- starlette/authentication.py | 19 +- starlette/background.py | 4 +- starlette/concurrency.py | 30 +- starlette/config.py | 25 +- starlette/convertors.py | 14 +- starlette/datastructures.py | 110 ++++-- starlette/endpoints.py | 4 +- starlette/exceptions.py | 27 +- starlette/formparsers.py | 60 +-- starlette/graphql.py | 42 +-- starlette/middleware/__init__.py | 17 + starlette/middleware/authentication.py | 6 +- starlette/middleware/base.py | 67 ++-- starlette/middleware/cors.py | 46 ++- starlette/middleware/errors.py | 142 +++++-- starlette/middleware/gzip.py | 21 +- starlette/middleware/httpsredirect.py | 2 +- starlette/middleware/sessions.py | 16 +- starlette/middleware/trustedhost.py | 3 +- starlette/middleware/wsgi.py | 73 ++-- starlette/requests.py | 160 +++++--- starlette/responses.py | 113 ++++-- starlette/routing.py | 395 ++++++++++++++------ starlette/schemas.py | 2 + starlette/staticfiles.py | 57 +-- starlette/status.py | 19 +- starlette/templating.py | 12 +- starlette/testclient.py | 244 ++++++++---- starlette/websockets.py | 23 +- tests/.ignore_lifespan | 3 - tests/conftest.py | 25 ++ tests/middleware/__init__.py | 0 tests/middleware/test_base.py | 83 +++- tests/middleware/test_cors.py | 296 +++++++++++++-- tests/middleware/test_errors.py | 28 +- tests/middleware/test_gzip.py | 17 +- tests/middleware/test_https_redirect.py | 21 +- tests/middleware/test_lifespan.py | 108 ------ tests/middleware/test_session.py | 27 +- tests/middleware/test_trusted_host.py | 13 +- tests/middleware/test_wsgi.py | 34 +- tests/test_applications.py | 142 +++++-- tests/test_authentication.py | 143 ++++++- tests/test_background.py | 13 +- tests/test_concurrency.py | 22 ++ tests/test_config.py | 5 + tests/test_database.py | 18 +- tests/test_datastructures.py | 41 ++ tests/test_endpoints.py | 41 +- tests/test_exceptions.py | 41 +- tests/test_formparsers.py | 141 +++++-- tests/test_graphql.py | 70 ++-- tests/test_requests.py | 260 +++++++++++-- tests/test_responses.py | 151 +++++--- tests/test_routing.py | 333 ++++++++++++++++- tests/test_schemas.py | 5 +- tests/test_staticfiles.py | 148 ++++++-- tests/test_templates.py | 5 +- tests/test_testclient.py | 172 +++++++-- tests/test_websockets.py | 162 ++++++-- 111 files changed, 4645 insertions(+), 1791 deletions(-) delete mode 100644 .codecov.yml create mode 100644 .github/FUNDING.yml create mode 100644 .github/ISSUE_TEMPLATE/2-bug-report.md create mode 100644 .github/ISSUE_TEMPLATE/3-feature-request.md create mode 100644 .github/ISSUE_TEMPLATE/config.yml create mode 100644 .github/workflows/publish.yml create mode 100644 .github/workflows/test-suite.yml delete mode 100644 .travis.yml delete mode 100644 CONTRIBUTING.md create mode 100644 MANIFEST.in create mode 100644 docs/js/chat.js create mode 100644 docs/js/sidecar-1.5.0.js create mode 100644 docs/server-push.md create mode 100755 scripts/build create mode 100755 scripts/check create mode 100755 scripts/coverage create mode 100755 scripts/docs create mode 100644 setup.cfg delete mode 100644 tests/.ignore_lifespan create mode 100644 tests/conftest.py create mode 100644 tests/middleware/__init__.py delete mode 100644 tests/middleware/test_lifespan.py create mode 100644 tests/test_concurrency.py diff --git a/.codecov.yml b/.codecov.yml deleted file mode 100644 index c2336342e..000000000 --- a/.codecov.yml +++ /dev/null @@ -1,11 +0,0 @@ -coverage: - precision: 2 - round: down - range: "80...100" - - status: - project: yes - patch: no - changes: no - -comment: off diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 000000000..2f87d94ca --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1 @@ +github: encode diff --git a/.github/ISSUE_TEMPLATE/2-bug-report.md b/.github/ISSUE_TEMPLATE/2-bug-report.md new file mode 100644 index 000000000..7c11706b7 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/2-bug-report.md @@ -0,0 +1,53 @@ +--- +name: Bug report +about: Report a bug to help improve this project +--- + +### Checklist + + + +- [ ] The bug is reproducible against the latest release and/or `master`. +- [ ] There are no similar issues or pull requests to fix it yet. + +### Describe the bug + + + +### To reproduce + + + +### Expected behavior + + + +### Actual behavior + + + +### Debugging material + + + +### Environment + +- OS: +- Python version: +- Starlette version: + +### Additional context + + diff --git a/.github/ISSUE_TEMPLATE/3-feature-request.md b/.github/ISSUE_TEMPLATE/3-feature-request.md new file mode 100644 index 000000000..97336f516 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/3-feature-request.md @@ -0,0 +1,33 @@ +--- +name: Feature request +about: Suggest an idea for this project. +--- + +### Checklist + + + +- [ ] There are no similar issues or pull requests for this yet. +- [ ] I discussed this idea on the [community chat](https://gitter.im/encode/community) and feedback is positive. + +### Is your feature related to a problem? Please describe. + + + +## Describe the solution you would like. + + + +## Describe alternatives you considered + + + +## Additional context + + diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 000000000..2ad6e8e27 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,7 @@ +# Ref: https://help.github.com/en/github/building-a-strong-community/configuring-issue-templates-for-your-repository#configuring-the-template-chooser +blank_issues_enabled: true +contact_links: +- name: Question + url: https://gitter.im/encode/community + about: > + Ask a question diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 000000000..b290d6e1a --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,27 @@ +--- +name: Publish + +on: + push: + tags: + - '*' + +jobs: + publish: + name: "Publish release" + runs-on: "ubuntu-latest" + + steps: + - uses: "actions/checkout@v2" + - uses: "actions/setup-python@v2" + with: + python-version: 3.7 + - name: "Install dependencies" + run: "scripts/install" + - name: "Build package & docs" + run: "scripts/build" + - name: "Publish to PyPI & deploy docs" + run: "scripts/publish" + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml new file mode 100644 index 000000000..751c5193b --- /dev/null +++ b/.github/workflows/test-suite.yml @@ -0,0 +1,33 @@ +--- +name: Test Suite + +on: + push: + branches: ["master"] + pull_request: + branches: ["master"] + +jobs: + tests: + name: "Python ${{ matrix.python-version }}" + runs-on: "ubuntu-latest" + + strategy: + matrix: + python-version: ["3.6", "3.7", "3.8", "3.9", "3.10.0-beta.3"] + + steps: + - uses: "actions/checkout@v2" + - uses: "actions/setup-python@v2" + with: + python-version: "${{ matrix.python-version }}" + - name: "Install dependencies" + run: "scripts/install" + - name: "Run linting checks" + run: "scripts/check" + - name: "Build package & docs" + run: "scripts/build" + - name: "Run tests" + run: "scripts/test" + - name: "Enforce coverage" + run: "scripts/coverage" diff --git a/.gitignore b/.gitignore index 7b5d4318c..bff8fa258 100644 --- a/.gitignore +++ b/.gitignore @@ -3,5 +3,11 @@ test.db .coverage .pytest_cache/ .mypy_cache/ -starlette.egg-info/ -venv/ +__pycache__/ +htmlcov/ +site/ +*.egg-info/ +venv*/ +.python-version +build/ +dist/ diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 4f4ef5dc5..000000000 --- a/.travis.yml +++ /dev/null @@ -1,19 +0,0 @@ -dist: xenial -language: python - -cache: pip - -python: - - "3.6" - - "3.7" - - "3.8-dev" - -install: - - pip install -U -r requirements.txt - -script: - - scripts/test - -after_script: - - pip install codecov - - codecov diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md deleted file mode 100644 index 6acbc2c87..000000000 --- a/CONTRIBUTING.md +++ /dev/null @@ -1,79 +0,0 @@ -# Contributing to Starlette - -The Starlette team happily welcomes contributions. This document will help you get ready to contribute to Starlette! - -To submit new code to the project you'll need to: - -* Fork the repo. -* Clone your fork on your local computer: `git clone https://github.com//starlette.git`. -* Install Starlette locally and run the tests: `./scripts/install`, `./scripts/test`. -* Create a branch for your work, e.g. `git checkout -b fix-some-bug`. -* Remember to include tests and documentation updates if applicable. -* Once ready, push to your remote: `git push origin fix-some-bug`. -* [Open a Pull Request][pull-request]. - -## Install - -**Note**: These scripts are currently suited to **Linux** and **macOS**, but we would happily take pull requests to help us make them more cross-compatible. - -Use the `install` script to install project dependencies in a virtual environment. - -```bash -./scripts/install -``` - -To use a specific Python executable, use the `-p` option, e.g.: - -```bash -./scripts/install -p python3.7 -``` - -## Running the tests - -The tests are written using [pytest] and located in the `tests/` directory. - -**Note**: tests should be run before making any changes to the code in order to make sure that everything is running as expected. - -We provide a stand-alone **test script** to run tests in a reliable manner. Run it with: - -```bash -./scripts/test -``` - -By default, tests involving a database are excluded. To include them, set the `STARLETTE_TEST_DATABASES` environment variable. This should be a comma separated string of database URLs. - -```bash -# Any of the following are valid for running the database tests... -export STARLETTE_TEST_DATABASES="postgresql://localhost/starlette" -export STARLETTE_TEST_DATABASES="mysql://localhost/starlette_test" -export STARLETTE_TEST_DATABASES="postgresql://localhost/starlette, mysql://localhost/starlette_test" -``` - -## Linting - -We use [Black][black] as a code formatter. To run it along with a few other linting tools, we provide a stand-alone linting script: - -```bash -./scripts/lint -``` - -If linting has anything to say about the code, it will format it in-place. - -To keep the code style consistent, you should apply linting before committing. - -## Documentation - -The documentation is built with [MkDocs], a Markdown-based documentation site generator. - -To run the docs site in hot-reload mode (useful when editing the docs), run `$ mkdocs serve` in the project root directory. - -For your information, the docs site configuration is located in the `mkdocs.yml` file. - -Please refer to the [MkDocs docs][MkDocs] for more usage information, including how to add new pages. - -[issues]: https://github.com/encode/starlette/issues/new -[pull-request]: https://github.com/encode/starlette/compare -[pytest]: https://docs.pytest.org -[pytest-cov]: https://github.com/pytest-dev/pytest-cov -[black]: https://www.google.com/search?client=safari&rls=en&q=github+black&ie=UTF-8&oe=UTF-8 -[MkDocs]: https://www.mkdocs.org diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 000000000..9cccc91b7 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,3 @@ +include LICENSE.md +global-exclude __pycache__ +global-exclude *.py[co] diff --git a/README.md b/README.md index 4de0e8b02..8eedea952 100644 --- a/README.md +++ b/README.md @@ -5,11 +5,8 @@ ✨ The little ASGI framework that shines. ✨

- - Build Status - - - Coverage + + Build Status Package version @@ -25,7 +22,7 @@ # Starlette Starlette is a lightweight [ASGI](https://asgi.readthedocs.io/en/latest/) framework/toolkit, -which is ideal for building high performance asyncio services. +which is ideal for building high performance async services. It is production-ready, and gives you the following: @@ -39,7 +36,8 @@ It is production-ready, and gives you the following: * Session and Cookie support. * 100% test coverage. * 100% type annotated codebase. -* Zero hard dependencies. +* Few hard dependencies. +* Compatible with `asyncio` and `trio` backends. ## Requirements @@ -59,36 +57,42 @@ $ pip3 install uvicorn ## Example +**example.py**: + ```python from starlette.applications import Starlette from starlette.responses import JSONResponse -import uvicorn - -app = Starlette(debug=True) +from starlette.routing import Route -@app.route('/') async def homepage(request): return JSONResponse({'hello': 'world'}) -if __name__ == '__main__': - uvicorn.run(app, host='0.0.0.0', port=8000) +routes = [ + Route("/", endpoint=homepage) +] + +app = Starlette(debug=True, routes=routes) +``` + +Then run the application using Uvicorn: + +```shell +$ uvicorn example:app ``` For a more complete example, see [encode/starlette-example](https://github.com/encode/starlette-example). ## Dependencies -Starlette does not have any hard dependencies, but the following are optional: +Starlette only requires `anyio`, and the following are optional: * [`requests`][requests] - Required if you want to use the `TestClient`. -* [`aiofiles`][aiofiles] - Required if you want to use `FileResponse` or `StaticFiles`. * [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`. * [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`. * [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support. * [`pyyaml`][pyyaml] - Required for `SchemaGenerator` support. * [`graphene`][graphene] - Required for `GraphQLApp` support. -* [`ujson`][ujson] - Required if you want to use `UJSONResponse`. You can install all of these with `pip3 install starlette[full]`. @@ -101,20 +105,16 @@ an ASGI toolkit. You can use any of its components independently. from starlette.responses import PlainTextResponse -class App: - def __init__(self, scope): - assert scope['type'] == 'http' - self.scope = scope - - async def __call__(self, receive, send): - response = PlainTextResponse('Hello, world!') - await response(receive, send) +async def app(scope, receive, send): + assert scope['type'] == 'http' + response = PlainTextResponse('Hello, world!') + await response(scope, receive, send) ``` -Run the `App` application in `example.py`: +Run the `app` application in `example.py`: ```shell -$ uvicorn example:App +$ uvicorn example:app INFO: Started server process [11509] INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit) ``` @@ -137,7 +137,6 @@ as [one of the fastest Python frameworks available](https://www.techempower.com/ For high throughput loads you should: -* Make sure to install `ujson` and use `UJSONResponse`. * Run using gunicorn using the `uvicorn` worker class. * Use one or two workers per-CPU core. (You might need to experiment with this.) * Disable access logging. @@ -168,11 +167,9 @@ gunicorn -k uvicorn.workers.UvicornH11Worker ...

Starlette is BSD licensed code. Designed & built in Brighton, England.

[requests]: http://docs.python-requests.org/en/master/ -[aiofiles]: https://github.com/Tinche/aiofiles [jinja2]: http://jinja.pocoo.org/ [python-multipart]: https://andrew-d.github.io/python-multipart/ [graphene]: https://graphene-python.org/ [itsdangerous]: https://pythonhosted.org/itsdangerous/ [sqlalchemy]: https://www.sqlalchemy.org [pyyaml]: https://pyyaml.org/wiki/PyYAMLDocumentation -[ujson]: https://github.com/esnme/ultrajson diff --git a/docs/applications.md b/docs/applications.md index 0092a3b4c..6fb74f19f 100644 --- a/docs/applications.md +++ b/docs/applications.md @@ -5,74 +5,57 @@ its other functionality. ```python from starlette.applications import Starlette from starlette.responses import PlainTextResponse +from starlette.routing import Route, Mount, WebSocketRoute from starlette.staticfiles import StaticFiles -app = Starlette() -app.debug = True -app.mount('/static', StaticFiles(directory="static")) - - -@app.route('/') def homepage(request): return PlainTextResponse('Hello, world!') -@app.route('/user/me') def user_me(request): username = "John Doe" return PlainTextResponse('Hello, %s!' % username) -@app.route('/user/{username}') def user(request): username = request.path_params['username'] return PlainTextResponse('Hello, %s!' % username) - -@app.websocket_route('/ws') async def websocket_endpoint(websocket): await websocket.accept() await websocket.send_text('Hello, websocket!') await websocket.close() - -@app.on_event('startup') def startup(): print('Ready to go') -``` - -### Instantiating the application -* `Starlette(debug=False)` - Create a new Starlette application. -### Adding routes to the application +routes = [ + Route('/', homepage), + Route('/user/me', user_me), + Route('/user/{username}', user), + WebSocketRoute('/ws', websocket_endpoint), + Mount('/static', StaticFiles(directory="static")), +] -You can use any of the following to add handled routes to the application: - -* `app.add_route(path, func, methods=["GET"])` - Add an HTTP route. The function may be either a coroutine or a regular function, with a signature like `func(request, **kwargs) -> response`. -* `app.add_websocket_route(path, func)` - Add a websocket session route. The function must be a coroutine, with a signature like `func(session, **kwargs)`. -* `@app.route(path)` - Add an HTTP route, decorator style. -* `@app.websocket_route(path)` - Add a WebSocket route, decorator style. - -### Adding event handlers to the application - -There are two ways to add event handlers: +app = Starlette(debug=True, routes=routes, on_startup=[startup]) +``` -* `@app.on_event(event_type)` - Add an event, decorator style -* `app.add_event_handler(event_type, func)` - Add an event through a function call. +### Instantiating the application -`event_type` must be specified as either `'startup'` or `'shutdown'`. +::: starlette.applications.Starlette + :docstring: -### Submounting other applications +### Storing state on the app instance -Submounting applications is a powerful way to include reusable ASGI applications. +You can store arbitrary extra state on the application instance, using the +generic `app.state` attribute. -* `app.mount(prefix, app)` - Include an ASGI app, mounted under the given path prefix +For example: -### Customizing exception handling +```python +app.state.ADMIN_EMAIL = 'admin@example.org' +``` -You can use either of the following to catch and handle particular types of -exceptions that occur within the application: +### Accessing the app instance -* `app.add_exception_handler(exc_class_or_status_code, handler)` - Add an error handler. The handler function may be either a coroutine or a regular function, with a signature like `func(request, exc) -> response`. -* `@app.exception_handler(exc_class_or_status_code)` - Add an error handler, decorator style. -* `app.debug` - Enable or disable error tracebacks in the browser. +Where a `request` is available (i.e. endpoints and middleware), the app is available on `request.app`. diff --git a/docs/authentication.md b/docs/authentication.md index e0e21853c..d4af5b216 100644 --- a/docs/authentication.md +++ b/docs/authentication.md @@ -5,12 +5,15 @@ interfaces will be available in your endpoints. ```python +from starlette.applications import Starlette from starlette.authentication import ( AuthenticationBackend, AuthenticationError, SimpleUser, UnauthenticatedUser, AuthCredentials ) +from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware from starlette.responses import PlainTextResponse +from starlette.routing import Route import base64 import binascii @@ -30,21 +33,24 @@ class BasicAuthBackend(AuthenticationBackend): raise AuthenticationError('Invalid basic auth credentials') username, _, password = decoded.partition(":") - # TODO: You'd want to verify the username and password here, - # possibly by installing `DatabaseMiddleware` - # and retrieving user information from `request.database`. + # TODO: You'd want to verify the username and password here. return AuthCredentials(["authenticated"]), SimpleUser(username) -app = Starlette() -app.add_middleware(AuthenticationMiddleware, backend=BasicAuthBackend()) - - -@app.route('/') async def homepage(request): if request.user.is_authenticated: - return PlainTextResponse('hello, ' + request.user.display_name) - return PlainTextResponse('hello, you') + return PlainTextResponse('Hello, ' + request.user.display_name) + return PlainTextResponse('Hello, you') + +routes = [ + Route("/", endpoint=homepage) +] + +middleware = [ + Middleware(AuthenticationMiddleware, backend=BasicAuthBackend()) +] + +app = Starlette(routes=routes, middleware=middleware) ``` ## Users @@ -81,7 +87,6 @@ incoming request includes the required authentication scopes. from starlette.authentication import requires -@app.route('/dashboard') @requires('authenticated') async def dashboard(request): ... @@ -93,7 +98,6 @@ You can include either one or multiple required scopes: from starlette.authentication import requires -@app.route('/dashboard') @requires(['authenticated', 'admin']) async def dashboard(request): ... @@ -107,7 +111,6 @@ about the URL layout from unauthenticated users. from starlette.authentication import requires -@app.route('/dashboard') @requires(['authenticated', 'admin'], status_code=404) async def dashboard(request): ... @@ -120,12 +123,10 @@ page. from starlette.authentication import requires -@app.route('/homepage') async def homepage(request): ... -@app.route('/dashboard') @requires('authenticated', redirect='homepage') async def dashboard(request): ... @@ -135,7 +136,6 @@ For class-based endpoints, you should wrap the decorator around a method on the class. ```python -@app.route("/dashboard") class Dashboard(HTTPEndpoint): @requires("authenticated") async def get(self, request): diff --git a/docs/background.md b/docs/background.md index d27fa65fe..e10832a92 100644 --- a/docs/background.md +++ b/docs/background.md @@ -13,11 +13,12 @@ Signature: `BackgroundTask(func, *args, **kwargs)` ```python from starlette.applications import Starlette from starlette.responses import JSONResponse +from starlette.routing import Route from starlette.background import BackgroundTask -app = Starlette() -@app.route('/user/signup', methods=['POST']) +... + async def signup(request): data = await request.json() username = data['username'] @@ -28,6 +29,14 @@ async def signup(request): async def send_welcome_email(to_address): ... + + +routes = [ + ... + Route('/user/signup', endpoint=signup, methods=['POST']) +] + +app = Starlette(routes=routes) ``` ### BackgroundTasks @@ -41,9 +50,6 @@ from starlette.applications import Starlette from starlette.responses import JSONResponse from starlette.background import BackgroundTasks -app = Starlette() - -@app.route('/user/signup', methods=['POST']) async def signup(request): data = await request.json() username = data['username'] @@ -60,4 +66,9 @@ async def send_welcome_email(to_address): async def send_admin_notification(username): ... +routes = [ + Route('/user/signup', endpoint=signup, methods=['POST']) +] + +app = Starlette(routes=routes) ``` diff --git a/docs/config.md b/docs/config.md index 74259c5b0..7a93b22e9 100644 --- a/docs/config.md +++ b/docs/config.md @@ -21,8 +21,7 @@ DATABASE_URL = config('DATABASE_URL', cast=databases.DatabaseURL) SECRET_KEY = config('SECRET_KEY', cast=Secret) ALLOWED_HOSTS = config('ALLOWED_HOSTS', cast=CommaSeparatedStrings) -app = Starlette() -app.debug = DEBUG +app = Starlette(debug=DEBUG) ... ``` @@ -86,7 +85,7 @@ type is useful. CommaSeparatedStrings(['127.0.0.1', 'localhost']) >>> print(list(settings.ALLOWED_HOSTS)) ['127.0.0.1', 'localhost'] ->>> print(len(settings.ALLOWED_HOSTS[0])) +>>> print(len(settings.ALLOWED_HOSTS)) 2 >>> print(settings.ALLOWED_HOSTS[0]) '127.0.0.1' @@ -160,28 +159,27 @@ organisations = sqlalchemy.Table( ```python from starlette.applications import Starlette -from starlette.middleware.database import DatabaseMiddleware +from starlette.middleware import Middleware from starlette.middleware.session import SessionMiddleware +from starlette.routing import Route from myproject import settings -app = Starlette() +async def homepage(request): + ... -app.debug = settings.DEBUG +routes = [ + Route("/", endpoint=homepage) +] -app.add_middleware( - SessionMiddleware, - secret_key=settings.SECRET_KEY, -) -app.add_middleware( - DatabaseMiddleware, - database_url=settings.DATABASE_URL, - rollback_on_shutdown=settings.TESTING -) +middleware = [ + Middleware( + SessionMiddleware, + secret_key=settings.SECRET_KEY, + ) +] -@app.route('/', methods=['GET']) -async def homepage(request): - ... +app = Starlette(debug=settings.DEBUG, routes=routes, middleware=middleware) ``` Now let's deal with our test configuration. diff --git a/docs/database.md b/docs/database.md index 39e605b75..ca1b85d6a 100644 --- a/docs/database.md +++ b/docs/database.md @@ -1,6 +1,6 @@ Starlette is not strictly tied to any particular database implementation. -You can use it with an asynchronous ORM, such as [GINO](https://python-gino.readthedocs.io/en/latest/), +You can use it with an asynchronous ORM, such as [GINO](https://python-gino.org/), or use regular non-async endpoints, and integrate with [SQLAlchemy](https://www.sqlalchemy.org/). In this documentation we'll demonstrate how to integrate against [the `databases` package](https://github.com/encode/databases), @@ -27,6 +27,7 @@ import sqlalchemy from starlette.applications import Starlette from starlette.config import Config from starlette.responses import JSONResponse +from starlette.routing import Route # Configuration from environment variables or '.env' file. @@ -45,22 +46,10 @@ notes = sqlalchemy.Table( sqlalchemy.Column("completed", sqlalchemy.Boolean), ) -# Main application code. database = databases.Database(DATABASE_URL) -app = Starlette() - - -@app.on_event("startup") -async def startup(): - await database.connect() - -@app.on_event("shutdown") -async def shutdown(): - await database.disconnect() - -@app.route("/notes", methods=["GET"]) +# Main application code. async def list_notes(request): query = notes.select() results = await database.fetch_all(query) @@ -73,8 +62,6 @@ async def list_notes(request): ] return JSONResponse(content) - -@app.route("/notes", methods=["POST"]) async def add_note(request): data = await request.json() query = notes.insert().values( @@ -86,8 +73,22 @@ async def add_note(request): "text": data["text"], "completed": data["completed"] }) + +routes = [ + Route("/notes", endpoint=list_notes, methods=["GET"]), + Route("/notes", endpoint=add_note, methods=["POST"]), +] + +app = Starlette( + routes=routes, + on_startup=[database.connect], + on_shutdown=[database.disconnect] +) ``` +Finally, you will need to create the database tables. It is recommended to use +Alembic, which we briefly go over in [Migrations](#migrations) + ## Queries Queries may be made with as [SQLAlchemy Core queries][sqlalchemy-core]. @@ -205,7 +206,7 @@ def create_test_database(): We use the `sqlalchemy_utils` package here for a few helpers in consistently creating and dropping the database. """ - url = str(app.DATABASE_URL) + url = str(app.TEST_DATABASE_URL) engine = create_engine(url) assert not database_exists(url), 'Test database already exists. Aborting tests.' create_database(url) # Create the test database. @@ -264,6 +265,34 @@ target_metadata = app.metadata ... ``` +Then, using our notes example above, create an initial revision: + +```shell +alembic revision -m "Create notes table" +``` + +And populate the new file (within `migrations/versions`) with the necessary directives: + +```python + +def upgrade(): + op.create_table( + 'notes', + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), + sqlalchemy.Column("text", sqlalchemy.String), + sqlalchemy.Column("completed", sqlalchemy.Boolean), + ) + +def downgrade(): + op.drop_table('notes') +``` + +And run your first migration. Our notes app can now run! + +```shell +alembic upgrade head +``` + **Running migrations during testing** It is good practice to ensure that your test suite runs the database migrations diff --git a/docs/endpoints.md b/docs/endpoints.md index fe05434c3..1362f5e80 100644 --- a/docs/endpoints.md +++ b/docs/endpoints.md @@ -17,30 +17,32 @@ class App(HTTPEndpoint): ``` If you're using a Starlette application instance to handle routing, you can -dispatch to an `HTTPEndpoint` class by using the `@app.route()` decorator, or the -`app.add_route()` function. Make sure to dispatch to the class itself, rather -than to an instance of the class: +dispatch to an `HTTPEndpoint` class. Make sure to dispatch to the class itself, +rather than to an instance of the class: ```python from starlette.applications import Starlette from starlette.responses import PlainTextResponse from starlette.endpoints import HTTPEndpoint +from starlette.routing import Route -app = Starlette() - - -@app.route("/") class Homepage(HTTPEndpoint): async def get(self, request): return PlainTextResponse(f"Hello, world!") -@app.route("/{username}") class User(HTTPEndpoint): async def get(self, request): username = request.path_params['username'] return PlainTextResponse(f"Hello, {username}") + +routes = [ + Route("/", Homepage), + Route("/{username}", User) +] + +app = Starlette(routes=routes) ``` HTTP endpoint classes will respond with "405 Method not allowed" responses for any @@ -90,8 +92,8 @@ import uvicorn from starlette.applications import Starlette from starlette.endpoints import WebSocketEndpoint, HTTPEndpoint from starlette.responses import HTMLResponse +from starlette.routing import Route, WebSocketRoute -app = Starlette() html = """ @@ -127,22 +129,20 @@ html = """ """ - -@app.route("/") class Homepage(HTTPEndpoint): async def get(self, request): return HTMLResponse(html) - -@app.websocket_route("/ws") class Echo(WebSocketEndpoint): - encoding = "text" async def on_receive(self, websocket, data): await websocket.send_text(f"Message text was: {data}") +routes = [ + Route("/", Homepage), + WebSocketRoute("/ws", Echo) +] -if __name__ == "__main__": - uvicorn.run(app, host="0.0.0.0", port=8000) +app = Starlette(routes=routes) ``` diff --git a/docs/events.md b/docs/events.md index 8dab26021..c7ed49e9d 100644 --- a/docs/events.md +++ b/docs/events.md @@ -5,49 +5,62 @@ is shutting down. ## Registering events -These event handlers can either be `async` coroutines, or regular syncronous +These event handlers can either be `async` coroutines, or regular synchronous functions. -The event handlers can be registered with a decorator syntax, like so: +The event handlers should be included on the application like so: ```python from starlette.applications import Starlette -app = Starlette() +async def some_startup_task(): + pass -@app.on_event('startup') -async def open_database_connection_pool(): - ... +async def some_shutdown_task(): + pass -@app.on_event('shutdown') -async def close_database_connection_pool(): +routes = [ ... +] + +app = Starlette( + routes=routes, + on_startup=[some_startup_task], + on_shutdown=[some_shutdown_task] +) ``` -Or as a regular function call: + +Starlette will not start serving any incoming requests until all of the +registered startup handlers have completed. + +The shutdown handlers will run once all connections have been closed, and +any in-process background tasks have completed. + +A single lifespan asynccontextmanager handler can be used instead of +separate startup and shutdown handlers: ```python +import contextlib +import anyio from starlette.applications import Starlette -app = Starlette() +@contextlib.asynccontextmanager +async def lifespan(app): + async with some_async_resource(): + yield -async def open_database_connection_pool(): - ... -async def close_database_connection_pool(): +routes = [ ... +] -app.add_event_handler('startup', open_database_connection_pool) -app.add_event_handler('shutdown', close_database_connection_pool) - +app = Starlette(routes=routes, lifespan=lifespan) ``` -Starlette will not start serving any incoming requests until all of the -registered startup handlers have completed. - -The shutdown handlers will run once all connections have been closed, and -any in-process background tasks have completed. +Consider using [`anyio.create_task_group()`](https://anyio.readthedocs.io/en/stable/tasks.html) +for managing asynchronious tasks. ## Running event handlers in tests @@ -59,15 +72,14 @@ startup and shutdown events are called. ```python from example import app -from starlette.lifespan import LifespanContext from starlette.testclient import TestClient def test_homepage(): with TestClient(app) as client: - # Application 'startup' handlers are called on entering the block. + # Application 'on_startup' handlers are called on entering the block. response = client.get("/") assert response.status_code == 200 - # Application 'shutdown' handlers are called on exiting the block. + # Application 'on_shutdown' handlers are called on exiting the block. ``` diff --git a/docs/exceptions.md b/docs/exceptions.md index 732cf81cd..10451c86b 100644 --- a/docs/exceptions.md +++ b/docs/exceptions.md @@ -11,23 +11,26 @@ HTML_404_PAGE = ... HTML_500_PAGE = ... -app = Starlette() - - -@app.exception_handler(404) async def not_found(request, exc): return HTMLResponse(content=HTML_404_PAGE, status_code=exc.status_code) -@app.exception_handler(500) async def server_error(request, exc): return HTMLResponse(content=HTML_500_PAGE, status_code=exc.status_code) + + +exception_handlers = { + 404: not_found, + 500: server_error +} + +app = Starlette(routes=routes, exception_handlers=exception_handlers) ``` If `debug` is enabled and an error occurs, then instead of using the installed 500 handler, Starlette will respond with a traceback response. ```python -app = Starlette(debug=True) +app = Starlette(debug=True, routes=routes, exception_handlers=exception_handlers) ``` As well as registering handlers for specific status codes, you can also @@ -37,9 +40,12 @@ In particular you might want to override how the built-in `HTTPException` class is handled. For example, to use JSON style responses: ```python -@app.exception_handler(HTTPException) async def http_exception(request, exc): return JSONResponse({"detail": exc.detail}, status_code=exc.status_code) + +exception_handlers = { + HTTPException: http_exception +} ``` You might also want to override how `WebSocketException` is handled: diff --git a/docs/graphql.md b/docs/graphql.md index 2fb00a55e..281bdea85 100644 --- a/docs/graphql.md +++ b/docs/graphql.md @@ -1,10 +1,25 @@ +!!! Warning + + GraphQL support in Starlette is **deprecated** as of version 0.15 and will + be removed in a future release. It is also incompatible with Python 3.10+. + Please consider using a third-party library to provide GraphQL support. This + is usually done by mounting a GraphQL ASGI application. + See [#619](https://github.com/encode/starlette/issues/619). + Some example libraries are: + + * [Ariadne](https://ariadnegraphql.org/docs/asgi) + * [`tartiflette-asgi`](https://tartiflette.github.io/tartiflette-asgi/) + * [Strawberry](https://strawberry.rocks/docs/integrations/asgi) + * [`starlette-graphene3`](https://github.com/ciscorn/starlette-graphene3) + Starlette includes optional support for GraphQL, using the `graphene` library. Here's an example of integrating the support into your application. ```python from starlette.applications import Starlette +from starlette.routing import Route from starlette.graphql import GraphQLApp import graphene @@ -15,9 +30,11 @@ class Query(graphene.ObjectType): def resolve_hello(self, info, name): return "Hello " + name +routes = [ + Route('/', GraphQLApp(schema=graphene.Schema(query=Query))) +] -app = Starlette() -app.add_route('/', GraphQLApp(schema=graphene.Schema(query=Query))) +app = Starlette(routes=routes) ``` If you load up the page in a browser, you'll be served the GraphiQL tool, @@ -67,15 +84,16 @@ async def log_user_agent(user_agent): If you're working with a standard ORM, then just use regular function calls for your "resolve" methods, and Starlette will manage running the GraphQL query within a -seperate thread. +separate thread. -If you want to use an asyncronous ORM, then use "async resolve" methods, and +If you want to use an asynchronous ORM, then use "async resolve" methods, and make sure to setup Graphene's AsyncioExecutor using the `executor` argument. ```python from graphql.execution.executors.asyncio import AsyncioExecutor from starlette.applications import Starlette from starlette.graphql import GraphQLApp +from starlette.routing import Route import graphene @@ -86,9 +104,13 @@ class Query(graphene.ObjectType): # We can make asynchronous network calls here. return "Hello " + name +routes = [ + # We're using `executor_class=AsyncioExecutor` here. + Route('/', GraphQLApp( + schema=graphene.Schema(query=Query), + executor_class=AsyncioExecutor + )) +] -app = Starlette() - -# We're using `executor_class=AsyncioExecutor` here. -app.add_route('/', GraphQLApp(schema=graphene.Schema(query=Query), executor_class=AsyncioExecutor)) +app = Starlette(routes=routes) ``` diff --git a/docs/index.md b/docs/index.md index d5f40c236..b9692a1fb 100644 --- a/docs/index.md +++ b/docs/index.md @@ -5,11 +5,8 @@ ✨ The little ASGI framework that shines. ✨

- - Build Status - - - Coverage + + Build Status Package version @@ -28,7 +25,6 @@ It is production-ready, and gives you the following: * Seriously impressive performance. * WebSocket support. -* GraphQL support. * In-process background tasks. * Startup and shutdown events. * Test client built on `requests`. @@ -36,7 +32,7 @@ It is production-ready, and gives you the following: * Session and Cookie support. * 100% test coverage. * 100% type annotated codebase. -* Zero hard dependencies. +* Few hard dependencies. ## Requirements @@ -56,35 +52,40 @@ $ pip3 install uvicorn ## Example +**example.py**: + ```python from starlette.applications import Starlette from starlette.responses import JSONResponse -import uvicorn +from starlette.routing import Route -app = Starlette(debug=True) -@app.route('/') async def homepage(request): return JSONResponse({'hello': 'world'}) -if __name__ == '__main__': - uvicorn.run(app, host='0.0.0.0', port=8000) + +app = Starlette(debug=True, routes=[ + Route('/', homepage), +]) +``` + +Then run the application... + +```shell +$ uvicorn example:app ``` For a more complete example, [see here](https://github.com/encode/starlette-example). ## Dependencies -Starlette does not have any hard dependencies, but the following are optional: +Starlette only requires `anyio`, and the following dependencies are optional: * [`requests`][requests] - Required if you want to use the `TestClient`. -* [`aiofiles`][aiofiles] - Required if you want to use `FileResponse` or `StaticFiles`. * [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`. * [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`. * [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support. * [`pyyaml`][pyyaml] - Required for `SchemaGenerator` support. -* [`graphene`][graphene] - Required for `GraphQLApp` support. -* [`ujson`][ujson] - Required if you want to use `UJSONResponse`. You can install all of these with `pip3 install starlette[full]`. @@ -97,20 +98,16 @@ an ASGI toolkit. You can use any of its components independently. from starlette.responses import PlainTextResponse -class App: - def __init__(self, scope): - assert scope['type'] == 'http' - self.scope = scope - - async def __call__(self, receive, send): - response = PlainTextResponse('Hello, world!') - await response(receive, send) +async def app(scope, receive, send): + assert scope['type'] == 'http' + response = PlainTextResponse('Hello, world!') + await response(scope, receive, send) ``` -Run the `App` application in `example.py`: +Run the `app` application in `example.py`: ```shell -$ uvicorn example:App +$ uvicorn example:app INFO: Started server process [11509] INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit) ``` @@ -133,7 +130,6 @@ as [one of the fastest Python frameworks available](https://www.techempower.com/ For high throughput loads you should: -* Make sure to install `ujson` and use `UJSONResponse`. * Run using Gunicorn using the `uvicorn` worker class. * Use one or two workers per-CPU core. (You might need to experiment with this.) * Disable access logging. @@ -164,11 +160,9 @@ gunicorn -k uvicorn.workers.UvicornH11Worker ...

Starlette is BSD licensed code. Designed & built in Brighton, England.

[requests]: http://docs.python-requests.org/en/master/ -[aiofiles]: https://github.com/Tinche/aiofiles [jinja2]: http://jinja.pocoo.org/ [python-multipart]: https://andrew-d.github.io/python-multipart/ [graphene]: https://graphene-python.org/ [itsdangerous]: https://pythonhosted.org/itsdangerous/ [sqlalchemy]: https://www.sqlalchemy.org [pyyaml]: https://pyyaml.org/wiki/PyYAMLDocumentation -[ujson]: https://github.com/esnme/ultrajson diff --git a/docs/js/chat.js b/docs/js/chat.js new file mode 100644 index 000000000..c82454918 --- /dev/null +++ b/docs/js/chat.js @@ -0,0 +1,3 @@ +((window.gitter = {}).chat = {}).options = { + room: 'encode/community' +}; diff --git a/docs/js/sidecar-1.5.0.js b/docs/js/sidecar-1.5.0.js new file mode 100644 index 000000000..44899c4b3 --- /dev/null +++ b/docs/js/sidecar-1.5.0.js @@ -0,0 +1,6 @@ +/*! + * Gitter Sidecar v1.5.0 + * https://sidecar.gitter.im/ + */ +var sidecar=function(t){function e(r){if(i[r])return i[r].exports;var n=i[r]={exports:{},id:r,loaded:!1};return t[r].call(n.exports,n,n.exports,e),n.loaded=!0,n.exports}var i={};return e.m=t,e.c=i,e.p="",e(0)}([function(t,e,i){"use strict";function r(t){return t&&t.__esModule?t:{default:t}}Object.defineProperty(e,"__esModule",{value:!0});var n=Object.assign||function(t){for(var e=1;e1?i-1:0),n=1;n0&&void 0!==arguments[0]?arguments[0]:{};o(this,t),this[C]=new L.default,this[S]=[],this[I]=v,this[A]=u({},this[I],e),this[O]()}return a(t,[{key:O,value:function(){var t=this,e=this[A];e.useStyles&&this[C].add(h()),e.targetElement=(0,x.default)(e.targetElement||function(){var e=t[C].createElement("aside");return e.classList.add("gitter-chat-embed"),e.classList.add("is-collapsed"),m.appendChild(e),e}()),e.targetElement.forEach(function(e){var i=t[C].createElement("div");i.classList.add("gitter-chat-embed-loading-wrapper"),i.innerHTML='\n
\n ',e.insertBefore(i,e.firstChild)}),p(this),e.preload&&this.toggleChat(!1),e.showChatByDefault?this.toggleChat(!0):(void 0===e.activationElement||e.activationElement===!0?e.activationElement=(0,x.default)(function(){var i=t[C].createElement("a");return i.href=""+e.host+e.room,i.innerHTML="Open Chat",i.classList.add("gitter-open-chat-button"),m.appendChild(i),i}()):e.activationElement&&(e.activationElement=(0,x.default)(e.activationElement)),e.activationElement&&(z(e.activationElement,function(e){t.toggleChat(!0),e.preventDefault()}),e.targetElement.forEach(function(t){x.on(t,"gitter-chat-toggle",function(t){var i=t.detail.state;e.activationElement.forEach(function(t){x.toggleClass(t,"is-collapsed",i)})})})));var i=z((0,x.default)(".js-gitter-toggle-chat-button"),function(e){var i=T(e.target.getAttribute("data-gitter-toggle-chat-state"));t.toggleChat(null!==i?i:"toggle"),e.preventDefault()});this[S].push(i),e.targetElement.forEach(function(e){var i=new l.default("gitter-chat-started",{detail:{chat:t}});e.dispatchEvent(i)});var r=new l.default("gitter-sidecar-instance-started",{detail:{chat:this}});document.dispatchEvent(r)}},{key:U,value:function(){if(!this[k]){var t=this[A],e=w(t);this[C].add(e)}this[k]=!0}},{key:Y,value:function(t){var e=this[A];e.targetElement||console.warn("Gitter Sidecar: No chat embed elements to toggle visibility on");var i=e.targetElement;i.forEach(function(e){"toggle"===t?x.toggleClass(e,"is-collapsed"):x.toggleClass(e,"is-collapsed",!t);var i=new l.default("gitter-chat-toggle",{detail:{state:t}});e.dispatchEvent(i)})}},{key:"toggleChat",value:function(t){var e=this,i=this[A];if(t&&!this[k]){var r=i.targetElement;r.forEach(function(t){t.classList.add("is-loading")}),setTimeout(function(){e[U](),e[Y](t),r.forEach(function(t){t.classList.remove("is-loading")})},300)}else this[U](),this[Y](t)}},{key:"destroy",value:function(){this[S].forEach(function(t){t()}),this[C].destroy()}},{key:"options",get:function(){return(0,j.default)(this[A])}}]),t}();e.default=Q},function(t,e){"use strict";function i(t){if(Array.isArray(t)){for(var e=0,i=Array(t.length);eiframe{box-sizing:border-box;-ms-flex:1;flex:1;width:100%;height:100%;border:0}.gitter-chat-embed-loading-wrapper{box-sizing:border-box;position:absolute;top:0;left:0;bottom:0;right:0;display:none;-ms-flex-pack:center;justify-content:center;-ms-flex-align:center;align-items:center}.is-loading .gitter-chat-embed-loading-wrapper{box-sizing:border-box;display:-ms-flexbox;display:flex}.gitter-chat-embed-loading-indicator{box-sizing:border-box;opacity:.75;background-image:url();animation:spin 2s infinite linear}@keyframes spin{0%{box-sizing:border-box;transform:rotate(0deg)}to{box-sizing:border-box;transform:rotate(359.9deg)}}.gitter-chat-embed-action-bar{box-sizing:border-box;position:absolute;top:0;left:0;right:0;display:-ms-flexbox;display:flex;-ms-flex-pack:end;justify-content:flex-end;padding-bottom:.7em;background:linear-gradient(180deg,#fff 0,#fff 50%,hsla(0,0%,100%,0))}.gitter-chat-embed-action-bar-item{box-sizing:border-box;display:-ms-flexbox;display:flex;-ms-flex-pack:center;justify-content:center;-ms-flex-align:center;align-items:center;width:40px;height:40px;padding-left:0;padding-right:0;opacity:.65;background:none;background-position:50%;background-repeat:no-repeat;background-size:22px 22px;border:0;outline:none;cursor:pointer;cursor:hand;transition:all .2s ease}.gitter-chat-embed-action-bar-item:focus,.gitter-chat-embed-action-bar-item:hover{box-sizing:border-box;opacity:1}.gitter-chat-embed-action-bar-item:active{box-sizing:border-box;filter:hue-rotate(80deg) saturate(150)}.gitter-chat-embed-action-bar-item-pop-out{box-sizing:border-box;margin-right:-4px;background-image:url()}.gitter-chat-embed-action-bar-item-collapse-chat{box-sizing:border-box;background-image:url()}.gitter-open-chat-button{z-index:100;position:fixed;bottom:0;right:10px;padding:1em 3em;background-color:#36bc98;border:0;border-top-left-radius:.5em;border-top-right-radius:.5em;font-family:sans-serif;font-size:12px;letter-spacing:1px;text-transform:uppercase;text-align:center;text-decoration:none;cursor:pointer;cursor:hand;transition:all .3s ease}.gitter-open-chat-button,.gitter-open-chat-button:visited{box-sizing:border-box;color:#fff}.gitter-open-chat-button:focus,.gitter-open-chat-button:hover{box-sizing:border-box;background-color:#3ea07f;color:#fff}.gitter-open-chat-button:focus{box-sizing:border-box;box-shadow:0 0 8px rgba(62,160,127,.6);outline:none}.gitter-open-chat-button:active{box-sizing:border-box;color:#eee}.gitter-open-chat-button.is-collapsed{box-sizing:border-box;transform:translateY(120%)}',""])},function(t,e){t.exports=function(){var t=[];return t.toString=function(){for(var t=[],e=0;e` style, as it will: + +* Ensure that everything remains wrapped in a single outermost `ServerErrorMiddleware`. +* Preserves the top-level `app` instance. ## Third party middleware @@ -216,3 +255,35 @@ when proxy servers are being used, based on the `X-Forwarded-Proto` and `X-Forwa A middleware class to emit timing information (cpu and wall time) for each request which passes through it. Includes examples for how to emit these timings as statsd metrics. + +#### [datasette-auth-github](https://github.com/simonw/datasette-auth-github) + +This middleware adds authentication to any ASGI application, requiring users to sign in +using their GitHub account (via [OAuth](https://developer.github.com/apps/building-oauth-apps/authorizing-oauth-apps/)). +Access can be restricted to specific users or to members of specific GitHub organizations or teams. + +#### [PrometheusMiddleware](https://github.com/perdy/starlette-prometheus) + +A middleware class for capturing Prometheus metrics related to requests and responses, including in progress requests, timing... + +#### [BugsnagMiddleware](https://github.com/ashinabraham/starlette-bugsnag) + +A middleware class for logging exceptions to [Bugsnag](https://www.bugsnag.com/). + +#### [EarlyDataMiddleware](https://github.com/HarrySky/starlette-early-data) + +Middleware and decorator for detecting and denying [TLSv1.3 early data](https://tools.ietf.org/html/rfc8470) requests. + +#### [AuthlibMiddleware](https://github.com/aogier/starlette-authlib) + +A drop-in replacement for Starlette session middleware, using [authlib's jwt](https://docs.authlib.org/en/latest/jose/jwt.html) +module. + +#### [StarletteOpentracing](https://github.com/acidjunk/starlette-opentracing) + +A middleware class that emits tracing info to [OpenTracing.io](https://opentracing.io/) compatible tracers and +can be used to profile and monitor distributed applications. + +#### [RateLimitMiddleware](https://github.com/abersheeran/asgi-ratelimit) + +A rate limit middleware. Regular expression matches url; flexible rules; highly customizable. Very easy to use. diff --git a/docs/release-notes.md b/docs/release-notes.md index cfe83993d..7305046f9 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -1,4 +1,162 @@ -## 0.12.0 +## 0.16.0 + +July 19, 2021 + +### Added + * Added [Encode](https://github.com/sponsors/encode) funding option + [#1219](https://github.com/encode/starlette/pull/1219) + +### Fixed + * `starlette.websockets.WebSocket` instances are now hashable and compare by identity + [#1039](https://github.com/encode/starlette/pull/1039) + * A number of fixes related to running task groups in lifespan + [#1213](https://github.com/encode/starlette/pull/1213), + [#1227](https://github.com/encode/starlette/pull/1227) + +### Deprecated/removed + * The method `starlette.templates.Jinja2Templates.get_env` was removed + [#1218](https://github.com/encode/starlette/pull/1218) + * The ClassVar `starlette.testclient.TestClient.async_backend` was removed, + the backend is now configured using constructor kwargs + [#1211](https://github.com/encode/starlette/pull/1211) + * Passing an Async Generator Function or a Generator Function to `starlette.router.Router(lifespan_context=)` is deprecated. You should wrap your lifespan in `@contextlib.asynccontextmanager`. + [#1227](https://github.com/encode/starlette/pull/1227) + [#1110](https://github.com/encode/starlette/pull/1110) + +## 0.15.0 + +June 23, 2021 + +This release includes major changes to the low-level asynchronous parts of Starlette. As a result, +**Starlette now depends on [AnyIO](https://anyio.readthedocs.io/en/stable/)** and some minor API +changes have occurred. Another significant change with this release is the +**deprecation of built-in GraphQL support**. + +### Added +* Starlette now supports [Trio](https://trio.readthedocs.io/en/stable/) as an async runtime via + AnyIO - [#1157](https://github.com/encode/starlette/pull/1157). +* `TestClient.websocket_connect()` now must be used as a context manager. +* Initial support for Python 3.10 - [#1201](https://github.com/encode/starlette/pull/1201). +* The compression level used in `GZipMiddleware` is now adjustable - + [#1128](https://github.com/encode/starlette/pull/1128). + +### Fixed +* Several fixes to `CORSMiddleware`. See [#1111](https://github.com/encode/starlette/pull/1111), + [#1112](https://github.com/encode/starlette/pull/1112), + [#1113](https://github.com/encode/starlette/pull/1113), + [#1199](https://github.com/encode/starlette/pull/1199). +* Improved exception messages in the case of duplicated path parameter names - + [#1177](https://github.com/encode/starlette/pull/1177). +* `RedirectResponse` now uses `quote` instead of `quote_plus` encoding for the `Location` header + to better match the behaviour in other frameworks such as Django - + [#1164](https://github.com/encode/starlette/pull/1164). +* Exception causes are now preserved in more cases - + [#1158](https://github.com/encode/starlette/pull/1158). +* Session cookies now use the ASGI root path in the case of mounted applications - + [#1147](https://github.com/encode/starlette/pull/1147). +* Fixed a cache invalidation bug when static files were deleted in certain circumstances - + [#1023](https://github.com/encode/starlette/pull/1023). +* Improved memory usage of `BaseHTTPMiddleware` when handling large responses - + [#1012](https://github.com/encode/starlette/issues/1012) fixed via #1157 + +### Deprecated/removed + +* Built-in GraphQL support via the `GraphQLApp` class has been deprecated and will be removed in a + future release. Please see [#619](https://github.com/encode/starlette/issues/619). GraphQL is not + supported on Python 3.10. +* The `executor` parameter to `GraphQLApp` was removed. Use `executor_class` instead. +* The `workers` parameter to `WSGIMiddleware` was removed. This hasn't had any effect since + Starlette v0.6.3. + +## 0.14.2 + +February 2, 2021 + +### Fixed + +* Fixed `ServerErrorMiddleware` compatibility with Python 3.9.1/3.8.7 when debug mode is enabled - + [#1132](https://github.com/encode/starlette/pull/1132). +* Fixed unclosed socket `ResourceWarning`s when using the `TestClient` with WebSocket endpoints - + #1132. +* Improved detection of `async` endpoints wrapped in `functools.partial` on Python 3.8+ - + [#1106](https://github.com/encode/starlette/pull/1106). + + +## 0.14.1 + +November 9th, 2020 + +### Removed + +* `UJSONResponse` was removed (this change was intended to be included in 0.14.0). Please see the + [documentation](https://www.starlette.io/responses/#custom-json-serialization) for how to + implement responses using custom JSON serialization - + [#1074](https://github.com/encode/starlette/pull/1047). + +## 0.14.0 + +November 8th, 2020 + +### Added + +* Starlette now officially supports Python3.9. +* In `StreamingResponse`, allow custom async iterator such as objects from classes implementing `__aiter__`. +* Allow usage of `functools.partial` async handlers in Python versions 3.6 and 3.7. +* Add 418 I'm A Teapot status code. + +### Changed + +* Create tasks from handler coroutines before sending them to `asyncio.wait`. +* Use `format_exception` instead of `format_tb` in `ServerErrorMiddleware`'s `debug` responses. +* Be more lenient with handler arguments when using the `requires` decorator. + +## 0.13.8 + +* Revert `Queue(maxsize=1)` fix for `BaseHTTPMiddleware` middleware classes and streaming responses. + +* The `StaticFiles` constructor now allows `pathlib.Path` in addition to strings for its `directory` argument. + +## 0.13.7 + +* Fix high memory usage when using `BaseHTTPMiddleware` middleware classes and streaming responses. + +## 0.13.6 + +* Fix 404 errors with `StaticFiles`. + +## 0.13.5 + +* Add support for `Starlette(lifespan=...)` functions. +* More robust path-traversal check in StaticFiles app. +* Fix WSGI PATH_INFO encoding. +* RedirectResponse now accepts optional background parameter +* Allow path routes to contain regex meta characters +* Treat ASGI HTTP 'body' as an optional key. +* Don't use thread pooling for writing to in-memory upload files. + +## 0.13.0 + +* Switch to promoting application configuration on init style everywhere. + This means dropping the decorator style in favour of declarative routing + tables and middleware definitions. + +## 0.12.12 + +* Fix `request.url_for()` for the Mount-within-a-Mount case. + +## 0.12.11 + +* Fix `request.url_for()` when an ASGI `root_path` is being used. + +## 0.12.1 + +* Add `URL.include_query_params(**kwargs)` +* Add `URL.replace_query_params(**kwargs)` +* Add `URL.remove_query_params(param_names)` +* `request.state` properly persisting across middleware. +* Added `request.scope` interface. + +## 0.12.0 * Switch to ASGI 3.0. * Fixes to CORS middleware. diff --git a/docs/requests.md b/docs/requests.md index 608caba7e..a72cb75dc 100644 --- a/docs/requests.md +++ b/docs/requests.md @@ -11,16 +11,12 @@ from starlette.requests import Request from starlette.responses import Response -class App: - def __init__(self, scope): - assert scope['type'] == 'http' - self.scope = scope - - async def __call__(self, receive, send): - request = Request(self.scope, receive) - content = '%s %s' % (request.method, request.url.path) - response = Response(content, media_type='text/plain') - await response(receive, send) +async def app(scope, receive, send): + assert scope['type'] == 'http' + request = Request(scope, receive) + content = '%s %s' % (request.method, request.url.path) + response = Response(content, media_type='text/plain') + await response(scope, receive, send) ``` Requests present a mapping interface, so you can use them in the same @@ -77,6 +73,8 @@ Cookies are exposed as a regular dictionary interface. For example: `request.cookies.get('mycookie')` +Cookies are ignored in case of an invalid cookie. (RFC2109) + #### Body There are a few different interfaces for returning the body of the request: @@ -93,19 +91,15 @@ You can also access the request body as a stream, using the `async for` syntax: from starlette.requests import Request from starlette.responses import Response - -class App: - def __init__(self, scope): - assert scope['type'] == 'http' - self.scope = scope - - async def __call__(self, receive, send): - request = Request(self.scope, receive) - body = b'' - async for chunk in request.stream(): - body += chunk - response = Response(body, media_type='text/plain') - await response(receive, send) + +async def app(scope, receive, send): + assert scope['type'] == 'http' + request = Request(scope, receive) + body = b'' + async for chunk in request.stream(): + body += chunk + response = Response(body, media_type='text/plain') + await response(scope, receive, send) ``` If you access `.stream()` then the byte chunks are provided without storing @@ -148,6 +142,10 @@ filename = form["upload_file"].filename contents = await form["upload_file"].read() ``` +#### Application + +The originating Starlette application can be accessed via `request.app`. + #### Other state If you want to store additional information on the request you can do so diff --git a/docs/responses.md b/docs/responses.md index 9aaf24e1f..c4cd84ed3 100644 --- a/docs/responses.md +++ b/docs/responses.md @@ -22,20 +22,16 @@ ASGI application instance. from starlette.responses import Response -class App: - def __init__(self, scope): - assert scope['type'] == 'http' - self.scope = scope - - async def __call__(self, receive, send): - response = Response('Hello, world!', media_type='text/plain') - await response(receive, send) +async def app(scope, receive, send): + assert scope['type'] == 'http' + response = Response('Hello, world!', media_type='text/plain') + await response(scope, receive, send) ``` #### Set Cookie Starlette provides a `set_cookie` method to allow you to set cookies on the response object. -Signature: `Response.set_cookie(key, value, max_age=None, expires=None, path="/", domain=None, secure=False, httponly=False)` +Signature: `Response.set_cookie(key, value, max_age=None, expires=None, path="/", domain=None, secure=False, httponly=False, samesite="lax")` * `key` - A string that will be the cookie's key. * `value` - A string that will be the cookie's value. @@ -44,7 +40,8 @@ Signature: `Response.set_cookie(key, value, max_age=None, expires=None, path="/" * `path` - A string that specifies the subset of routes to which the cookie will apply. `Optional` * `domain` - A string that specifies the domain for which the cookie is valid. `Optional` * `secure` - A bool indicating that the cookie will only be sent to the server if request is made using SSL and the HTTPS protocol. `Optional` -* `httponly` - A bool indicating that the cookie cannot be accessed via Javascript through `Document.cookie` property, the `XMLHttpRequest` or `Request` APIs. `Optional` +* `httponly` - A bool indicating that the cookie cannot be accessed via JavaScript through `Document.cookie` property, the `XMLHttpRequest` or `Request` APIs. `Optional` +* `samesite` - A string that specifies the samesite strategy for the cookie. Valid values are `'lax'`, `'strict'` and `'none'`. Defaults to `'lax'`. `Optional` #### Delete Cookie @@ -61,32 +58,24 @@ Takes some text or bytes and returns an HTML response. from starlette.responses import HTMLResponse -class App: - def __init__(self, scope): - assert scope['type'] == 'http' - self.scope = scope - - async def __call__(self, receive, send): - response = HTMLResponse('

Hello, world!

') - await response(receive, send) +async def app(scope, receive, send): + assert scope['type'] == 'http' + response = HTMLResponse('

Hello, world!

') + await response(scope, receive, send) ``` ### PlainTextResponse -Takes some text or bytes and returns an plain text response. +Takes some text or bytes and returns a plain text response. ```python from starlette.responses import PlainTextResponse -class App: - def __init__(self, scope): - assert scope['type'] == 'http' - self.scope = scope - - async def __call__(self, receive, send): - response = PlainTextResponse('Hello, world!') - await response(receive, send) +async def app(scope, receive, send): + assert scope['type'] == 'http' + response = PlainTextResponse('Hello, world!') + await response(scope, receive, send) ``` ### JSONResponse @@ -97,59 +86,51 @@ Takes some data and returns an `application/json` encoded response. from starlette.responses import JSONResponse -class App: - def __init__(self, scope): - assert scope['type'] == 'http' - self.scope = scope - - async def __call__(self, receive, send): - response = JSONResponse({'hello': 'world'}) - await response(receive, send) +async def app(scope, receive, send): + assert scope['type'] == 'http' + response = JSONResponse({'hello': 'world'}) + await response(scope, receive, send) ``` -### UJSONResponse +#### Custom JSON serialization -A JSON response class that uses the optimised `ujson` library for serialisation. +If you need fine-grained control over JSON serialization, you can subclass +`JSONResponse` and override the `render` method. -Using `ujson` will result in faster JSON serialisation, but is also less careful -than Python's built-in implementation in how it handles some edge-cases. - -In general you *probably* want to stick with `JSONResponse` by default unless -you are micro-optimising a particular endpoint. +For example, if you wanted to use a third-party JSON library such as +[orjson](https://pypi.org/project/orjson/): ```python -from starlette.responses import UJSONResponse +from typing import Any +import orjson +from starlette.responses import JSONResponse -class App: - def __init__(self, scope): - assert scope['type'] == 'http' - self.scope = scope - async def __call__(self, receive, send): - response = UJSONResponse({'hello': 'world'}) - await response(receive, send) +class OrjsonResponse(JSONResponse): + def render(self, content: Any) -> bytes: + return orjson.dumps(content) ``` +In general you *probably* want to stick with `JSONResponse` by default unless +you are micro-optimising a particular endpoint or need to serialize non-standard +object types. + ### RedirectResponse -Returns an HTTP redirect. Uses a 302 status code by default. +Returns an HTTP redirect. Uses a 307 status code by default. ```python from starlette.responses import PlainTextResponse, RedirectResponse -class App: - def __init__(self, scope): - assert scope['type'] == 'http' - self.scope = scope - - async def __call__(self, receive, send): - if self.scope['path'] != '/': - response = RedirectResponse(url='/') - else: - response = PlainTextResponse('Hello, world!') - await response(receive, send) +async def app(scope, receive, send): + assert scope['type'] == 'http' + if scope['path'] != '/': + response = RedirectResponse(url='/') + else: + response = PlainTextResponse('Hello, world!') + await response(scope, receive, send) ``` ### StreamingResponse @@ -169,15 +150,11 @@ async def slow_numbers(minimum, maximum): yield('') -class App: - def __init__(self, scope): - assert scope['type'] == 'http' - self.scope = scope - - async def __call__(self, receive, send): - generator = slow_numbers(1, 10) - response = StreamingResponse(generator, media_type='text/html') - await response(receive, send) +async def app(scope, receive, send): + assert scope['type'] == 'http' + generator = slow_numbers(1, 10) + response = StreamingResponse(generator, media_type='text/html') + await response(scope, receive, send) ``` Have in mind that file-like objects (like those created by `open()`) are normal iterators. So, you can return them directly in a `StreamingResponse`. @@ -199,12 +176,15 @@ File responses will include appropriate `Content-Length`, `Last-Modified` and `E from starlette.responses import FileResponse -class App: - def __init__(self, scope): - assert scope['type'] == 'http' - self.scope = scope - - async def __call__(self, receive, send): - response = FileResponse('statics/favicon.ico') - await response(receive, send) +async def app(scope, receive, send): + assert scope['type'] == 'http' + response = FileResponse('statics/favicon.ico') + await response(scope, receive, send) ``` + +## Third party middleware + +### [SSEResponse(EventSourceResponse)](https://github.com/sysid/sse-starlette) + +Server Sent Response implements the ServerSentEvent Protocol: https://www.w3.org/TR/2009/WD-eventsource-20090421. +It enables event streaming from the server to the client without the complexity of websockets. diff --git a/docs/routing.md b/docs/routing.md index 9b0329b0c..5b712d1a4 100644 --- a/docs/routing.md +++ b/docs/routing.md @@ -1,65 +1,219 @@ - -Starlette includes a `Router` class which is an ASGI application that -dispatches incoming requests to endpoints or submounted applications. +Starlette has a simple but capable request routing system. A routing table +is defined as a list of routes, and passed when instantiating the application. ```python -from starlette.routing import Mount, Route, Router -from myproject import Homepage, SubMountedApp +from starlette.applications import Starlette +from starlette.responses import PlainTextResponse +from starlette.routing import Route -app = Router([ - Route('/', endpoint=Homepage, methods=['GET']), - Mount('/mount', app=SubMountedApp) -]) +async def homepage(request): + return PlainTextResponse("Homepage") + +async def about(request): + return PlainTextResponse("About") + + +routes = [ + Route("/", endpoint=homepage), + Route("/about", endpoint=about), +] + +app = Starlette(routes=routes) ``` +The `endpoint` argument can be one of: + +* A regular function or async function, which accepts a single `request` +argument and which should return a response. +* A class that implements the ASGI interface, such as Starlette's [class based +views](endpoints.md). + +## Path Parameters + Paths can use URI templating style to capture path components. ```python -Route('/users/{username}', endpoint=User, methods=['GET']) +Route('/users/{username}', user) ``` +By default this will capture characters up to the end of the path or the next `/`. + +You can use convertors to modify what is captured. Four convertors are available: -Convertors for `int`, `float`, and `path` are also available: +* `str` returns a string, and is the default. +* `int` returns a Python integer. +* `float` returns a Python float. +* `uuid` return a Python `uuid.UUID` instance. +* `path` returns the rest of the path, including any additional `/` characters. + +Convertors are used by prefixing them with a colon, like so: ```python -Route('/users/{user_id:int}', endpoint=User, methods=['GET']) +Route('/users/{user_id:int}', user) +Route('/floating-point/{number:float}', floating_point) +Route('/uploaded/{rest_of_path:path}', uploaded) ``` Path parameters are made available in the request, as the `request.path_params` dictionary. -Because the target of a `Mount` is an ASGI instance itself, routers -allow for easy composition. For example: +```python +async def user(request): + user_id = request.path_params['user_id'] + ... +``` + +## Handling HTTP methods + +Routes can also specify which HTTP methods are handled by an endpoint: ```python -app = Router([ - Route('/', endpoint=Homepage, methods=['GET']), - Mount('/users', app=Router([ - Route('/', endpoint=Users, methods=['GET', 'POST']), - Route('/{username}', endpoint=User, methods=['GET']), - ])) -]) +Route('/users/{user_id:int}', user, methods=["GET", "POST"]) +``` + +By default function endpoints will only accept `GET` requests, unless specified. + +## Submounting routes + +In large applications you might find that you want to break out parts of the +routing table, based on a common path prefix. + +```python +routes = [ + Route('/', homepage), + Mount('/users', routes=[ + Route('/', users, methods=['GET', 'POST']), + Route('/{username}', user), + ]) +] +``` + +This style allows you to define different subsets of the routing table in +different parts of your project. + +```python +from myproject import users, auth + +routes = [ + Route('/', homepage), + Mount('/users', routes=users.routes), + Mount('/auth', routes=auth.routes), +] +``` + +You can also use mounting to include sub-applications within your Starlette +application. For example... + +```python +# This is a standalone static files server: +app = StaticFiles(directory="static") + +# This is a static files server mounted within a Starlette application, +# underneath the "/static" path. +routes = [ + ... + Mount("/static", app=StaticFiles(directory="static"), name="static") +] + +app = Starlette(routes=routes) +``` + +## Reverse URL lookups + +You'll often want to be able to generate the URL for a particular route, +such as in cases where you need to return a redirect response. + +```python +routes = [ + Route("/", homepage, name="homepage") +] + +# We can use the following to return a URL... +url = request.url_for("homepage") +``` + +URL lookups can include path parameters... + +```python +routes = [ + Route("/users/{username}", user, name="user_detail") +] + +# We can use the following to return a URL... +url = request.url_for("user_detail", username=...) ``` -The router will respond with "404 Not found" or "405 Method not allowed" -responses for requests which do not match. +If a `Mount` includes a `name`, then submounts should use a `{prefix}:{name}` +style for reverse URL lookups. + +```python +routes = [ + Mount("/users", name="users", routes=[ + Route("/", user, name="user_list"), + Route("/{username}", user, name="user_detail") + ]) +] + +# We can use the following to return URLs... +url = request.url_for("users:user_list") +url = request.url_for("users:user_detail", username=...) +``` + +Mounted applications may include a `path=...` parameter. + +```python +routes = [ + ... + Mount("/static", app=StaticFiles(directory="static"), name="static") +] + +# We can use the following to return URLs... +url = request.url_for("static", path="/css/base.css") +``` + +For cases where there is no `request` instance, you can make reverse lookups +against the application, although these will only return the URL path. + +```python +url = app.url_path_for("user_detail", username=...) +``` + +## Route priority Incoming paths are matched against each `Route` in order. -If you need to have a `Route` with a fixed path that would also match a -`Route` with parameters you should add the `Route` with the fixed path first. +In cases where more that one route could match an incoming path, you should +take care to ensure that more specific routes are listed before general cases. -For example, with an additional `Route` like: +For example: ```python -Route('/users/me', endpoint=UserMe, methods=['GET']) +# Don't do this: `/users/me` will never match incoming requests. +routes = [ + Route('/users/{username}', user), + Route('/users/me', current_user), +] + +# Do this: `/users/me` is tested first. +routes = [ + Route('/users/me', current_user), + Route('/users/{username}', user), +] ``` -You should add that route for `/users/me` before the one for `/users/{username}`: +## Working with Router instances + +If you're working at a low-level you might want to use a plain `Router` +instance, rather that creating a `Starlette` application. This gives you +a lightweight ASGI application that just provides the application routing, +without wrapping it up in any middleware. ```python -app = Router([ - Route('/users/me', endpoint=UserMe, methods=['GET']), - Route('/users/{username}', endpoint=User, methods=['GET']), +app = Router(routes=[ + Route('/', homepage), + Mount('/users', routes=[ + Route('/', users, methods=['GET', 'POST']), + Route('/{username}', user), + ]) ]) ``` diff --git a/docs/schemas.md b/docs/schemas.md index 5f0a2fb38..2530ba8d1 100644 --- a/docs/schemas.md +++ b/docs/schemas.md @@ -11,16 +11,14 @@ the docstrings. ```python from starlette.applications import Starlette +from starlette.routing import Route from starlette.schemas import SchemaGenerator schemas = SchemaGenerator( {"openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}} ) -app = Starlette() - -@app.route("/users", methods=["GET"]) def list_users(request): """ responses: @@ -32,7 +30,6 @@ def list_users(request): raise NotImplementedError() -@app.route("/users", methods=["POST"]) def create_user(request): """ responses: @@ -44,9 +41,17 @@ def create_user(request): raise NotImplementedError() -@app.route("/schema", methods=["GET"], include_in_schema=False) def openapi_schema(request): return schemas.OpenAPIResponse(request=request) + + +routes = [ + Route("/users", endpoint=list_users, methods=["GET"]) + Route("/users", endpoint=create_user, methods=["POST"]) + Route("/schema", endpoint=openapi_schema, include_in_schema=False) +] + +app = Starlette() ``` We can now access an OpenAPI schema at the "/schema" endpoint. @@ -86,7 +91,7 @@ if __name__ == '__main__': assert sys.argv[-1] in ("run", "schema"), "Usage: example.py [run|schema]" if sys.argv[-1] == "run": - uvicorn.run(app, host='0.0.0.0', port=8000) + uvicorn.run("example:app", host='0.0.0.0', port=8000) elif sys.argv[-1] == "schema": schema = schemas.get_schema(routes=app.routes) print(yaml.dump(schema, default_flow_style=False)) diff --git a/docs/server-push.md b/docs/server-push.md new file mode 100644 index 000000000..c97014d97 --- /dev/null +++ b/docs/server-push.md @@ -0,0 +1,36 @@ + +Starlette includes support for HTTP/2 and HTTP/3 server push, making it +possible to push resources to the client to speed up page load times. + +### `Request.send_push_promise` + +Used to initiate a server push for a resource. If server push is not available +this method does nothing. + +Signature: `send_push_promise(path)` + +* `path` - A string denoting the path of the resource. + +```python +from starlette.applications import Starlette +from starlette.responses import HTMLResponse +from starlette.routing import Route, Mount +from starlette.staticfiles import StaticFiles + + +async def homepage(request): + """ + Homepage which uses server push to deliver the stylesheet. + """ + await request.send_push_promise("/static/style.css") + return HTMLResponse( + '' + ) + +routes = [ + Route("/", endpoint=homepage), + Mount("/static", StaticFiles(directory="static"), name="static") +] + +app = Starlette(routes=routes) +``` diff --git a/docs/staticfiles.md b/docs/staticfiles.md index a3710473a..d8786af4d 100644 --- a/docs/staticfiles.md +++ b/docs/staticfiles.md @@ -5,39 +5,51 @@ Starlette also includes a `StaticFiles` class for serving files in a given direc Signature: `StaticFiles(directory=None, packages=None, check_dir=True)` -* `directory` - A string denoting a directory path. +* `directory` - A string or [os.Pathlike][pathlike] denoting a directory path. * `packages` - A list of strings of python packages. +* `html` - Run in HTML mode. Automatically loads `index.html` for directories if such file exist. * `check_dir` - Ensure that the directory exists upon instantiation. Defaults to `True`. You can combine this ASGI application with Starlette's routing to provide comprehensive static file serving. ```python -from starlette.routing import Router, Mount +from starlette.applications import Starlette +from starlette.routing import Mount from starlette.staticfiles import StaticFiles -app = Router(routes=[ +routes = [ + ... Mount('/static', app=StaticFiles(directory='static'), name="static"), -]) +] + +app = Starlette(routes=routes) ``` Static files will respond with "404 Not found" or "405 Method not allowed" -responses for requests which do not match. +responses for requests which do not match. In HTML mode if `404.html` file +exists it will be shown as 404 response. The `packages` option can be used to include "static" directories contained within a python package. The Python "bootstrap4" package is an example of this. ```python -from starlette.routing import Router, Mount +from starlette.applications import Starlette +from starlette.routing import Mount from starlette.staticfiles import StaticFiles -app = Router(routes=[ +routes=[ + ... Mount('/static', app=StaticFiles(directory='static', packages=['bootstrap4']), name="static"), -]) +] + +app = Starlette(routes=routes) ``` You may prefer to include static files directly inside the "static" directory rather than using Python packaging to include static files, but it can be useful for bundling up reusable components. + +[pathlike]: https://docs.python.org/3/library/os.html#os.PathLike diff --git a/docs/templates.md b/docs/templates.md index bcc8e56ab..b67669920 100644 --- a/docs/templates.md +++ b/docs/templates.md @@ -6,18 +6,22 @@ what you want to use by default. ```python from starlette.applications import Starlette +from starlette.routing import Route, Mount from starlette.templating import Jinja2Templates +from starlette.staticfiles import StaticFiles templates = Jinja2Templates(directory='templates') -app = Starlette(debug=True) -app.mount('/static', StaticFiles(directory='statics'), name='static') - - -@app.route('/') async def homepage(request): return templates.TemplateResponse('index.html', {'request': request}) + +routes = [ + Route('/', endpoint=homepage), + Mount('/static', StaticFiles(directory='static'), name='static') +] + +app = Starlette(debug=True, routes=routes) ``` Note that the incoming `request` instance must be included as part of the @@ -32,6 +36,20 @@ For example, we can link to static files from within our HTML templates: ``` +If you want to use [custom filters][jinja2], you will need to update the `env` +property of `Jinja2Templates`: + +```python +from commonmark import commonmark +from starlette.templating import Jinja2Templates + +def marked_filter(text): + return commonmark(text) + +templates = Jinja2Templates(directory='templates') +templates.env.filters['marked'] = marked_filter +``` + ## Testing template responses When using the test client, template responses include `.template` and `.context` @@ -55,3 +73,5 @@ database lookups, or other I/O operations. Instead we'd recommend that you ensure that your endpoints perform all I/O, for example, strictly evaluate any database queries within the view and include the final results in the context. + +[jinja2]: https://jinja.palletsprojects.com/en/2.10.x/api/?highlight=environment#writing-filters diff --git a/docs/testclient.md b/docs/testclient.md index 6065e4d09..a1861efec 100644 --- a/docs/testclient.md +++ b/docs/testclient.md @@ -7,18 +7,14 @@ from starlette.responses import HTMLResponse from starlette.testclient import TestClient -class App: - def __init__(self, scope): - assert scope['type'] == 'http' - self.scope = scope - - async def __call__(self, receive, send): - response = HTMLResponse('Hello, world!') - await response(receive, send) +async def app(scope, receive, send): + assert scope['type'] == 'http' + response = HTMLResponse('Hello, world!') + await response(scope, receive, send) def test_app(): - client = TestClient(App) + client = TestClient(app) response = client.get('/') assert response.status_code == 200 ``` @@ -35,6 +31,29 @@ application. Occasionally you might want to test the content of 500 error responses, rather than allowing client to raise the server exception. In this case you should use `client = TestClient(app, raise_server_exceptions=False)`. +### Selecting the Async backend + +`TestClient` takes arguments `backend` (a string) and `backend_options` (a dictionary). +These options are passed to `anyio.start_blocking_portal()`. See the [anyio documentation](https://anyio.readthedocs.io/en/stable/basics.html#backend-options) +for more information about the accepted backend options. +By default, `asyncio` is used with default options. + +To run `Trio`, pass `backend="trio"`. For example: + +```python +def test_app() + with TestClient(app, backend="trio") as client: + ... +``` + +To run `asyncio` with `uvloop`, pass `backend_options={"use_uvloop": True}`. For example: + +```python +def test_app() + with TestClient(app, backend_options={"use_uvloop": True}) as client: + ... +``` + ### Testing WebSocket sessions You can also test websocket sessions with the test client. @@ -48,20 +67,16 @@ from starlette.testclient import TestClient from starlette.websockets import WebSocket -class App: - def __init__(self, scope): - assert scope['type'] == 'websocket' - self.scope = scope - - async def __call__(self, receive, send): - websocket = WebSocket(self.scope, receive=receive, send=send) - await websocket.accept() - await websocket.send_text('Hello, world!') - await websocket.close() +async def app(scope, receive, send): + assert scope['type'] == 'websocket' + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + await websocket.send_text('Hello, world!') + await websocket.close() def test_app(): - client = TestClient(App) + client = TestClient(app) with client.websocket_connect('/') as websocket: data = websocket.receive_text() assert data == 'Hello, world!' @@ -78,7 +93,9 @@ always raised by the test client. * `.websocket_connect(url, subprotocols=None, **options)` - Takes the same set of arguments as `requests.get()`. -May raise `starlette.websockets.Disconnect` if the application does not accept the websocket connection. +May raise `starlette.websockets.WebSocketDisconnect` if the application does not accept the websocket connection. + +`websocket_connect()` must be used as a context manager (in a `with` block). #### Sending data @@ -92,7 +109,7 @@ May raise `starlette.websockets.Disconnect` if the application does not accept t * `.receive_bytes()` - Wait for incoming bytestring sent by the application and return it. * `.receive_json(mode="text")` - Wait for incoming json data sent by the application and return it. Use `mode="binary"` to send JSON over binary data frames. -May raise `starlette.websockets.Disconnect`. +May raise `starlette.websockets.WebSocketDisconnect`. #### Closing the connection diff --git a/docs/third-party-packages.md b/docs/third-party-packages.md index 656387ffe..71902ce83 100644 --- a/docs/third-party-packages.md +++ b/docs/third-party-packages.md @@ -20,12 +20,44 @@ Simple APISpec integration for Starlette. Document your REST API built with Starlette by declaring OpenAPI (Swagger) schemas in YAML format in your endpoint's docstrings. -### Starlette API +### SpecTree -GitHub +GitHub -That library aims to bring a layer on top of Starlette framework to provide useful mechanism for building APIs. Based on API Star. Some featuers: marshmallow schemas, dependency injection, auto generated api schemas, -auto generated docs. +Generate OpenAPI spec document and validate request & response with Python annotations. Less boilerplate code(no need for YAML). + +### Mangum + +GitHub + +Serverless ASGI adapter for AWS Lambda & API Gateway. + +### Nejma + +GitHub + +Manage and send messages to groups of channels using websockets. +Checkout nejma-chat, a simple chat application built using `nejma` and `starlette`. + +### ChannelBox + +GitHub + +Another solution for websocket broadcast. Send messages to channel groups from any part of your code. +Checkout channel-box-chat, a simple chat application built using `channel-box` and `starlette`. + +### Scout APM + +GitHub + +An APM (Application Performance Monitoring) solution that can +instrument your application to find performance bottlenecks. + +### Starlette Prometheus + +GitHub + +A plugin for providing an endpoint that exposes [Prometheus](https://prometheus.io/) metrics based on its [official python client](https://github.com/prometheus/client_python). ### webargs-starlette @@ -37,24 +69,32 @@ of [webargs](https://github.com/marshmallow-code/webargs). Allows you to parse querystring, JSON, form, headers, and cookies using type annotations. -### Mangum +### Authlib -GitHub +GitHub | +Documentation -Serverless ASGI adapter for AWS Lambda & API Gateway. +The ultimate Python library in building OAuth and OpenID Connect clients and servers. Check out how to integrate with [Starlette](https://docs.authlib.org/en/latest/client/starlette.html). -### Nejma +### Starlette OAuth2 API -GitHub +GitLab -Manage and send messages to groups of channels using websockets. -Checkout nejma-chat, a simple chat application built using `nejma` and `starlette`. +A starlette middleware to add authentication and authorization through JWTs. +It relies solely on an auth provider to issue access and/or id tokens to clients. + +### Starlette Context + +GitHub + +Middleware for Starlette that allows you to store and access the context data of a request. +Can be used with logging so logs automatically use request headers such as x-request-id or x-correlation-id. ## Frameworks ### Responder -GitHub | +GitHub | Documentation Async web service framework. Some Features: flask-style route expression, @@ -68,10 +108,11 @@ yaml support, OpenAPI schema generation, background tasks, graphql. High performance, easy to learn, fast to code, ready for production web API framework. Inspired by **APIStar**'s previous server system with type declarations for route parameters, based on the OpenAPI specification version 3.0.0+ (with JSON Schema), powered by **Pydantic** for the data handling. -### Bocadillo +### Flama + +GitHub | +Documentation -GitHub | -Documentation +Formerly Starlette API. -A modern Python web framework filled with asynchronous salsa. -Bocadillo is **async-first** and designed with productivity and simplicity in mind. It is not meant to be minimal: a **carefully chosen set of included batteries** helps you build performant web apps and services with minimal setup. +Flama aims to bring a layer on top of Starlette to provide an **easy to learn** and **fast to develop** approach for building **highly performant** GraphQL and REST APIs. In the same way of Starlette is, Flama is a perfect option for developing **asynchronous** and **production-ready** services. diff --git a/docs/websockets.md b/docs/websockets.md index 06994ea48..807496188 100644 --- a/docs/websockets.md +++ b/docs/websockets.md @@ -10,16 +10,11 @@ Signature: `WebSocket(scope, receive=None, send=None)` from starlette.websockets import WebSocket -class App: - def __init__(self, scope): - assert scope['type'] == 'websocket' - self.scope = scope - - async def __call__(self, receive, send): - websocket = WebSocket(self.scope, receive=receive, send=send) - await websocket.accept() - await websocket.send_text('Hello, world!') - await websocket.close() +async def app(scope, receive, send): + websocket = WebSocket(scope=scope, receive=receive, send=send) + await websocket.accept() + await websocket.send_text('Hello, world!') + await websocket.close() ``` WebSockets present a mapping interface, so you can use them in the same @@ -44,7 +39,7 @@ For example: `websocket.headers['sec-websocket-version']` #### Query Parameters -Headers are exposed as an immutable multi-dict. +Query parameters are exposed as an immutable multi-dict. For example: `websocket.query_params['search']` @@ -52,7 +47,7 @@ For example: `websocket.query_params['search']` Router path parameters are exposed as a dictionary interface. -For example: `request.path_params['username']` +For example: `websocket.path_params['username']` ### Accepting the connection @@ -73,7 +68,7 @@ Use `websocket.send_json(data, mode="binary")` to send JSON over binary data fra * `await websocket.receive_bytes()` * `await websocket.receive_json()` -May raise `starlette.websockets.Disconnect()`. +May raise `starlette.websockets.WebSocketDisconnect()`. JSON messages default to being received over text data frames, from version 0.10.0 onwards. Use `websocket.receive_json(data, mode="binary")` to receive JSON over binary data frames. diff --git a/mkdocs.yml b/mkdocs.yml index 2ec6a97c6..b1237aefb 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,36 +1,44 @@ site_name: Starlette site_description: The little ASGI library that shines. +site_url: https://www.starlette.io theme: - name: 'material' + name: 'material' repo_name: encode/starlette repo_url: https://github.com/encode/starlette edit_uri: "" nav: - - Introduction: 'index.md' - - Applications: 'applications.md' - - Requests: 'requests.md' - - Responses: 'responses.md' - - WebSockets: 'websockets.md' - - Routing: 'routing.md' - - Endpoints: 'endpoints.md' - - Middleware: 'middleware.md' - - Static Files: 'staticfiles.md' - - Templates: 'templates.md' - - Database: 'database.md' - - GraphQL: 'graphql.md' - - Authentication: 'authentication.md' - - API Schemas: 'schemas.md' - - Events: 'events.md' - - Background Tasks: 'background.md' - - Exceptions: 'exceptions.md' - - Configuration: 'config.md' - - Test Client: 'testclient.md' - - Third Party Packages: 'third-party-packages.md' - - Release Notes: 'release-notes.md' + - Introduction: 'index.md' + - Applications: 'applications.md' + - Requests: 'requests.md' + - Responses: 'responses.md' + - WebSockets: 'websockets.md' + - Routing: 'routing.md' + - Endpoints: 'endpoints.md' + - Middleware: 'middleware.md' + - Static Files: 'staticfiles.md' + - Templates: 'templates.md' + - Database: 'database.md' + - GraphQL: 'graphql.md' + - Authentication: 'authentication.md' + - API Schemas: 'schemas.md' + - Events: 'events.md' + - Background Tasks: 'background.md' + - Server Push: 'server-push.md' + - Exceptions: 'exceptions.md' + - Configuration: 'config.md' + - Test Client: 'testclient.md' + - Third Party Packages: 'third-party-packages.md' + - Release Notes: 'release-notes.md' markdown_extensions: - - markdown.extensions.codehilite: - guess_lang: false + - mkautodoc + - admonition + - pymdownx.highlight + - pymdownx.superfences + +extra_javascript: + - 'js/chat.js' + - 'js/sidecar-1.5.0.js' diff --git a/requirements.txt b/requirements.txt index c73de0bf5..abc7a3b0a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,22 +1,31 @@ # Optionals -aiofiles -graphene +graphene; python_version<'3.10' itsdangerous jinja2 python-multipart pyyaml requests -ujson # Testing autoflake -black +black==20.8b1 +coverage>=5.3 databases[sqlite] -isort +flake8 +isort==5.* mypy +types-requests +types-contextvars +types-PyYAML +types-dataclasses pytest -pytest-cov +trio # Documentation mkdocs mkdocs-material +mkautodoc + +# Packaging +twine +wheel diff --git a/scripts/README.md b/scripts/README.md index 84015423f..7388eac4b 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -2,7 +2,10 @@ * `scripts/install` - Install dependencies in a virtual environment. * `scripts/test` - Run the test suite. -* `scripts/lint` - Run the code linting. +* `scripts/lint` - Run the automated code linting/formatting tools. +* `scripts/check` - Run the code linting, checking that it passes. +* `scripts/coverage` - Check that code coverage is complete. +* `scripts/build` - Build source and wheel packages. * `scripts/publish` - Publish the latest version to PyPI. Styled after GitHub's ["Scripts to Rule Them All"](https://github.com/github/scripts-to-rule-them-all). diff --git a/scripts/build b/scripts/build new file mode 100755 index 000000000..1c47d2cc2 --- /dev/null +++ b/scripts/build @@ -0,0 +1,13 @@ +#!/bin/sh -e + +if [ -d 'venv' ] ; then + PREFIX="venv/bin/" +else + PREFIX="" +fi + +set -x + +${PREFIX}python setup.py sdist bdist_wheel +${PREFIX}twine check dist/* +${PREFIX}mkdocs build diff --git a/scripts/check b/scripts/check new file mode 100755 index 000000000..23d50c7c3 --- /dev/null +++ b/scripts/check @@ -0,0 +1,14 @@ +#!/bin/sh -e + +export PREFIX="" +if [ -d 'venv' ] ; then + export PREFIX="venv/bin/" +fi +export SOURCE_FILES="starlette tests" + +set -x + +${PREFIX}isort --check --diff --project=starlette $SOURCE_FILES +${PREFIX}black --check --diff $SOURCE_FILES +${PREFIX}flake8 $SOURCE_FILES +${PREFIX}mypy $SOURCE_FILES diff --git a/scripts/coverage b/scripts/coverage new file mode 100755 index 000000000..e871360d1 --- /dev/null +++ b/scripts/coverage @@ -0,0 +1,10 @@ +#!/bin/sh -e + +export PREFIX="" +if [ -d 'venv' ] ; then + export PREFIX="venv/bin/" +fi + +set -x + +${PREFIX}coverage report --show-missing --skip-covered --fail-under=100 diff --git a/scripts/docs b/scripts/docs new file mode 100755 index 000000000..4ac3beb7a --- /dev/null +++ b/scripts/docs @@ -0,0 +1,10 @@ +#!/bin/sh -e + +export PREFIX="" +if [ -d 'venv' ] ; then + export PREFIX="venv/bin/" +fi + +set -x + +${PREFIX}mkdocs serve diff --git a/scripts/install b/scripts/install index 852afa6c6..65885a720 100755 --- a/scripts/install +++ b/scripts/install @@ -1,24 +1,19 @@ #!/bin/sh -e # Use the Python executable provided from the `-p` option, or a default. -[[ $1 = "-p" ]] && PYTHON=$2 || PYTHON="python3" - -MIN_VERSION="(3, 6)" -VERSION_OK=`"$PYTHON" -c "import sys; print(sys.version_info[0:2] >= $MIN_VERSION and '1' or '');"` - -if [[ -z "$VERSION_OK" ]] ; then - PYTHON_VERSION=`"$PYTHON" -c "import sys; print('%s.%s' % sys.version_info[0:2]);"` - DISP_MIN_VERSION=`"$PYTHON" -c "print('%s.%s' % $MIN_VERSION)"` - echo "ERROR: Python $PYTHON_VERSION detected, but $DISP_MIN_VERSION+ is required." - echo "Please upgrade your Python distribution to install Starlette." - exit 1 -fi +[ "$1" = "-p" ] && PYTHON=$2 || PYTHON="python3" REQUIREMENTS="requirements.txt" VENV="venv" -PIP="$VENV/bin/pip" set -x -"$PYTHON" -m venv "$VENV" + +if [ -z "$GITHUB_ACTIONS" ]; then + "$PYTHON" -m venv "$VENV" + PIP="$VENV/bin/pip" +else + PIP="pip" +fi + "$PIP" install -r "$REQUIREMENTS" "$PIP" install -e . diff --git a/scripts/lint b/scripts/lint index 06e320938..92e121691 100755 --- a/scripts/lint +++ b/scripts/lint @@ -4,11 +4,10 @@ export PREFIX="" if [ -d 'venv' ] ; then export PREFIX="venv/bin/" fi +export SOURCE_FILES="starlette tests" set -x -${PREFIX}mypy starlette --ignore-missing-imports --disallow-untyped-defs -${PREFIX}autoflake --in-place --recursive starlette tests setup.py -${PREFIX}black starlette tests setup.py -${PREFIX}isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --apply starlette tests setup.py -${PREFIX}mypy starlette --ignore-missing-imports --disallow-untyped-defs +${PREFIX}autoflake --in-place --recursive $SOURCE_FILES +${PREFIX}isort --project=starlette $SOURCE_FILES +${PREFIX}black $SOURCE_FILES diff --git a/scripts/publish b/scripts/publish index 838260d51..667103d62 100755 --- a/scripts/publish +++ b/scripts/publish @@ -1,28 +1,26 @@ #!/bin/sh -e -export VERSION=`cat starlette/__init__.py | grep __version__ | sed "s/__version__ = //" | sed "s/'//g"` -export PREFIX="" +VERSION_FILE="starlette/__init__.py" + if [ -d 'venv' ] ; then - export PREFIX="venv/bin/" + PREFIX="venv/bin/" +else + PREFIX="" fi -scripts/clean +if [ ! -z "$GITHUB_ACTIONS" ]; then + git config --local user.email "41898282+github-actions[bot]@users.noreply.github.com" + git config --local user.name "GitHub Action" + + VERSION=`grep __version__ ${VERSION_FILE} | grep -o '[0-9][^"]*'` -if ! command -v "${PREFIX}twine" &>/dev/null ; then - echo "Unable to find the 'twine' command." - echo "Install from PyPI, using '${PREFIX}pip install twine'." + if [ "refs/tags/${VERSION}" != "${GITHUB_REF}" ] ; then + echo "GitHub Ref '${GITHUB_REF}' did not match package version '${VERSION}'" exit 1 + fi fi -find starlette -type f -name "*.py[co]" -delete -find starlette -type d -name __pycache__ -delete +set -x -${PREFIX}python setup.py sdist ${PREFIX}twine upload dist/* -${PREFIX}mkdocs gh-deploy - -echo "You probably want to also tag the version now:" -echo "git tag -a ${VERSION} -m 'version ${VERSION}'" -echo "git push --tags" - -scripts/clean +${PREFIX}mkdocs gh-deploy --force diff --git a/scripts/test b/scripts/test index 7a394e0b0..720a66392 100755 --- a/scripts/test +++ b/scripts/test @@ -1,16 +1,18 @@ -#!/bin/sh -e +#!/bin/sh export PREFIX="" if [ -d 'venv' ] ; then export PREFIX="venv/bin/" fi -export VERSION_SCRIPT="import sys; print('%s.%s' % sys.version_info[0:2])" -export PYTHON_VERSION=`python -c "$VERSION_SCRIPT"` +set -ex -set -x +if [ -z $GITHUB_ACTIONS ]; then + scripts/check +fi + +${PREFIX}coverage run -m pytest $@ -PYTHONPATH=. ${PREFIX}pytest --ignore venv --cov-config tests/.ignore_lifespan -W ignore::DeprecationWarning --cov=starlette --cov=tests --cov-fail-under=100 --cov-report=term-missing ${@} -${PREFIX}mypy starlette --ignore-missing-imports --disallow-untyped-defs -${PREFIX}autoflake --recursive starlette tests setup.py -${PREFIX}black starlette tests setup.py --check +if [ -z $GITHUB_ACTIONS ]; then + scripts/coverage +fi diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 000000000..f59a72029 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,39 @@ +[flake8] +ignore = W503, E203, B305 +max-line-length = 88 + +[mypy] +disallow_untyped_defs = True +ignore_missing_imports = True + +[mypy-tests.*] +disallow_untyped_defs = False +# https://github.com/encode/starlette/issues/1045 +# check_untyped_defs = True + +[tool:isort] +profile = black +combine_as_imports = True + +[tool:pytest] +addopts = + -rxXs + --strict-config + --strict-markers +xfail_strict=True +filterwarnings= + # Turn warnings that aren't filtered into exceptions + error + # Deprecated GraphQL (including https://github.com/graphql-python/graphene/issues/1055) + ignore: GraphQLApp is deprecated and will be removed in a future release\..*:DeprecationWarning + ignore: Using or importing the ABCs from 'collections' instead of from 'collections\.abc' is deprecated.*:DeprecationWarning + ignore: The 'context' alias has been deprecated. Please use 'context_value' instead\.:DeprecationWarning + ignore: The 'variables' alias has been deprecated. Please use 'variable_values' instead\.:DeprecationWarning + +[coverage:run] +source_pkgs = starlette, tests +# GraphQLApp incompatible with and untested on Python 3.10. It's deprecated, let's just ignore +# coverage for it until it's gone. +omit = + starlette/graphql.py + tests/test_graphql.py diff --git a/setup.py b/setup.py index de80e0ef9..31789fe09 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ import os import re -from setuptools import setup +from setuptools import setup, find_packages def get_version(package): @@ -23,17 +23,6 @@ def get_long_description(): return f.read() -def get_packages(package): - """ - Return root package and all sub-packages. - """ - return [ - dirpath - for dirpath, dirnames, filenames in os.walk(package) - if os.path.exists(os.path.join(dirpath, "__init__.py")) - ] - - setup( name="starlette", python_requires=">=3.6", @@ -45,19 +34,22 @@ def get_packages(package): long_description_content_type="text/markdown", author="Tom Christie", author_email="tom@tomchristie.com", - packages=get_packages("starlette"), + packages=find_packages(exclude=["tests*"]), package_data={"starlette": ["py.typed"]}, - data_files=[("", ["LICENSE.md"])], + include_package_data=True, + install_requires=[ + "anyio>=3.0.0,<4", + "typing_extensions; python_version < '3.8'", + "contextlib2 >= 21.6.0; python_version < '3.7'", + ], extras_require={ "full": [ - "aiofiles", - "graphene", + "graphene; python_version<'3.10'", "itsdangerous", "jinja2", "python-multipart", "pyyaml", "requests", - "ujson", ] }, classifiers=[ @@ -70,6 +62,9 @@ def get_packages(package): "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", ], zip_safe=False, ) diff --git a/starlette/__init__.py b/starlette/__init__.py index ea370a8e5..5a313cc7e 100644 --- a/starlette/__init__.py +++ b/starlette/__init__.py @@ -1 +1 @@ -__version__ = "0.12.0" +__version__ = "0.16.0" diff --git a/starlette/applications.py b/starlette/applications.py index 8e7694bc2..ea52ee70e 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -1,7 +1,8 @@ import typing -from starlette.datastructures import URLPath +from starlette.datastructures import State, URLPath from starlette.exceptions import ExceptionMiddleware +from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.errors import ServerErrorMiddleware from starlette.routing import BaseRoute, Router @@ -9,16 +10,87 @@ class Starlette: + """ + Creates an application instance. + + **Parameters:** + + * **debug** - Boolean indicating if debug tracebacks should be returned on errors. + * **routes** - A list of routes to serve incoming HTTP and WebSocket requests. + * **middleware** - A list of middleware to run for every request. A starlette + application will always automatically include two middleware classes. + `ServerErrorMiddleware` is added as the very outermost middleware, to handle + any uncaught errors occurring anywhere in the entire stack. + `ExceptionMiddleware` is added as the very innermost middleware, to deal + with handled exception cases occurring in the routing or endpoints. + * **exception_handlers** - A dictionary mapping either integer status codes, + or exception class types onto callables which handle the exceptions. + Exception handler callables should be of the form + `handler(request, exc) -> response` and may be be either standard functions, or + async functions. + * **on_startup** - A list of callables to run on application startup. + Startup handler callables do not take any arguments, and may be be either + standard functions, or async functions. + * **on_shutdown** - A list of callables to run on application shutdown. + Shutdown handler callables do not take any arguments, and may be be either + standard functions, or async functions. + """ + def __init__( - self, debug: bool = False, routes: typing.List[BaseRoute] = None + self, + debug: bool = False, + routes: typing.Sequence[BaseRoute] = None, + middleware: typing.Sequence[Middleware] = None, + exception_handlers: typing.Dict[ + typing.Union[int, typing.Type[Exception]], typing.Callable + ] = None, + on_startup: typing.Sequence[typing.Callable] = None, + on_shutdown: typing.Sequence[typing.Callable] = None, + lifespan: typing.Callable[["Starlette"], typing.AsyncContextManager] = None, ) -> None: + # The lifespan context function is a newer style that replaces + # on_startup / on_shutdown handlers. Use one or the other, not both. + assert lifespan is None or ( + on_startup is None and on_shutdown is None + ), "Use either 'lifespan' or 'on_startup'/'on_shutdown', not both." + self._debug = debug - self.router = Router(routes) - self.exception_middleware = ExceptionMiddleware(self.router, debug=debug) - self.error_middleware = ServerErrorMiddleware( - self.exception_middleware, debug=debug + self.state = State() + self.router = Router( + routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan + ) + self.exception_handlers = ( + {} if exception_handlers is None else dict(exception_handlers) + ) + self.user_middleware = [] if middleware is None else list(middleware) + self.middleware_stack = self.build_middleware_stack() + + def build_middleware_stack(self) -> ASGIApp: + debug = self.debug + error_handler = None + exception_handlers = {} + + for key, value in self.exception_handlers.items(): + if key in (500, Exception): + error_handler = value + else: + exception_handlers[key] = value + + middleware = ( + [Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)] + + self.user_middleware + + [ + Middleware( + ExceptionMiddleware, handlers=exception_handlers, debug=debug + ) + ] ) + app = self.router + for cls, options in reversed(middleware): + app = cls(app=app, **options) + return app + @property def routes(self) -> typing.List[BaseRoute]: return self.router.routes @@ -30,11 +102,19 @@ def debug(self) -> bool: @debug.setter def debug(self, value: bool) -> None: self._debug = value - self.exception_middleware.debug = value - self.error_middleware.debug = value + self.middleware_stack = self.build_middleware_stack() + + def url_path_for(self, name: str, **path_params: str) -> URLPath: + return self.router.url_path_for(name, **path_params) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + scope["app"] = self + await self.middleware_stack(scope, receive, send) + # The following usages are now discouraged in favour of configuration + #  during Starlette.__init__(...) def on_event(self, event_type: str) -> typing.Callable: - return self.router.lifespan.on_event(event_type) + return self.router.on_event(event_type) def mount(self, path: str, app: ASGIApp, name: str = None) -> None: self.router.mount(path, app=app, name=name) @@ -42,25 +122,20 @@ def mount(self, path: str, app: ASGIApp, name: str = None) -> None: def host(self, host: str, app: ASGIApp, name: str = None) -> None: self.router.host(host, app=app, name=name) - def add_middleware(self, middleware_class: type, **kwargs: typing.Any) -> None: - self.error_middleware.app = middleware_class( - self.error_middleware.app, **kwargs - ) + def add_middleware(self, middleware_class: type, **options: typing.Any) -> None: + self.user_middleware.insert(0, Middleware(middleware_class, **options)) + self.middleware_stack = self.build_middleware_stack() def add_exception_handler( self, exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], handler: typing.Callable, ) -> None: - if exc_class_or_status_code in (500, Exception): - self.error_middleware.handler = handler - else: - self.exception_middleware.add_exception_handler( - exc_class_or_status_code, handler - ) + self.exception_handlers[exc_class_or_status_code] = handler + self.middleware_stack = self.build_middleware_stack() def add_event_handler(self, event_type: str, func: typing.Callable) -> None: - self.router.lifespan.add_event_handler(event_type, func) + self.router.add_event_handler(event_type, func) def add_route( self, @@ -124,10 +199,3 @@ def decorator(func: typing.Callable) -> typing.Callable: return func return decorator - - def url_path_for(self, name: str, **path_params: str) -> URLPath: - return self.router.url_path_for(name, **path_params) - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - scope["app"] = self - await self.error_middleware(scope, receive, send) diff --git a/starlette/authentication.py b/starlette/authentication.py index 1fc6acaeb..44a9847fc 100644 --- a/starlette/authentication.py +++ b/starlette/authentication.py @@ -24,24 +24,23 @@ def requires( scopes_list = [scopes] if isinstance(scopes, str) else list(scopes) def decorator(func: typing.Callable) -> typing.Callable: - type = None sig = inspect.signature(func) for idx, parameter in enumerate(sig.parameters.values()): if parameter.name == "request" or parameter.name == "websocket": - type = parameter.name + type_ = parameter.name break else: raise Exception( f'No "request" or "websocket" argument on function "{func}"' ) - if type == "websocket": + if type_ == "websocket": # Handle websocket functions. (Always async) @functools.wraps(func) async def websocket_wrapper( *args: typing.Any, **kwargs: typing.Any ) -> None: - websocket = kwargs.get("websocket", args[idx]) + websocket = kwargs.get("websocket", args[idx] if args else None) assert isinstance(websocket, WebSocket) if not has_required_scope(websocket, scopes_list): @@ -57,12 +56,14 @@ async def websocket_wrapper( async def async_wrapper( *args: typing.Any, **kwargs: typing.Any ) -> Response: - request = kwargs.get("request", args[idx]) + request = kwargs.get("request", args[idx] if args else None) assert isinstance(request, Request) if not has_required_scope(request, scopes_list): if redirect is not None: - return RedirectResponse(url=request.url_for(redirect)) + return RedirectResponse( + url=request.url_for(redirect), status_code=303 + ) raise HTTPException(status_code=status_code) return await func(*args, **kwargs) @@ -72,12 +73,14 @@ async def async_wrapper( # Handle sync request/response functions. @functools.wraps(func) def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response: - request = kwargs.get("request", args[idx]) + request = kwargs.get("request", args[idx] if args else None) assert isinstance(request, Request) if not has_required_scope(request, scopes_list): if redirect is not None: - return RedirectResponse(url=request.url_for(redirect)) + return RedirectResponse( + url=request.url_for(redirect), status_code=303 + ) raise HTTPException(status_code=status_code) return func(*args, **kwargs) diff --git a/starlette/background.py b/starlette/background.py index b2a3cfe17..1160baeed 100644 --- a/starlette/background.py +++ b/starlette/background.py @@ -21,8 +21,8 @@ async def __call__(self) -> None: class BackgroundTasks(BackgroundTask): - def __init__(self, tasks: typing.Sequence[BackgroundTask] = []): - self.tasks = list(tasks) + def __init__(self, tasks: typing.Sequence[BackgroundTask] = None): + self.tasks = list(tasks) if tasks else [] def add_task( self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any diff --git a/starlette/concurrency.py b/starlette/concurrency.py index 56db3db3c..e89d1e047 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -1,18 +1,32 @@ -import asyncio import functools import typing from typing import Any, AsyncGenerator, Iterator +import anyio + try: - import contextvars # Python 3.7+ only. + import contextvars # Python 3.7+ only or via contextvars backport. except ImportError: # pragma: no cover contextvars = None # type: ignore +T = typing.TypeVar("T") + + +async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None: + async with anyio.create_task_group() as task_group: + + async def run(func: typing.Callable[[], typing.Coroutine]) -> None: + await func() + task_group.cancel_scope.cancel() + + for func, kwargs in args: + task_group.start_soon(run, functools.partial(func, **kwargs)) + + async def run_in_threadpool( - func: typing.Callable, *args: typing.Any, **kwargs: typing.Any -) -> typing.Any: - loop = asyncio.get_event_loop() + func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any +) -> T: if contextvars is not None: # pragma: no cover # Ensure we run in the same context child = functools.partial(func, *args, **kwargs) @@ -20,9 +34,9 @@ async def run_in_threadpool( func = context.run args = (child,) elif kwargs: # pragma: no cover - # loop.run_in_executor doesn't accept 'kwargs', so bind them in here + # run_sync doesn't accept 'kwargs', so bind them in here func = functools.partial(func, **kwargs) - return await loop.run_in_executor(None, func, *args) + return await anyio.to_thread.run_sync(func, *args) class _StopIteration(Exception): @@ -42,6 +56,6 @@ def _next(iterator: Iterator) -> Any: async def iterate_in_threadpool(iterator: Iterator) -> AsyncGenerator: while True: try: - yield await run_in_threadpool(_next, iterator) + yield await anyio.to_thread.run_sync(_next, iterator) except _StopIteration: break diff --git a/starlette/config.py b/starlette/config.py index 9a361c4e0..e9894e077 100644 --- a/starlette/config.py +++ b/starlette/config.py @@ -1,6 +1,7 @@ import os import typing from collections.abc import MutableMapping +from pathlib import Path class undefined: @@ -14,7 +15,7 @@ class EnvironError(Exception): class Environ(MutableMapping): def __init__(self, environ: typing.MutableMapping = os.environ): self._environ = environ - self._has_been_read = set() # type: typing.Set[typing.Any] + self._has_been_read: typing.Set[typing.Any] = set() def __getitem__(self, key: typing.Any) -> typing.Any: self._has_been_read.add(key) @@ -23,14 +24,16 @@ def __getitem__(self, key: typing.Any) -> typing.Any: def __setitem__(self, key: typing.Any, value: typing.Any) -> None: if key in self._has_been_read: raise EnvironError( - f"Attempting to set environ['{key}'], but the value has already be read." + f"Attempting to set environ['{key}'], but the value has already been " + "read." ) self._environ.__setitem__(key, value) def __delitem__(self, key: typing.Any) -> None: if key in self._has_been_read: raise EnvironError( - f"Attempting to delete environ['{key}'], but the value has already be read." + f"Attempting to delete environ['{key}'], but the value has already " + "been read." ) self._environ.__delitem__(key) @@ -46,20 +49,22 @@ def __len__(self) -> int: class Config: def __init__( - self, env_file: str = None, environ: typing.Mapping[str, str] = environ + self, + env_file: typing.Union[str, Path] = None, + environ: typing.Mapping[str, str] = environ, ) -> None: self.environ = environ - self.file_values = {} # type: typing.Dict[str, str] + self.file_values: typing.Dict[str, str] = {} if env_file is not None and os.path.isfile(env_file): self.file_values = self._read_file(env_file) def __call__( - self, key: str, cast: type = None, default: typing.Any = undefined + self, key: str, cast: typing.Callable = None, default: typing.Any = undefined ) -> typing.Any: return self.get(key, cast, default) def get( - self, key: str, cast: type = None, default: typing.Any = undefined + self, key: str, cast: typing.Callable = None, default: typing.Any = undefined ) -> typing.Any: if key in self.environ: value = self.environ[key] @@ -71,8 +76,8 @@ def get( return self._perform_cast(key, default, cast) raise KeyError(f"Config '{key}' is missing, and has no default.") - def _read_file(self, file_name: str) -> typing.Dict[str, str]: - file_values = {} # type: typing.Dict[str, str] + def _read_file(self, file_name: typing.Union[str, Path]) -> typing.Dict[str, str]: + file_values: typing.Dict[str, str] = {} with open(file_name) as input_file: for line in input_file.readlines(): line = line.strip() @@ -84,7 +89,7 @@ def _read_file(self, file_name: str) -> typing.Dict[str, str]: return file_values def _perform_cast( - self, key: str, value: typing.Any, cast: type = None + self, key: str, value: typing.Any, cast: typing.Callable = None ) -> typing.Any: if cast is None or value is None: return value diff --git a/starlette/convertors.py b/starlette/convertors.py index 854f3a7a9..7afe4c8d1 100644 --- a/starlette/convertors.py +++ b/starlette/convertors.py @@ -1,5 +1,6 @@ import math import typing +import uuid class Convertor: @@ -20,7 +21,7 @@ def convert(self, value: str) -> typing.Any: def to_string(self, value: typing.Any) -> str: value = str(value) - assert "/" not in value, "May not contain path seperators" + assert "/" not in value, "May not contain path separators" assert value, "Must not be empty" return value @@ -61,9 +62,20 @@ def to_string(self, value: typing.Any) -> str: return ("%0.20f" % value).rstrip("0").rstrip(".") +class UUIDConvertor(Convertor): + regex = "[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}" + + def convert(self, value: str) -> typing.Any: + return uuid.UUID(value) + + def to_string(self, value: typing.Any) -> str: + return str(value) + + CONVERTOR_TYPES = { "str": StringConvertor(), "path": PathConvertor(), "int": IntegerConvertor(), "float": FloatConvertor(), + "uuid": UUIDConvertor(), } diff --git a/starlette/datastructures.py b/starlette/datastructures.py index e7259da1e..5149a6e2e 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -21,7 +21,7 @@ def __init__( scheme = scope.get("scheme", "http") server = scope.get("server", None) path = scope.get("root_path", "") + scope["path"] - query_string = scope["query_string"] + query_string = scope.get("query_string", b"") host_header = None for key, value in scope["headers"]: @@ -44,7 +44,7 @@ def __init__( if query_string: url += "?" + query_string.decode() elif components: - assert not url, 'Cannot set both "scope" and "**components".' + assert not url, 'Cannot set both "url" and "**components".' url = URL("").replace(**components).components.geturl() self._url = url @@ -121,6 +121,27 @@ def replace(self, **kwargs: typing.Any) -> "URL": components = self.components._replace(**kwargs) return self.__class__(components.geturl()) + def include_query_params(self, **kwargs: typing.Any) -> "URL": + params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) + params.update({str(key): str(value) for key, value in kwargs.items()}) + query = urlencode(params.multi_items()) + return self.replace(query=query) + + def replace_query_params(self, **kwargs: typing.Any) -> "URL": + query = urlencode([(str(key), str(value)) for key, value in kwargs.items()]) + return self.replace(query=query) + + def remove_query_params( + self, keys: typing.Union[str, typing.Sequence[str]] + ) -> "URL": + if isinstance(keys, str): + keys = [keys] + params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) + for key in keys: + params.pop(key, None) + query = urlencode(params.multi_items()) + return self.replace(query=query) + def __eq__(self, other: typing.Any) -> bool: return str(self) == str(other) @@ -140,7 +161,7 @@ class URLPath(str): Used by the routing to return `url_path_for` matches. """ - def __new__(cls, path: str, protocol: str = "", host: str = "") -> str: + def __new__(cls, path: str, protocol: str = "", host: str = "") -> "URLPath": assert protocol in ("http", "websocket", "") return str.__new__(cls, path) # type: ignore @@ -159,12 +180,9 @@ def make_absolute_url(self, base_url: typing.Union[str, URL]) -> str: else: scheme = base_url.scheme - if self.host: - netloc = self.host - else: - netloc = base_url.netloc - - return str(URL(scheme=scheme, netloc=netloc, path=str(self))) + netloc = self.host or base_url.netloc + path = base_url.path.rstrip("/") + str(self) + return str(URL(scheme=scheme, netloc=netloc, path=path)) class Secret: @@ -209,7 +227,7 @@ def __repr__(self) -> str: return f"{class_name}({items!r})" def __str__(self) -> str: - return ", ".join([repr(item) for item in self]) + return ", ".join(repr(item) for item in self) class ImmutableMultiDict(typing.Mapping): @@ -224,11 +242,7 @@ def __init__( ) -> None: assert len(args) < 2, "Too many arguments." - if args: - value = args[0] - else: - value = [] - + value = args[0] if args else [] if kwargs: value = ( ImmutableMultiDict(value).multi_items() @@ -236,7 +250,7 @@ def __init__( ) if not value: - _items = [] # type: typing.List[typing.Tuple[typing.Any, typing.Any]] + _items: typing.List[typing.Tuple[typing.Any, typing.Any]] = [] elif hasattr(value, "multi_items"): value = typing.cast(ImmutableMultiDict, value) _items = list(value.multi_items()) @@ -376,9 +390,11 @@ def __init__( value = args[0] if args else [] if isinstance(value, str): - super().__init__(parse_qsl(value), **kwargs) + super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs) elif isinstance(value, bytes): - super().__init__(parse_qsl(value.decode("latin-1")), **kwargs) + super().__init__( + parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs + ) else: super().__init__(*args, **kwargs) # type: ignore self._list = [(str(k), str(v)) for k, v in self._list] @@ -398,26 +414,44 @@ class UploadFile: An uploaded file included as part of the request data. """ + spool_max_size = 1024 * 1024 + def __init__( self, filename: str, file: typing.IO = None, content_type: str = "" ) -> None: self.filename = filename self.content_type = content_type if file is None: - file = tempfile.SpooledTemporaryFile() + file = tempfile.SpooledTemporaryFile(max_size=self.spool_max_size) self.file = file + @property + def _in_memory(self) -> bool: + rolled_to_disk = getattr(self.file, "_rolled", True) + return not rolled_to_disk + async def write(self, data: typing.Union[bytes, str]) -> None: - await run_in_threadpool(self.file.write, data) + if self._in_memory: + self.file.write(data) # type: ignore + else: + await run_in_threadpool(self.file.write, data) - async def read(self, size: int = None) -> typing.Union[bytes, str]: + async def read(self, size: int = -1) -> typing.Union[bytes, str]: + if self._in_memory: + return self.file.read(size) return await run_in_threadpool(self.file.read, size) async def seek(self, offset: int) -> None: - await run_in_threadpool(self.file.seek, offset) + if self._in_memory: + self.file.seek(offset) + else: + await run_in_threadpool(self.file.seek, offset) async def close(self) -> None: - await run_in_threadpool(self.file.close) + if self._in_memory: + self.file.close() + else: + await run_in_threadpool(self.file.close) class FormData(ImmutableMultiDict): @@ -434,7 +468,7 @@ def __init__( ], **kwargs: typing.Union[str, UploadFile], ) -> None: - super().__init__(*args, **kwargs) # type: ignore + super().__init__(*args, **kwargs) async def close(self) -> None: for key, value in self.multi_items(): @@ -453,7 +487,7 @@ def __init__( raw: typing.List[typing.Tuple[bytes, bytes]] = None, scope: Scope = None, ) -> None: - self._list = [] # type: typing.List[typing.Tuple[bytes, bytes]] + self._list: typing.List[typing.Tuple[bytes, bytes]] = [] if headers is not None: assert raw is None, 'Cannot set both "headers" and "raw".' assert scope is None, 'Cannot set both "headers" and "scope".' @@ -605,3 +639,29 @@ def add_vary_header(self, vary: str) -> None: if existing is not None: vary = ", ".join([existing, vary]) self["vary"] = vary + + +class State: + """ + An object that can be used to store arbitrary state. + + Used for `request.state` and `app.state`. + """ + + def __init__(self, state: typing.Dict = None): + if state is None: + state = {} + super().__setattr__("_state", state) + + def __setattr__(self, key: typing.Any, value: typing.Any) -> None: + self._state[key] = value + + def __getattr__(self, key: typing.Any) -> typing.Any: + try: + return self._state[key] + except KeyError: + message = "'{}' object has no attribute '{}'" + raise AttributeError(message.format(self.__class__.__name__, key)) + + def __delattr__(self, key: typing.Any) -> None: + del self._state[key] diff --git a/starlette/endpoints.py b/starlette/endpoints.py index 2ba68c04b..2504dd84d 100644 --- a/starlette/endpoints.py +++ b/starlette/endpoints.py @@ -43,7 +43,7 @@ async def method_not_allowed(self, request: Request) -> Response: class WebSocketEndpoint: - encoding = None # May be "text", "bytes", or "json". + encoding: typing.Optional[str] = None # May be "text", "bytes", or "json". def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: assert scope["type"] == "websocket" @@ -71,7 +71,7 @@ async def dispatch(self) -> None: break except Exception as exc: close_code = status.WS_1011_INTERNAL_ERROR - raise exc from None + raise exc finally: await self.on_disconnect(websocket, close_code) diff --git a/starlette/exceptions.py b/starlette/exceptions.py index 7e3f272a8..557534d87 100644 --- a/starlette/exceptions.py +++ b/starlette/exceptions.py @@ -7,7 +7,7 @@ from starlette.requests import Request from starlette.responses import PlainTextResponse, Response from starlette.types import ASGIApp, Message, Receive, Scope, Send -from starlette.websockets import WebSocket, WebSocketClose +from starlette.websockets import WebSocket class HTTPException(Exception): @@ -17,6 +17,10 @@ def __init__(self, status_code: int, detail: str = None) -> None: self.status_code = status_code self.detail = detail + def __repr__(self) -> str: + class_name = self.__class__.__name__ + return f"{class_name}(status_code={self.status_code!r}, detail={self.detail!r})" + class WebSocketException(Exception): def __init__(self, code: int = status.WS_1008_POLICY_VIOLATION) -> None: @@ -29,21 +33,28 @@ def __init__(self, code: int = status.WS_1008_POLICY_VIOLATION) -> None: > other more suitable status code (e.g., 1003 or 1009) or if there > is a need to hide specific details about the policy. - Set `code` to any value allowed by - [the WebSocket specification](https://tools.ietf.org/html/rfc6455#section-7.4.1). + Set `code` to any value allowed by the + [WebSocket specification](https://tools.ietf.org/html/rfc6455#section-7.4.1). """ self.code = code class ExceptionMiddleware: - def __init__(self, app: ASGIApp, debug: bool = False) -> None: + def __init__( + self, app: ASGIApp, handlers: dict = None, debug: bool = False + ) -> None: self.app = app self.debug = debug # TODO: We ought to handle 404 cases if debug is set. - self._status_handlers = {} # type: typing.Dict[int, typing.Callable] - self._exception_handlers = { + self._status_handlers: typing.Dict[int, typing.Callable] = {} + self._exception_handlers: typing.Dict[ + typing.Type[Exception], typing.Callable + ] = { HTTPException: self.http_exception, WebSocketException: self.websocket_exception, - } # type: typing.Dict[typing.Type[Exception], typing.Callable] + } + if handlers is not None: + for key, value in handlers.items(): + self.add_exception_handler(key, value) def add_exception_handler( self, @@ -90,7 +101,7 @@ async def sender(message: Message) -> None: handler = self._lookup_exception_handler(exc) if handler is None: - raise exc from None + raise exc if response_started: msg = "Caught handled exception, but response already started." diff --git a/starlette/formparsers.py b/starlette/formparsers.py index cbbe8949b..1614a9d69 100644 --- a/starlette/formparsers.py +++ b/starlette/formparsers.py @@ -5,11 +5,11 @@ from starlette.datastructures import FormData, Headers, UploadFile try: - from multipart.multipart import parse_options_header import multipart + from multipart.multipart import parse_options_header except ImportError: # pragma: nocover - parse_options_header = None # type: ignore - multipart = None # type: ignore + parse_options_header = None + multipart = None class FormMessage(Enum): @@ -31,6 +31,13 @@ class MultiPartMessage(Enum): END = 8 +def _user_safe_decode(src: bytes, codec: str) -> str: + try: + return src.decode(codec) + except (UnicodeDecodeError, LookupError): + return src.decode("latin-1") + + class FormParser: def __init__( self, headers: Headers, stream: typing.AsyncGenerator[bytes, None] @@ -40,7 +47,7 @@ def __init__( ), "The `python-multipart` library must be installed to use form parsing." self.headers = headers self.stream = stream - self.messages = [] # type: typing.List[typing.Tuple[FormMessage, bytes]] + self.messages: typing.List[typing.Tuple[FormMessage, bytes]] = [] def on_field_start(self) -> None: message = (FormMessage.FIELD_START, b"") @@ -77,9 +84,7 @@ async def parse(self) -> FormData: field_name = b"" field_value = b"" - items = ( - [] - ) # type: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] + items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = [] # Feed the parser with data from the request. async for chunk in self.stream: @@ -101,8 +106,6 @@ async def parse(self) -> FormData: name = unquote_plus(field_name.decode("latin-1")) value = unquote_plus(field_value.decode("latin-1")) items.append((name, value)) - elif message_type == FormMessage.END: - pass return FormData(items) @@ -116,7 +119,7 @@ def __init__( ), "The `python-multipart` library must be installed to use form parsing." self.headers = headers self.stream = stream - self.messages = [] # type: typing.List[typing.Tuple[MultiPartMessage, bytes]] + self.messages: typing.List[typing.Tuple[MultiPartMessage, bytes]] = [] def on_part_begin(self) -> None: message = (MultiPartMessage.PART_BEGIN, b"") @@ -153,6 +156,9 @@ def on_end(self) -> None: async def parse(self) -> FormData: # Parse the Content-Type header to get the multipart boundary. content_type, params = parse_options_header(self.headers["Content-Type"]) + charset = params.get(b"charset", "utf-8") + if type(charset) == bytes: + charset = charset.decode("latin-1") boundary = params.get(b"boundary") # Callbacks dictionary. @@ -171,14 +177,13 @@ async def parse(self) -> FormData: parser = multipart.MultipartParser(boundary, callbacks) header_field = b"" header_value = b"" - raw_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]] + content_disposition = None + content_type = b"" field_name = "" data = b"" - file = None # type: typing.Optional[UploadFile] + file: typing.Optional[UploadFile] = None - items = ( - [] - ) # type: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] + items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = [] # Feed the parser with data from the request. async for chunk in self.stream: @@ -187,25 +192,30 @@ async def parse(self) -> FormData: self.messages.clear() for message_type, message_bytes in messages: if message_type == MultiPartMessage.PART_BEGIN: - raw_headers = [] + content_disposition = None + content_type = b"" data = b"" elif message_type == MultiPartMessage.HEADER_FIELD: header_field += message_bytes elif message_type == MultiPartMessage.HEADER_VALUE: header_value += message_bytes elif message_type == MultiPartMessage.HEADER_END: - raw_headers.append((header_field.lower(), header_value)) + field = header_field.lower() + if field == b"content-disposition": + content_disposition = header_value + elif field == b"content-type": + content_type = header_value header_field = b"" header_value = b"" elif message_type == MultiPartMessage.HEADERS_FINISHED: - headers = Headers(raw=raw_headers) - content_disposition = headers.get("Content-Disposition") - content_type = headers.get("Content-Type", "") disposition, options = parse_options_header(content_disposition) - field_name = options[b"name"].decode("latin-1") + field_name = _user_safe_decode(options[b"name"], charset) if b"filename" in options: - filename = options[b"filename"].decode("latin-1") - file = UploadFile(filename=filename, content_type=content_type) + filename = _user_safe_decode(options[b"filename"], charset) + file = UploadFile( + filename=filename, + content_type=content_type.decode("latin-1"), + ) else: file = None elif message_type == MultiPartMessage.PART_DATA: @@ -215,12 +225,10 @@ async def parse(self) -> FormData: await file.write(message_bytes) elif message_type == MultiPartMessage.PART_END: if file is None: - items.append((field_name, data.decode("latin-1"))) + items.append((field_name, _user_safe_decode(data, charset))) else: await file.seek(0) items.append((field_name, file)) - elif message_type == MultiPartMessage.END: - pass parser.finalize() return FormData(items) diff --git a/starlette/graphql.py b/starlette/graphql.py index 2063c0789..6e5d6ec6a 100644 --- a/starlette/graphql.py +++ b/starlette/graphql.py @@ -1,5 +1,6 @@ import json import typing +import warnings from starlette import status from starlette.background import BackgroundTasks @@ -8,13 +9,19 @@ from starlette.responses import HTMLResponse, JSONResponse, PlainTextResponse, Response from starlette.types import Receive, Scope, Send +warnings.warn( + "GraphQLApp is deprecated and will be removed in a future release. " + "Consider using a third-party GraphQL implementation. " + "See https://github.com/encode/starlette/issues/619.", + DeprecationWarning, +) + try: import graphene + from graphql.error import GraphQLError, format_error as format_graphql_error from graphql.execution.executors.asyncio import AsyncioExecutor - from graphql.error import format_error as format_graphql_error - from graphql.error import GraphQLError except ImportError: # pragma: nocover - graphene = None # type: ignore + graphene = None AsyncioExecutor = None # type: ignore format_graphql_error = None # type: ignore GraphQLError = None # type: ignore @@ -24,29 +31,18 @@ class GraphQLApp: def __init__( self, schema: "graphene.Schema", - executor: typing.Any = None, executor_class: type = None, graphiql: bool = True, ) -> None: self.schema = schema self.graphiql = graphiql - if executor is None: - # New style in 0.10.0. Use 'executor_class'. - # See issue https://github.com/encode/starlette/issues/242 - self.executor = executor - self.executor_class = executor_class - self.is_async = executor_class is not None and issubclass( - executor_class, AsyncioExecutor - ) - else: - # Old style. Use 'executor'. - # We should remove this in the next median/major version bump. - self.executor = executor - self.executor_class = None - self.is_async = isinstance(executor, AsyncioExecutor) + self.executor_class = executor_class + self.is_async = executor_class is not None and issubclass( + executor_class, AsyncioExecutor + ) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - if self.executor is None and self.executor_class is not None: + if self.executor_class is not None: self.executor = self.executor_class() request = Request(scope, receive=receive) @@ -62,7 +58,7 @@ async def handle_graphql(self, request: Request) -> Response: ) return await self.handle_graphiql(request) - data = request.query_params # type: typing.Mapping[str, typing.Any] + data: typing.Mapping[str, typing.Any] = request.query_params elif request.method == "POST": content_type = request.headers.get("Content-Type", "") @@ -107,7 +103,9 @@ async def handle_graphql(self, request: Request) -> Response: if result.errors else None ) - response_data = {"data": result.data, "errors": error_data} + response_data = {"data": result.data} + if error_data: + response_data["errors"] = error_data status_code = ( status.HTTP_400_BAD_REQUEST if result.errors else status.HTTP_200_OK ) @@ -274,4 +272,4 @@ async def handle_graphiql(self, request: Request) -> Response: -""" +""" # noqa: E501 diff --git a/starlette/middleware/__init__.py b/starlette/middleware/__init__.py index e69de29bb..5ac5b96c8 100644 --- a/starlette/middleware/__init__.py +++ b/starlette/middleware/__init__.py @@ -0,0 +1,17 @@ +import typing + + +class Middleware: + def __init__(self, cls: type, **options: typing.Any) -> None: + self.cls = cls + self.options = options + + def __iter__(self) -> typing.Iterator: + as_tuple = (self.cls, self.options) + return iter(as_tuple) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + option_strings = [f"{key}={value!r}" for key, value in self.options.items()] + args_repr = ", ".join([self.cls.__name__] + option_strings) + return f"{class_name}({args_repr})" diff --git a/starlette/middleware/authentication.py b/starlette/middleware/authentication.py index c9e4d4f68..6e2d2dade 100644 --- a/starlette/middleware/authentication.py +++ b/starlette/middleware/authentication.py @@ -22,9 +22,9 @@ def __init__( ) -> None: self.app = app self.backend = backend - self.on_error = ( - on_error if on_error is not None else self.default_on_error - ) # type: typing.Callable[[HTTPConnection, AuthenticationError], Response] + self.on_error: typing.Callable[ + [HTTPConnection, AuthenticationError], Response + ] = (on_error if on_error is not None else self.default_on_error) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] not in ["http", "websocket"]: diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index ea5afb210..77ba66925 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -1,6 +1,7 @@ -import asyncio import typing +import anyio + from starlette.requests import Request from starlette.responses import Response, StreamingResponse from starlette.types import ASGIApp, Receive, Scope, Send @@ -21,45 +22,39 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) return - request = Request(scope, receive=receive) - response = await self.dispatch_func(request, self.call_next) - await response(scope, receive, send) + async def call_next(request: Request) -> Response: + send_stream, recv_stream = anyio.create_memory_object_stream() - async def call_next(self, request: Request) -> Response: - loop = asyncio.get_event_loop() - queue = asyncio.Queue() # type: asyncio.Queue + async def coro() -> None: + async with send_stream: + await self.app(scope, request.receive, send_stream.send) - scope = dict(request) - receive = request.receive - send = queue.put + task_group.start_soon(coro) - async def coro() -> None: try: - await self.app(scope, receive, send) - finally: - await queue.put(None) - - task = loop.create_task(coro()) - message = await queue.get() - if message is None: - task.result() - raise RuntimeError("No response returned.") - assert message["type"] == "http.response.start" - - async def body_stream() -> typing.AsyncGenerator[bytes, None]: - while True: - message = await queue.get() - if message is None: - break - assert message["type"] == "http.response.body" - yield message["body"] - task.result() - - response = StreamingResponse( - status_code=message["status"], content=body_stream() - ) - response.raw_headers = message["headers"] - return response + message = await recv_stream.receive() + except anyio.EndOfStream: + raise RuntimeError("No response returned.") + + assert message["type"] == "http.response.start" + + async def body_stream() -> typing.AsyncGenerator[bytes, None]: + async with recv_stream: + async for message in recv_stream: + assert message["type"] == "http.response.body" + yield message.get("body", b"") + + response = StreamingResponse( + status_code=message["status"], content=body_stream() + ) + response.raw_headers = message["headers"] + return response + + async with anyio.create_task_group() as task_group: + request = Request(scope, receive=receive) + response = await self.dispatch_func(request, call_next) + await response(scope, receive, send) + task_group.cancel_scope.cancel() async def dispatch( self, request: Request, call_next: RequestResponseEndpoint diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py index ad2eeff48..c850579c8 100644 --- a/starlette/middleware/cors.py +++ b/starlette/middleware/cors.py @@ -6,7 +6,8 @@ from starlette.responses import PlainTextResponse, Response from starlette.types import ASGIApp, Message, Receive, Scope, Send -ALL_METHODS = ("DELETE", "GET", "OPTIONS", "PATCH", "POST", "PUT") +ALL_METHODS = ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT") +SAFELISTED_HEADERS = {"Accept", "Accept-Language", "Content-Language", "Content-Type"} class CORSMiddleware: @@ -29,8 +30,12 @@ def __init__( if allow_origin_regex is not None: compiled_allow_origin_regex = re.compile(allow_origin_regex) + allow_all_origins = "*" in allow_origins + allow_all_headers = "*" in allow_headers + preflight_explicit_allow_origin = not allow_all_origins or allow_credentials + simple_headers = {} - if "*" in allow_origins: + if allow_all_origins: simple_headers["Access-Control-Allow-Origin"] = "*" if allow_credentials: simple_headers["Access-Control-Allow-Credentials"] = "true" @@ -38,17 +43,19 @@ def __init__( simple_headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers) preflight_headers = {} - if "*" in allow_origins: - preflight_headers["Access-Control-Allow-Origin"] = "*" - else: + if preflight_explicit_allow_origin: + # The origin value will be set in preflight_response() if it is allowed. preflight_headers["Vary"] = "Origin" + else: + preflight_headers["Access-Control-Allow-Origin"] = "*" preflight_headers.update( { "Access-Control-Allow-Methods": ", ".join(allow_methods), "Access-Control-Max-Age": str(max_age), } ) - if allow_headers and "*" not in allow_headers: + allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers)) + if allow_headers and not allow_all_headers: preflight_headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers) if allow_credentials: preflight_headers["Access-Control-Allow-Credentials"] = "true" @@ -57,8 +64,9 @@ def __init__( self.allow_origins = allow_origins self.allow_methods = allow_methods self.allow_headers = [h.lower() for h in allow_headers] - self.allow_all_origins = "*" in allow_origins - self.allow_all_headers = "*" in allow_headers + self.allow_all_origins = allow_all_origins + self.allow_all_headers = allow_all_headers + self.preflight_explicit_allow_origin = preflight_explicit_allow_origin self.allow_origin_regex = compiled_allow_origin_regex self.simple_headers = simple_headers self.preflight_headers = preflight_headers @@ -87,7 +95,7 @@ def is_allowed_origin(self, origin: str) -> bool: if self.allow_all_origins: return True - if self.allow_origin_regex is not None and self.allow_origin_regex.match( + if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch( origin ): return True @@ -103,11 +111,9 @@ def preflight_response(self, request_headers: Headers) -> Response: failures = [] if self.is_allowed_origin(origin=requested_origin): - if not self.allow_all_origins: - # If self.allow_all_origins is True, then the "Access-Control-Allow-Origin" - # header is already set to "*". - # If we only allow specific origins, then we have to mirror back - # the Origin header in the response. + if self.preflight_explicit_allow_origin: + # The "else" case is already accounted for in self.preflight_headers + # and the value would be "*". headers["Access-Control-Allow-Origin"] = requested_origin else: failures.append("origin") @@ -123,6 +129,7 @@ def preflight_response(self, request_headers: Headers) -> Response: for header in [h.lower() for h in requested_headers.split(",")]: if header.strip() not in self.allow_headers: failures.append("headers") + break # We don't strictly need to use 400 responses here, since its up to # the browser to enforce the CORS policy, but its more informative @@ -155,11 +162,16 @@ async def send( # If request includes any cookie headers, then we must respond # with the specific origin instead of '*'. if self.allow_all_origins and has_cookie: - headers["Access-Control-Allow-Origin"] = origin + self.allow_explicit_origin(headers, origin) # If we only allow specific origins, then we have to mirror back # the Origin header in the response. elif not self.allow_all_origins and self.is_allowed_origin(origin=origin): - headers["Access-Control-Allow-Origin"] = origin - headers.add_vary_header("Origin") + self.allow_explicit_origin(headers, origin) + await send(message) + + @staticmethod + def allow_explicit_origin(headers: MutableHeaders, origin: str) -> None: + headers["Access-Control-Allow-Origin"] = origin + headers.add_vary_header("Origin") diff --git a/starlette/middleware/errors.py b/starlette/middleware/errors.py index 643ebc104..0eaae03ad 100644 --- a/starlette/middleware/errors.py +++ b/starlette/middleware/errors.py @@ -1,15 +1,20 @@ import asyncio +import html +import inspect import traceback import typing from starlette import status from starlette.concurrency import run_in_threadpool -from starlette.requests import Request, empty_receive +from starlette.requests import Request from starlette.responses import HTMLResponse, PlainTextResponse, Response from starlette.types import ASGIApp, Message, Receive, Scope, Send -from starlette.websockets import WebSocket, WebSocketState +from starlette.websockets import WebSocket STYLES = """ +p { + color: #211c1c; +} .traceback-container { border: 1px solid #038BB8; } @@ -20,20 +25,61 @@ font-size: 20px; margin-top: 0px; } -.traceback-content { - padding: 5px 0px 20px 20px; -} .frame-line { + padding-left: 10px; + font-family: monospace; +} +.frame-filename { + font-family: monospace; +} +.center-line { + background-color: #038BB8; + color: #f9f6e1; + padding: 5px 0px 5px 5px; +} +.lineno { + margin-right: 5px; +} +.frame-title { font-weight: unset; - padding: 10px 10px 10px 20px; + padding: 10px 10px 10px 10px; background-color: #E4F4FD; - margin-left: 10px; margin-right: 10px; - font: #394D54; color: #191f21; font-size: 17px; border: 1px solid #c7dce8; } +.collapse-btn { + float: right; + padding: 0px 5px 1px 5px; + border: solid 1px #96aebb; + cursor: pointer; +} +.collapsed { + display: none; +} +.source-code { + font-family: courier; + font-size: small; + padding-bottom: 10px; +} +""" + +JS = """ + """ TEMPLATE = """ @@ -47,21 +93,34 @@

500 Server Error

{error}

-
-

Traceback

-
{exc_html}
+
+

Traceback

+
{exc_html}
+ {js} """ FRAME_TEMPLATE = """
- File `{frame_filename}`, +

File {frame_filename}, line {frame_lineno}, in {frame_name} -

{frame_line}

+ {collapse_button} +

+
{code_context}
+""" # noqa: E501 + +LINE = """ +

+{lineno}. {line}

+""" + +CENTER_LINE = """ +

+{lineno}. {line}

""" @@ -128,30 +187,65 @@ async def _send(message: Message) -> None: # We always continue to raise the exception. # This allows servers to log the error, or allows test clients # to optionally raise the error within the test case. - raise exc from None + raise exc + + def format_line( + self, index: int, line: str, frame_lineno: int, frame_index: int + ) -> str: + values = { + # HTML escape - line could contain < or > + "line": html.escape(line).replace(" ", " "), + "lineno": (frame_lineno - frame_index) + index, + } + + if index != frame_index: + return LINE.format(**values) + return CENTER_LINE.format(**values) + + def generate_frame_html(self, frame: inspect.FrameInfo, is_collapsed: bool) -> str: + code_context = "".join( + self.format_line(index, line, frame.lineno, frame.index) # type: ignore + for index, line in enumerate(frame.code_context or []) + ) - def generate_frame_html(self, frame: traceback.FrameSummary) -> str: values = { - "frame_filename": frame.filename, + # HTML escape - filename could contain < or >, especially if it's a virtual + # file e.g. in the REPL + "frame_filename": html.escape(frame.filename), "frame_lineno": frame.lineno, - "frame_name": frame.name, - "frame_line": frame.line, + # HTML escape - if you try very hard it's possible to name a function with < + # or > + "frame_name": html.escape(frame.function), + "code_context": code_context, + "collapsed": "collapsed" if is_collapsed else "", + "collapse_button": "+" if is_collapsed else "‒", } return FRAME_TEMPLATE.format(**values) - def generate_html(self, exc: Exception) -> str: + def generate_html(self, exc: Exception, limit: int = 7) -> str: traceback_obj = traceback.TracebackException.from_exception( exc, capture_locals=True ) - exc_html = "".join( - self.generate_frame_html(frame) for frame in traceback_obj.stack + + exc_html = "" + is_collapsed = False + exc_traceback = exc.__traceback__ + if exc_traceback is not None: + frames = inspect.getinnerframes(exc_traceback, limit) + for frame in reversed(frames): + exc_html += self.generate_frame_html(frame, is_collapsed) + is_collapsed = True + + # escape error class and text + error = ( + f"{html.escape(traceback_obj.exc_type.__name__)}: " + f"{html.escape(str(traceback_obj))}" ) - error = f"{traceback_obj.exc_type.__name__}: {traceback_obj}" - return TEMPLATE.format(styles=STYLES, error=error, exc_html=exc_html) + return TEMPLATE.format(styles=STYLES, js=JS, error=error, exc_html=exc_html) def generate_plain_text(self, exc: Exception) -> str: - return "".join(traceback.format_tb(exc.__traceback__)) + return "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) def debug_response(self, request: Request, exc: Exception) -> Response: accept = request.headers.get("accept", "") diff --git a/starlette/middleware/gzip.py b/starlette/middleware/gzip.py index bb634e36d..37c6936fa 100644 --- a/starlette/middleware/gzip.py +++ b/starlette/middleware/gzip.py @@ -6,29 +6,36 @@ class GZipMiddleware: - def __init__(self, app: ASGIApp, minimum_size: int = 500) -> None: + def __init__( + self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9 + ) -> None: self.app = app self.minimum_size = minimum_size + self.compresslevel = compresslevel async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] == "http": headers = Headers(scope=scope) if "gzip" in headers.get("Accept-Encoding", ""): - responder = GZipResponder(self.app, self.minimum_size) + responder = GZipResponder( + self.app, self.minimum_size, compresslevel=self.compresslevel + ) await responder(scope, receive, send) return await self.app(scope, receive, send) class GZipResponder: - def __init__(self, app: ASGIApp, minimum_size: int) -> None: + def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None: self.app = app self.minimum_size = minimum_size - self.send = unattached_send # type: Send - self.initial_message = {} # type: Message + self.send: Send = unattached_send + self.initial_message: Message = {} self.started = False self.gzip_buffer = io.BytesIO() - self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer) + self.gzip_file = gzip.GzipFile( + mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel + ) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: self.send = send @@ -38,7 +45,7 @@ async def send_with_gzip(self, message: Message) -> None: message_type = message["type"] if message_type == "http.response.start": # Don't send the initial message until we've determined how to - # modify the ougoging headers correctly. + # modify the outgoing headers correctly. self.initial_message = message elif message_type == "http.response.body" and not self.started: self.started = True diff --git a/starlette/middleware/httpsredirect.py b/starlette/middleware/httpsredirect.py index 13f3a70a6..a8359067f 100644 --- a/starlette/middleware/httpsredirect.py +++ b/starlette/middleware/httpsredirect.py @@ -13,7 +13,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: redirect_scheme = {"http": "https", "ws": "wss"}[url.scheme] netloc = url.hostname if url.port in (80, 443) else url.netloc url = url.replace(scheme=redirect_scheme, netloc=netloc) - response = RedirectResponse(url, status_code=301) + response = RedirectResponse(url, status_code=307) await response(scope, receive, send) else: await self.app(scope, receive, send) diff --git a/starlette/middleware/sessions.py b/starlette/middleware/sessions.py index 8e47454fc..a13ec5c0e 100644 --- a/starlette/middleware/sessions.py +++ b/starlette/middleware/sessions.py @@ -6,7 +6,7 @@ from itsdangerous.exc import BadTimeSignature, SignatureExpired from starlette.datastructures import MutableHeaders, Secret -from starlette.requests import Request +from starlette.requests import HTTPConnection from starlette.types import ASGIApp, Message, Receive, Scope, Send @@ -33,11 +33,11 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) return - request = Request(scope) + connection = HTTPConnection(scope) initial_session_was_empty = True - if self.session_cookie in request.cookies: - data = request.cookies[self.session_cookie].encode("utf-8") + if self.session_cookie in connection.cookies: + data = connection.cookies[self.session_cookie].encode("utf-8") try: data = self.signer.unsign(data, max_age=self.max_age) scope["session"] = json.loads(b64decode(data)) @@ -49,14 +49,16 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async def send_wrapper(message: Message) -> None: if message["type"] == "http.response.start": + path = scope.get("root_path", "") or "/" if scope["session"]: # We have session data to persist. data = b64encode(json.dumps(scope["session"]).encode("utf-8")) data = self.signer.sign(data) headers = MutableHeaders(scope=message) - header_value = "%s=%s; path=/; Max-Age=%d; %s" % ( + header_value = "%s=%s; path=%s; Max-Age=%d; %s" % ( self.session_cookie, data.decode("utf-8"), + path, self.max_age, self.security_flags, ) @@ -64,9 +66,9 @@ async def send_wrapper(message: Message) -> None: elif not initial_session_was_empty: # The session has been cleared. headers = MutableHeaders(scope=message) - header_value = "%s=%s; %s" % ( + header_value = "{}={}; {}".format( self.session_cookie, - "null; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT;", + f"null; path={path}; expires=Thu, 01 Jan 1970 00:00:00 GMT;", self.security_flags, ) headers.append("Set-Cookie", header_value) diff --git a/starlette/middleware/trustedhost.py b/starlette/middleware/trustedhost.py index 398365248..6bc4d2b5e 100644 --- a/starlette/middleware/trustedhost.py +++ b/starlette/middleware/trustedhost.py @@ -50,10 +50,11 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if is_valid_host: await self.app(scope, receive, send) else: + response: Response if found_www_redirect and self.www_redirect: url = URL(scope=scope) redirect_url = url.replace(netloc="www." + url.netloc) - response = RedirectResponse(url=str(redirect_url)) # type: Response + response = RedirectResponse(url=str(redirect_url)) else: response = PlainTextResponse("Invalid host header", status_code=400) await response(scope, receive, send) diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index a9aef0942..7e69e1a6b 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -1,10 +1,11 @@ -import asyncio import io +import math import sys import typing -from starlette.concurrency import run_in_threadpool -from starlette.types import Message, Receive, Scope, Send +import anyio + +from starlette.types import Receive, Scope, Send def build_environ(scope: Scope, body: bytes) -> dict: @@ -13,8 +14,8 @@ def build_environ(scope: Scope, body: bytes) -> dict: """ environ = { "REQUEST_METHOD": scope["method"], - "SCRIPT_NAME": scope.get("root_path", ""), - "PATH_INFO": scope["path"], + "SCRIPT_NAME": scope.get("root_path", "").encode("utf8").decode("latin1"), + "PATH_INFO": scope["path"].encode("utf8").decode("latin1"), "QUERY_STRING": scope["query_string"].decode("ascii"), "SERVER_PROTOCOL": f"HTTP/{scope['http_version']}", "wsgi.version": (1, 0), @@ -44,7 +45,8 @@ def build_environ(scope: Scope, body: bytes) -> dict: corrected_name = "CONTENT_TYPE" else: corrected_name = f"HTTP_{name}".upper().replace("-", "_") - # HTTPbis say only ASCII chars are allowed in headers, but we latin1 just in case + # HTTPbis say only ASCII chars are allowed in headers, but we latin1 just in + # case value = value.decode("latin1") if corrected_name in environ: value = environ[corrected_name] + "," + value @@ -53,7 +55,7 @@ def build_environ(scope: Scope, body: bytes) -> dict: class WSGIMiddleware: - def __init__(self, app: typing.Callable, workers: int = 10) -> None: + def __init__(self, app: typing.Callable) -> None: self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: @@ -68,11 +70,11 @@ def __init__(self, app: typing.Callable, scope: Scope) -> None: self.scope = scope self.status = None self.response_headers = None - self.send_event = asyncio.Event() - self.send_queue = [] # type: typing.List[typing.Optional[Message]] - self.loop = asyncio.get_event_loop() + self.stream_send, self.stream_receive = anyio.create_memory_object_stream( + math.inf + ) self.response_started = False - self.exc_info = None # type: typing.Any + self.exc_info: typing.Any = None async def __call__(self, receive: Receive, send: Send) -> None: body = b"" @@ -82,30 +84,18 @@ async def __call__(self, receive: Receive, send: Send) -> None: body += message.get("body", b"") more_body = message.get("more_body", False) environ = build_environ(self.scope, body) - try: - sender = self.loop.create_task(self.sender(send)) - await run_in_threadpool(self.wsgi, environ, self.start_response) - self.send_queue.append(None) - self.send_event.set() - await asyncio.wait_for(sender, None) - if self.exc_info is not None: - raise self.exc_info[0].with_traceback( - self.exc_info[1], self.exc_info[2] - ) - finally: - if not sender.done(): - sender.cancel() # pragma: no cover + + async with anyio.create_task_group() as task_group: + task_group.start_soon(self.sender, send) + async with self.stream_send: + await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response) + if self.exc_info is not None: + raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2]) async def sender(self, send: Send) -> None: - while True: - if self.send_queue: - message = self.send_queue.pop(0) - if message is None: - return + async with self.stream_receive: + async for message in self.stream_receive: await send(message) - else: - await self.send_event.wait() - self.send_event.clear() def start_response( self, @@ -119,24 +109,25 @@ def start_response( status_code_string, _ = status.split(" ", 1) status_code = int(status_code_string) headers = [ - (name.encode("ascii"), value.encode("ascii")) + (name.strip().encode("ascii").lower(), value.strip().encode("ascii")) for name, value in response_headers ] - self.send_queue.append( + anyio.from_thread.run( + self.stream_send.send, { "type": "http.response.start", "status": status_code, "headers": headers, - } + }, ) - self.loop.call_soon_threadsafe(self.send_event.set) def wsgi(self, environ: dict, start_response: typing.Callable) -> None: for chunk in self.app(environ, start_response): - self.send_queue.append( - {"type": "http.response.body", "body": chunk, "more_body": True} + anyio.from_thread.run( + self.stream_send.send, + {"type": "http.response.body", "body": chunk, "more_body": True}, ) - self.loop.call_soon_threadsafe(self.send_event.set) - self.send_queue.append({"type": "http.response.body", "body": b""}) - self.loop.call_soon_threadsafe(self.send_event.set) + anyio.from_thread.run( + self.stream_send.send, {"type": "http.response.body", "body": b""} + ) diff --git a/starlette/requests.py b/starlette/requests.py index 44433a064..f88021645 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -1,24 +1,57 @@ -import asyncio -import http.cookies import json import typing from collections.abc import Mapping +from http import cookies as http_cookies -from starlette.datastructures import URL, Address, FormData, Headers, QueryParams +import anyio + +from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State from starlette.formparsers import FormParser, MultiPartParser -from starlette.types import Message, Receive, Scope +from starlette.types import Message, Receive, Scope, Send try: from multipart.multipart import parse_options_header except ImportError: # pragma: nocover - parse_options_header = None # type: ignore + parse_options_header = None -class ClientDisconnect(Exception): - pass +SERVER_PUSH_HEADERS_TO_COPY = { + "accept", + "accept-encoding", + "accept-language", + "cache-control", + "user-agent", +} + + +def cookie_parser(cookie_string: str) -> typing.Dict[str, str]: + """ + This function parses a ``Cookie`` HTTP header into a dict of key/value pairs. + + It attempts to mimic browser cookie parsing behavior: browsers and web servers + frequently disregard the spec (RFC 6265) when setting and reading cookies, + so we attempt to suit the common scenarios here. + + This function has been adapted from Django 3.1.0. + Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based + on an outdated spec and will fail on lots of input we want to support + """ + cookie_dict: typing.Dict[str, str] = {} + for chunk in cookie_string.split(";"): + if "=" in chunk: + key, val = chunk.split("=", 1) + else: + # Assume an empty name per + # https://bugzilla.mozilla.org/show_bug.cgi?id=169091 + key, val = "", chunk + key, val = key.strip(), val.strip() + if key or val: + # unquote using Python's algorithm. + cookie_dict[key] = http_cookies._unquote(val) # type: ignore + return cookie_dict -class State: +class ClientDisconnect(Exception): pass @@ -30,109 +63,136 @@ class HTTPConnection(Mapping): def __init__(self, scope: Scope, receive: Receive = None) -> None: assert scope["type"] in ("http", "websocket") - self._scope = scope + self.scope = scope def __getitem__(self, key: str) -> str: - return self._scope[key] + return self.scope[key] def __iter__(self) -> typing.Iterator[str]: - return iter(self._scope) + return iter(self.scope) def __len__(self) -> int: - return len(self._scope) + return len(self.scope) + + # Don't use the `abc.Mapping.__eq__` implementation. + # Connection instances should never be considered equal + # unless `self is other`. + __eq__ = object.__eq__ + __hash__ = object.__hash__ @property def app(self) -> typing.Any: - return self._scope["app"] + return self.scope["app"] @property def url(self) -> URL: if not hasattr(self, "_url"): - self._url = URL(scope=self._scope) + self._url = URL(scope=self.scope) return self._url + @property + def base_url(self) -> URL: + if not hasattr(self, "_base_url"): + base_url_scope = dict(self.scope) + base_url_scope["path"] = "/" + base_url_scope["query_string"] = b"" + base_url_scope["root_path"] = base_url_scope.get( + "app_root_path", base_url_scope.get("root_path", "") + ) + self._base_url = URL(scope=base_url_scope) + return self._base_url + @property def headers(self) -> Headers: if not hasattr(self, "_headers"): - self._headers = Headers(scope=self._scope) + self._headers = Headers(scope=self.scope) return self._headers @property def query_params(self) -> QueryParams: if not hasattr(self, "_query_params"): - self._query_params = QueryParams(self._scope["query_string"]) + self._query_params = QueryParams(self.scope["query_string"]) return self._query_params @property def path_params(self) -> dict: - return self._scope.get("path_params", {}) + return self.scope.get("path_params", {}) @property def cookies(self) -> typing.Dict[str, str]: if not hasattr(self, "_cookies"): - cookies = {} + cookies: typing.Dict[str, str] = {} cookie_header = self.headers.get("cookie") + if cookie_header: - cookie = http.cookies.SimpleCookie() - cookie.load(cookie_header) - for key, morsel in cookie.items(): - cookies[key] = morsel.value + cookies = cookie_parser(cookie_header) self._cookies = cookies return self._cookies @property def client(self) -> Address: - host, port = self._scope.get("client") or (None, None) + host, port = self.scope.get("client") or (None, None) return Address(host=host, port=port) @property def session(self) -> dict: assert ( - "session" in self._scope + "session" in self.scope ), "SessionMiddleware must be installed to access request.session" - return self._scope["session"] + return self.scope["session"] @property def auth(self) -> typing.Any: assert ( - "auth" in self._scope + "auth" in self.scope ), "AuthenticationMiddleware must be installed to access request.auth" - return self._scope["auth"] + return self.scope["auth"] @property def user(self) -> typing.Any: assert ( - "user" in self._scope + "user" in self.scope ), "AuthenticationMiddleware must be installed to access request.user" - return self._scope["user"] + return self.scope["user"] @property def state(self) -> State: - if "state" not in self._scope: - self._scope["state"] = State() - return self._scope["state"] + if not hasattr(self, "_state"): + # Ensure 'state' has an empty dict if it's not already populated. + self.scope.setdefault("state", {}) + # Create a state instance with a reference to the dict in which it should + # store info + self._state = State(self.scope["state"]) + return self._state def url_for(self, name: str, **path_params: typing.Any) -> str: - router = self._scope["router"] + router = self.scope["router"] url_path = router.url_path_for(name, **path_params) - return url_path.make_absolute_url(base_url=self.url) + return url_path.make_absolute_url(base_url=self.base_url) async def empty_receive() -> Message: raise RuntimeError("Receive channel has not been made available") +async def empty_send(message: Message) -> None: + raise RuntimeError("Send channel has not been made available") + + class Request(HTTPConnection): - def __init__(self, scope: Scope, receive: Receive = empty_receive): + def __init__( + self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send + ): super().__init__(scope) assert scope["type"] == "http" self._receive = receive + self._send = send self._stream_consumed = False self._is_disconnected = False @property def method(self) -> str: - return self._scope["method"] + return self.scope["method"] @property def receive(self) -> Receive: @@ -163,10 +223,10 @@ async def stream(self) -> typing.AsyncGenerator[bytes, None]: async def body(self) -> bytes: if not hasattr(self, "_body"): - body = b"" + chunks = [] async for chunk in self.stream(): - body += chunk - self._body = body + chunks.append(chunk) + self._body = b"".join(chunks) return self._body async def json(self) -> typing.Any: @@ -198,12 +258,26 @@ async def close(self) -> None: async def is_disconnected(self) -> bool: if not self._is_disconnected: - try: - message = await asyncio.wait_for(self._receive(), timeout=0.0000001) - except asyncio.TimeoutError: - message = {} + message: Message = {} + + # If message isn't immediately available, move on + with anyio.CancelScope() as cs: + cs.cancel() + message = await self._receive() if message.get("type") == "http.disconnect": self._is_disconnected = True return self._is_disconnected + + async def send_push_promise(self, path: str) -> None: + if "http.response.push" in self.scope.get("extensions", {}): + raw_headers = [] + for name in SERVER_PUSH_HEADERS_TO_COPY: + for value in self.headers.getlist(name): + raw_headers.append( + (name.encode("latin-1"), value.encode("latin-1")) + ) + await self._send( + {"type": "http.response.push", "path": path, "headers": raw_headers} + ) diff --git a/starlette/responses.py b/starlette/responses.py index 48afc47ad..d03df2329 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -1,30 +1,33 @@ import hashlib import http.cookies -import inspect import json import os import stat +import sys import typing from email.utils import formatdate -from mimetypes import guess_type -from urllib.parse import quote_plus +from functools import partial +from mimetypes import guess_type as mimetypes_guess_type +from urllib.parse import quote + +import anyio from starlette.background import BackgroundTask from starlette.concurrency import iterate_in_threadpool from starlette.datastructures import URL, MutableHeaders from starlette.types import Receive, Scope, Send -try: - import aiofiles - from aiofiles.os import stat as aio_stat -except ImportError: # pragma: nocover - aiofiles = None # type: ignore - aio_stat = None # type: ignore +# Workaround for adding samesite support to pre 3.8 python +http.cookies.Morsel._reserved["samesite"] = "SameSite" # type: ignore + -try: - import ujson -except ImportError: # pragma: nocover - ujson = None # type: ignore +# Compatibility wrapper for `mimetypes.guess_type` to support `os.PathLike` on typing.Tuple[typing.Optional[str], typing.Optional[str]]: + if sys.version_info < (3, 8): # pragma: no cover + url = os.fspath(url) + return mimetypes_guess_type(url, strict) class Response: @@ -39,24 +42,23 @@ def __init__( media_type: str = None, background: BackgroundTask = None, ) -> None: - if content is None: - self.body = b"" - else: - self.body = self.render(content) self.status_code = status_code if media_type is not None: self.media_type = media_type self.background = background + self.body = self.render(content) self.init_headers(headers) def render(self, content: typing.Any) -> bytes: + if content is None: + return b"" if isinstance(content, bytes): return content return content.encode(self.charset) def init_headers(self, headers: typing.Mapping[str, str] = None) -> None: if headers is None: - raw_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]] + raw_headers: typing.List[typing.Tuple[bytes, bytes]] = [] populate_content_length = True populate_content_type = True else: @@ -97,21 +99,29 @@ def set_cookie( domain: str = None, secure: bool = False, httponly: bool = False, + samesite: str = "lax", ) -> None: - cookie = http.cookies.SimpleCookie() + cookie: http.cookies.BaseCookie = http.cookies.SimpleCookie() cookie[key] = value if max_age is not None: - cookie[key]["max-age"] = max_age # type: ignore + cookie[key]["max-age"] = max_age if expires is not None: - cookie[key]["expires"] = expires # type: ignore + cookie[key]["expires"] = expires if path is not None: cookie[key]["path"] = path if domain is not None: cookie[key]["domain"] = domain if secure: - cookie[key]["secure"] = True # type: ignore + cookie[key]["secure"] = True if httponly: - cookie[key]["httponly"] = True # type: ignore + cookie[key]["httponly"] = True + if samesite is not None: + assert samesite.lower() in [ + "strict", + "lax", + "none", + ], "samesite must be either 'strict', 'lax' or 'none'" + cookie[key]["samesite"] = samesite cookie_val = cookie.output(header="").strip() self.raw_headers.append((b"set-cookie", cookie_val.encode("latin-1"))) @@ -153,19 +163,18 @@ def render(self, content: typing.Any) -> bytes: ).encode("utf-8") -class UJSONResponse(JSONResponse): - media_type = "application/json" - - def render(self, content: typing.Any) -> bytes: - return ujson.dumps(content, ensure_ascii=False).encode("utf-8") - - class RedirectResponse(Response): def __init__( - self, url: typing.Union[str, URL], status_code: int = 302, headers: dict = None + self, + url: typing.Union[str, URL], + status_code: int = 307, + headers: dict = None, + background: BackgroundTask = None, ) -> None: - super().__init__(content=b"", status_code=status_code, headers=headers) - self.headers["location"] = quote_plus(str(url), safe=":/%#?&=@[]!$&'()*+,;") + super().__init__( + content=b"", status_code=status_code, headers=headers, background=background + ) + self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;") class StreamingResponse(Response): @@ -177,7 +186,7 @@ def __init__( media_type: str = None, background: BackgroundTask = None, ) -> None: - if inspect.isasyncgen(content): + if isinstance(content, typing.AsyncIterable): self.body_iterator = content else: self.body_iterator = iterate_in_threadpool(content) @@ -186,7 +195,13 @@ def __init__( self.background = background self.init_headers(headers) - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + async def listen_for_disconnect(self, receive: Receive) -> None: + while True: + message = await receive() + if message["type"] == "http.disconnect": + break + + async def stream_response(self, send: Send) -> None: await send( { "type": "http.response.start", @@ -198,8 +213,19 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if not isinstance(chunk, bytes): chunk = chunk.encode(self.charset) await send({"type": "http.response.body", "body": chunk, "more_body": True}) + await send({"type": "http.response.body", "body": b"", "more_body": False}) + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + async with anyio.create_task_group() as task_group: + + async def wrap(func: typing.Callable[[], typing.Coroutine]) -> None: + await func() + task_group.cancel_scope.cancel() + + task_group.start_soon(wrap, partial(self.stream_response, send)) + await wrap(partial(self.listen_for_disconnect, receive)) + if self.background is not None: await self.background() @@ -209,7 +235,7 @@ class FileResponse(Response): def __init__( self, - path: str, + path: typing.Union[str, "os.PathLike[str]"], status_code: int = 200, headers: dict = None, media_type: str = None, @@ -218,7 +244,6 @@ def __init__( stat_result: os.stat_result = None, method: str = None, ) -> None: - assert aiofiles is not None, "'aiofiles' must be installed to use FileResponse" self.path = path self.status_code = status_code self.filename = filename @@ -229,7 +254,13 @@ def __init__( self.background = background self.init_headers(headers) if self.filename is not None: - content_disposition = 'attachment; filename="{}"'.format(self.filename) + content_disposition_filename = quote(self.filename) + if content_disposition_filename != self.filename: + content_disposition = "attachment; filename*=utf-8''{}".format( + content_disposition_filename + ) + else: + content_disposition = f'attachment; filename="{self.filename}"' self.headers.setdefault("content-disposition", content_disposition) self.stat_result = stat_result if stat_result is not None: @@ -248,7 +279,7 @@ def set_stat_headers(self, stat_result: os.stat_result) -> None: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if self.stat_result is None: try: - stat_result = await aio_stat(self.path) + stat_result = await anyio.to_thread.run_sync(os.stat, self.path) self.set_stat_headers(stat_result) except FileNotFoundError: raise RuntimeError(f"File at path {self.path} does not exist.") @@ -264,9 +295,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: } ) if self.send_header_only: - await send({"type": "http.response.body"}) + await send({"type": "http.response.body", "body": b"", "more_body": False}) else: - async with aiofiles.open(self.path, mode="rb") as file: + async with await anyio.open_file(self.path, mode="rb") as file: more_body = True while more_body: chunk = await file.read(self.chunk_size) diff --git a/starlette/routing.py b/starlette/routing.py index 0b5fb18b8..9a1a5e12d 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -1,7 +1,13 @@ import asyncio +import contextlib +import functools import inspect import re +import sys +import traceback +import types import typing +import warnings from enum import Enum from starlette.concurrency import run_in_threadpool @@ -13,6 +19,11 @@ from starlette.types import ASGIApp, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketClose +if sys.version_info >= (3, 7): + from contextlib import asynccontextmanager # pragma: no cover +else: + from contextlib2 import asynccontextmanager # pragma: no cover + class NoMatchFound(Exception): """ @@ -27,15 +38,25 @@ class Match(Enum): FULL = 2 +def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: + """ + Correctly determines if an object is a coroutine function, + including those wrapped in functools.partial objects. + """ + while isinstance(obj, functools.partial): + obj = obj.func + return inspect.iscoroutinefunction(obj) + + def request_response(func: typing.Callable) -> ASGIApp: """ Takes a function or coroutine `func(request) -> response`, and returns an ASGI application. """ - is_coroutine = asyncio.iscoroutinefunction(func) + is_coroutine = iscoroutinefunction_or_partial(func) async def app(scope: Scope, receive: Receive, send: Send) -> None: - request = Request(scope, receive=receive) + request = Request(scope, receive=receive, send=send) if is_coroutine: response = await func(request) else: @@ -83,7 +104,7 @@ def replace_params( def compile_path( - path: str + path: str, ) -> typing.Tuple[typing.Pattern, str, typing.Dict[str, Convertor]]: """ Given a path string, like: "/{username:str}", return a three-tuple @@ -95,6 +116,7 @@ def compile_path( """ path_regex = "^" path_format = "" + duplicated_params = set() idx = 0 param_convertors = {} @@ -106,17 +128,25 @@ def compile_path( ), f"Unknown path convertor '{convertor_type}'" convertor = CONVERTOR_TYPES[convertor_type] - path_regex += path[idx : match.start()] + path_regex += re.escape(path[idx : match.start()]) path_regex += f"(?P<{param_name}>{convertor.regex})" path_format += path[idx : match.start()] path_format += "{%s}" % param_name + if param_name in param_convertors: + duplicated_params.add(param_name) + param_convertors[param_name] = convertor idx = match.end() - path_regex += path[idx:] + "$" + if duplicated_params: + names = ", ".join(sorted(duplicated_params)) + ending = "s" if len(duplicated_params) > 1 else "" + raise ValueError(f"Duplicated param name{ending} {names} at path {path}") + + path_regex += re.escape(path[idx:]) + "$" path_format += path[idx:] return re.compile(path_regex), path_format, param_convertors @@ -129,9 +159,28 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: def url_path_for(self, name: str, **path_params: str) -> URLPath: raise NotImplementedError() # pragma: no cover - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: raise NotImplementedError() # pragma: no cover + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ + A route may be used in isolation as a stand-alone ASGI app. + This is a somewhat contrived case, as they'll almost always be used + within a Router, but could be useful for some tooling and minimal apps. + """ + match, child_scope = self.matches(scope) + if match == Match.NONE: + if scope["type"] == "http": + response = PlainTextResponse("Not Found", status_code=404) + await response(scope, receive, send) + elif scope["type"] == "websocket": + websocket_close = WebSocketClose() + await websocket_close(scope, receive, send) + return + + scope.update(child_scope) + await self.handle(scope, receive, send) + class Route(BaseRoute): def __init__( @@ -149,7 +198,10 @@ def __init__( self.name = get_name(endpoint) if name is None else name self.include_in_schema = include_in_schema - if inspect.isfunction(endpoint) or inspect.ismethod(endpoint): + endpoint_handler = endpoint + while isinstance(endpoint_handler, functools.partial): + endpoint_handler = endpoint_handler.func + if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler): # Endpoint is function or method. Treat it as `func(request) -> response`. self.app = request_response(endpoint) if methods is None: @@ -161,9 +213,9 @@ def __init__( if methods is None: self.methods = None else: - self.methods = set([method.upper() for method in methods]) + self.methods = {method.upper() for method in methods} if "GET" in self.methods: - self.methods |= set(["HEAD"]) + self.methods.add("HEAD") self.path_regex, self.path_format, self.param_convertors = compile_path(path) @@ -196,7 +248,7 @@ def url_path_for(self, name: str, **path_params: str) -> URLPath: assert not remaining_params return URLPath(path=path, protocol="http") - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: if self.methods and scope["method"] not in self.methods: if "app" in scope: raise HTTPException(status_code=405) @@ -231,8 +283,6 @@ def __init__( # Endpoint is a class. Treat it as ASGI. self.app = endpoint - regex = "^" + path + "$" - regex = re.sub("{([a-zA-Z_][a-zA-Z0-9_]*)}", r"(?P<\1>[^/]+)", regex) self.path_regex, self.path_format, self.param_convertors = compile_path(path) def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: @@ -261,7 +311,7 @@ def url_path_for(self, name: str, **path_params: str) -> URLPath: assert not remaining_params return URLPath(path=path, protocol="websocket") - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) def __eq__(self, other: typing.Any) -> bool: @@ -277,7 +327,7 @@ def __init__( self, path: str, app: ASGIApp = None, - routes: typing.List[BaseRoute] = None, + routes: typing.Sequence[BaseRoute] = None, name: str = None, ) -> None: assert path == "" or path.startswith("/"), "Routed paths must start with '/'" @@ -286,7 +336,7 @@ def __init__( ), "Either 'app=...', or 'routes=' must be specified" self.path = path.rstrip("/") if app is not None: - self.app = app # type: ASGIApp + self.app: ASGIApp = app else: self.app = Router(routes=routes) self.name = name @@ -310,9 +360,11 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: matched_path = path[: -len(remaining_path)] path_params = dict(scope.get("path_params", {})) path_params.update(matched_params) + root_path = scope.get("root_path", "") child_scope = { "path_params": path_params, - "root_path": scope.get("root_path", "") + matched_path, + "app_root_path": scope.get("app_root_path", root_path), + "root_path": root_path + matched_path, "path": remaining_path, "endpoint": self.app, } @@ -335,21 +387,24 @@ def url_path_for(self, name: str, **path_params: str) -> URLPath: else: # 'name' matches ":". remaining_name = name[len(self.name) + 1 :] + path_kwarg = path_params.get("path") path_params["path"] = "" - path, remaining_params = replace_params( + path_prefix, remaining_params = replace_params( self.path_format, self.param_convertors, path_params ) + if path_kwarg is not None: + remaining_params["path"] = path_kwarg for route in self.routes or []: try: url = route.url_path_for(remaining_name, **remaining_params) return URLPath( - path=path.rstrip("/") + str(url), protocol=url.protocol + path=path_prefix.rstrip("/") + str(url), protocol=url.protocol ) except NoMatchFound: pass raise NoMatchFound() - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) def __eq__(self, other: typing.Any) -> bool: @@ -413,7 +468,7 @@ def url_path_for(self, name: str, **path_params: str) -> URLPath: pass raise NoMatchFound() - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) def __eq__(self, other: typing.Any) -> bool: @@ -424,72 +479,216 @@ def __eq__(self, other: typing.Any) -> bool: ) -class Lifespan(BaseRoute): - def __init__( - self, on_startup: typing.Callable = None, on_shutdown: typing.Callable = None - ): - self.startup_handlers = [] if on_startup is None else [on_startup] - self.shutdown_handlers = [] if on_shutdown is None else [on_shutdown] +_T = typing.TypeVar("_T") - def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: - if scope["type"] == "lifespan": - return Match.FULL, {} - return Match.NONE, {} - def add_event_handler(self, event_type: str, func: typing.Callable) -> None: - assert event_type in ("startup", "shutdown") +class _AsyncLiftContextManager(typing.AsyncContextManager[_T]): + def __init__(self, cm: typing.ContextManager[_T]): + self._cm = cm - if event_type == "startup": - self.startup_handlers.append(func) + async def __aenter__(self) -> _T: + return self._cm.__enter__() + + async def __aexit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]], + exc_value: typing.Optional[BaseException], + traceback: typing.Optional[types.TracebackType], + ) -> typing.Optional[bool]: + return self._cm.__exit__(exc_type, exc_value, traceback) + + +def _wrap_gen_lifespan_context( + lifespan_context: typing.Callable[[typing.Any], typing.Generator] +) -> typing.Callable[[typing.Any], typing.AsyncContextManager]: + cmgr = contextlib.contextmanager(lifespan_context) + + @functools.wraps(cmgr) + def wrapper(app: typing.Any) -> _AsyncLiftContextManager: + return _AsyncLiftContextManager(cmgr(app)) + + return wrapper + + +class _DefaultLifespan: + def __init__(self, router: "Router"): + self._router = router + + async def __aenter__(self) -> None: + await self._router.startup() + + async def __aexit__(self, *exc_info: object) -> None: + await self._router.shutdown() + + def __call__(self: _T, app: object) -> _T: + return self + + +class Router: + def __init__( + self, + routes: typing.Sequence[BaseRoute] = None, + redirect_slashes: bool = True, + default: ASGIApp = None, + on_startup: typing.Sequence[typing.Callable] = None, + on_shutdown: typing.Sequence[typing.Callable] = None, + lifespan: typing.Callable[[typing.Any], typing.AsyncContextManager] = None, + ) -> None: + self.routes = [] if routes is None else list(routes) + self.redirect_slashes = redirect_slashes + self.default = self.not_found if default is None else default + self.on_startup = [] if on_startup is None else list(on_startup) + self.on_shutdown = [] if on_shutdown is None else list(on_shutdown) + + if lifespan is None: + self.lifespan_context: typing.Callable[ + [typing.Any], typing.AsyncContextManager + ] = _DefaultLifespan(self) + + elif inspect.isasyncgenfunction(lifespan): + warnings.warn( + "async generator function lifespans are deprecated, " + "use an @contextlib.asynccontextmanager function instead", + DeprecationWarning, + ) + self.lifespan_context = asynccontextmanager( + lifespan, # type: ignore[arg-type] + ) + elif inspect.isgeneratorfunction(lifespan): + warnings.warn( + "generator function lifespans are deprecated, " + "use an @contextlib.asynccontextmanager function instead", + DeprecationWarning, + ) + self.lifespan_context = _wrap_gen_lifespan_context( + lifespan, # type: ignore[arg-type] + ) else: - assert event_type == "shutdown" - self.shutdown_handlers.append(func) + self.lifespan_context = lifespan - def on_event(self, event_type: str) -> typing.Callable: - def decorator(func: typing.Callable) -> typing.Callable: - self.add_event_handler(event_type, func) - return func + async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "websocket": + websocket_close = WebSocketClose() + await websocket_close(scope, receive, send) + return - return decorator + # If we're running inside a starlette application then raise an + # exception, so that the configurable exception handler can deal with + # returning the response. For plain ASGI apps, just return the response. + if "app" in scope: + raise HTTPException(status_code=404) + else: + response = PlainTextResponse("Not Found", status_code=404) + await response(scope, receive, send) + + def url_path_for(self, name: str, **path_params: str) -> URLPath: + for route in self.routes: + try: + return route.url_path_for(name, **path_params) + except NoMatchFound: + pass + raise NoMatchFound() async def startup(self) -> None: - for handler in self.startup_handlers: + """ + Run any `.on_startup` event handlers. + """ + for handler in self.on_startup: if asyncio.iscoroutinefunction(handler): await handler() else: handler() async def shutdown(self) -> None: - for handler in self.shutdown_handlers: + """ + Run any `.on_shutdown` event handlers. + """ + for handler in self.on_shutdown: if asyncio.iscoroutinefunction(handler): await handler() else: handler() + async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None: + """ + Handle ASGI lifespan messages, which allows us to manage application + startup and shutdown events. + """ + started = False + app = scope.get("app") + await receive() + try: + async with self.lifespan_context(app): + await send({"type": "lifespan.startup.complete"}) + started = True + await receive() + except BaseException: + exc_text = traceback.format_exc() + if started: + await send({"type": "lifespan.shutdown.failed", "message": exc_text}) + else: + await send({"type": "lifespan.startup.failed", "message": exc_text}) + raise + else: + await send({"type": "lifespan.shutdown.complete"}) + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - message = await receive() - assert message["type"] == "lifespan.startup" - await self.startup() - await send({"type": "lifespan.startup.complete"}) + """ + The main entry point to the Router class. + """ + assert scope["type"] in ("http", "websocket", "lifespan") - message = await receive() - assert message["type"] == "lifespan.shutdown" - await self.shutdown() - await send({"type": "lifespan.shutdown.complete"}) + if "router" not in scope: + scope["router"] = self + if scope["type"] == "lifespan": + await self.lifespan(scope, receive, send) + return -class Router: - def __init__( - self, - routes: typing.List[BaseRoute] = None, - redirect_slashes: bool = True, - default: ASGIApp = None, - ) -> None: - self.routes = [] if routes is None else list(routes) - self.redirect_slashes = redirect_slashes - self.default = self.not_found if default is None else default - self.lifespan = Lifespan() + partial = None + + for route in self.routes: + # Determine if any route matches the incoming scope, + # and hand over to the matching route if found. + match, child_scope = route.matches(scope) + if match == Match.FULL: + scope.update(child_scope) + await route.handle(scope, receive, send) + return + elif match == Match.PARTIAL and partial is None: + partial = route + partial_scope = child_scope + if partial is not None: + #  Handle partial matches. These are cases where an endpoint is + # able to handle the request, but is not a preferred option. + # We use this in particular to deal with "405 Method Not Allowed". + scope.update(partial_scope) + await partial.handle(scope, receive, send) + return + + if scope["type"] == "http" and self.redirect_slashes and scope["path"] != "/": + redirect_scope = dict(scope) + if scope["path"].endswith("/"): + redirect_scope["path"] = redirect_scope["path"].rstrip("/") + else: + redirect_scope["path"] = redirect_scope["path"] + "/" + + for route in self.routes: + match, child_scope = route.matches(redirect_scope) + if match != Match.NONE: + redirect_url = URL(scope=redirect_scope) + response = RedirectResponse(url=str(redirect_url)) + await response(scope, receive, send) + return + + await self.default(scope, receive, send) + + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, Router) and self.routes == other.routes + + # The following usages are now discouraged in favour of configuration + #  during Router.__init__(...) def mount(self, path: str, app: ASGIApp, name: str = None) -> None: route = Mount(path, app=app, name=name) self.routes.append(route) @@ -547,69 +746,17 @@ def decorator(func: typing.Callable) -> typing.Callable: return decorator - async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None: - if scope["type"] == "websocket": - websocket_close = WebSocketClose() - await websocket_close(receive, send) - return + def add_event_handler(self, event_type: str, func: typing.Callable) -> None: + assert event_type in ("startup", "shutdown") - # If we're running inside a starlette application then raise an - # exception, so that the configurable exception handler can deal with - # returning the response. For plain ASGI apps, just return the response. - if "app" in scope: - raise HTTPException(status_code=404) + if event_type == "startup": + self.on_startup.append(func) else: - response = PlainTextResponse("Not Found", status_code=404) - await response(scope, receive, send) - - def url_path_for(self, name: str, **path_params: str) -> URLPath: - for route in self.routes: - try: - return route.url_path_for(name, **path_params) - except NoMatchFound: - pass - raise NoMatchFound() - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - assert scope["type"] in ("http", "websocket", "lifespan") - - if "router" not in scope: - scope["router"] = self - - partial = None - - for route in self.routes: - match, child_scope = route.matches(scope) - if match == Match.FULL: - scope.update(child_scope) - await route(scope, receive, send) - return - elif match == Match.PARTIAL and partial is None: - partial = route - partial_scope = child_scope - - if partial is not None: - scope.update(partial_scope) - await partial(scope, receive, send) - return - - if scope["type"] == "http" and self.redirect_slashes: - if not scope["path"].endswith("/"): - redirect_scope = dict(scope) - redirect_scope["path"] += "/" - - for route in self.routes: - match, child_scope = route.matches(redirect_scope) - if match != Match.NONE: - redirect_url = URL(scope=redirect_scope) - response = RedirectResponse(url=str(redirect_url)) - await response(scope, receive, send) - return + self.on_shutdown.append(func) - if scope["type"] == "lifespan": - await self.lifespan(scope, receive, send) - else: - await self.default(scope, receive, send) + def on_event(self, event_type: str) -> typing.Callable: + def decorator(func: typing.Callable) -> typing.Callable: + self.add_event_handler(event_type, func) + return func - def __eq__(self, other: typing.Any) -> bool: - return isinstance(other, Router) and self.routes == other.routes + return decorator diff --git a/starlette/schemas.py b/starlette/schemas.py index 6d4b119f8..6ca764fdc 100644 --- a/starlette/schemas.py +++ b/starlette/schemas.py @@ -89,6 +89,8 @@ def parse_docstring(self, func_or_method: typing.Callable) -> dict: if not docstring: return {} + assert yaml is not None, "`pyyaml` must be installed to use parse_docstring." + # We support having regular docstrings before the schema # definition. Here we return just the schema part from # the docstring. diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py index 6f773d305..33ea0b033 100644 --- a/starlette/staticfiles.py +++ b/starlette/staticfiles.py @@ -4,7 +4,7 @@ import typing from email.utils import parsedate -from aiofiles.os import stat as aio_stat +import anyio from starlette.datastructures import URL, Headers from starlette.responses import ( @@ -15,6 +15,8 @@ ) from starlette.types import Receive, Scope, Send +PathLike = typing.Union[str, "os.PathLike[str]"] + class NotModifiedResponse(Response): NOT_MODIFIED_HEADERS = ( @@ -27,7 +29,7 @@ class NotModifiedResponse(Response): ) def __init__(self, headers: Headers): - return super().__init__( + super().__init__( status_code=304, headers={ name: value @@ -41,7 +43,7 @@ class StaticFiles: def __init__( self, *, - directory: str = None, + directory: PathLike = None, packages: typing.List[str] = None, html: bool = False, check_dir: bool = True, @@ -55,10 +57,10 @@ def __init__( raise RuntimeError(f"Directory '{directory}' does not exist") def get_directories( - self, directory: str = None, packages: typing.List[str] = None - ) -> typing.List[str]: + self, directory: PathLike = None, packages: typing.List[str] = None + ) -> typing.List[PathLike]: """ - Given `directory` and `packages` arugments, return a list of all the + Given `directory` and `packages` arguments, return a list of all the directories that should be used for serving static files from. """ directories = [] @@ -70,12 +72,14 @@ def get_directories( assert spec is not None, f"Package {package!r} could not be found." assert ( spec.origin is not None - ), "Directory 'statics' in package {package!r} could not be found." - directory = os.path.normpath(os.path.join(spec.origin, "..", "statics")) + ), f"Directory 'statics' in package {package!r} could not be found." + package_directory = os.path.normpath( + os.path.join(spec.origin, "..", "statics") + ) assert os.path.isdir( - directory - ), "Directory 'statics' in package {package!r} could not be found." - directories.append(directory) + package_directory + ), f"Directory 'statics' in package {package!r} could not be found." + directories.append(package_directory) return directories @@ -107,12 +111,6 @@ async def get_response(self, path: str, scope: Scope) -> Response: if scope["method"] not in ("GET", "HEAD"): return PlainTextResponse("Method Not Allowed", status_code=405) - if path.startswith(".."): - # Most clients will normalize the path, so we shouldn't normally - # get this, but don't allow misbehaving clients to break out of - # the static files directory. - return PlainTextResponse("Not Found", status_code=404) - full_path, stat_result = await self.lookup_path(path) if stat_result and stat.S_ISREG(stat_result.st_mode): @@ -136,8 +134,11 @@ async def get_response(self, path: str, scope: Scope) -> Response: # Check for '404.html' if we're in HTML mode. full_path, stat_result = await self.lookup_path("404.html") if stat_result is not None and stat.S_ISREG(stat_result.st_mode): - return self.file_response( - full_path, stat_result, scope, status_code=404 + return FileResponse( + full_path, + stat_result=stat_result, + method=scope["method"], + status_code=404, ) return PlainTextResponse("Not Found", status_code=404) @@ -145,19 +146,23 @@ async def get_response(self, path: str, scope: Scope) -> Response: async def lookup_path( self, path: str ) -> typing.Tuple[str, typing.Optional[os.stat_result]]: - stat_result = None for directory in self.all_directories: - full_path = os.path.join(directory, path) + full_path = os.path.realpath(os.path.join(directory, path)) + directory = os.path.realpath(directory) + if os.path.commonprefix([full_path, directory]) != directory: + # Don't allow misbehaving clients to break out of the static files + # directory. + continue try: - stat_result = await aio_stat(full_path) - return (full_path, stat_result) + stat_result = await anyio.to_thread.run_sync(os.stat, full_path) + return full_path, stat_result except FileNotFoundError: pass - return ("", None) + return "", None def file_response( self, - full_path: str, + full_path: PathLike, stat_result: os.stat_result, scope: Scope, status_code: int = 200, @@ -182,7 +187,7 @@ async def check_config(self) -> None: return try: - stat_result = await aio_stat(self.directory) + stat_result = await anyio.to_thread.run_sync(os.stat, self.directory) except FileNotFoundError: raise RuntimeError( f"StaticFiles directory '{self.directory}' does not exist." diff --git a/starlette/status.py b/starlette/status.py index 47c204b9e..b122ae85c 100644 --- a/starlette/status.py +++ b/starlette/status.py @@ -1,11 +1,14 @@ """ HTTP codes -See RFC 2616 - https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html -And RFC 6585 - https://tools.ietf.org/html/rfc6585 -And RFC 4918 - https://tools.ietf.org/html/rfc4918 +See HTTP Status Code Registry: +https://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml + +And RFC 2324 - https://tools.ietf.org/html/rfc2324 """ HTTP_100_CONTINUE = 100 HTTP_101_SWITCHING_PROTOCOLS = 101 +HTTP_102_PROCESSING = 102 +HTTP_103_EARLY_HINTS = 103 HTTP_200_OK = 200 HTTP_201_CREATED = 201 HTTP_202_ACCEPTED = 202 @@ -14,6 +17,8 @@ HTTP_205_RESET_CONTENT = 205 HTTP_206_PARTIAL_CONTENT = 206 HTTP_207_MULTI_STATUS = 207 +HTTP_208_ALREADY_REPORTED = 208 +HTTP_226_IM_USED = 226 HTTP_300_MULTIPLE_CHOICES = 300 HTTP_301_MOVED_PERMANENTLY = 301 HTTP_302_FOUND = 302 @@ -22,6 +27,7 @@ HTTP_305_USE_PROXY = 305 HTTP_306_RESERVED = 306 HTTP_307_TEMPORARY_REDIRECT = 307 +HTTP_308_PERMANENT_REDIRECT = 308 HTTP_400_BAD_REQUEST = 400 HTTP_401_UNAUTHORIZED = 401 HTTP_402_PAYMENT_REQUIRED = 402 @@ -40,9 +46,13 @@ HTTP_415_UNSUPPORTED_MEDIA_TYPE = 415 HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE = 416 HTTP_417_EXPECTATION_FAILED = 417 +HTTP_418_IM_A_TEAPOT = 418 +HTTP_421_MISDIRECTED_REQUEST = 421 HTTP_422_UNPROCESSABLE_ENTITY = 422 HTTP_423_LOCKED = 423 HTTP_424_FAILED_DEPENDENCY = 424 +HTTP_425_TOO_EARLY = 425 +HTTP_426_UPGRADE_REQUIRED = 426 HTTP_428_PRECONDITION_REQUIRED = 428 HTTP_429_TOO_MANY_REQUESTS = 429 HTTP_431_REQUEST_HEADER_FIELDS_TOO_LARGE = 431 @@ -53,7 +63,10 @@ HTTP_503_SERVICE_UNAVAILABLE = 503 HTTP_504_GATEWAY_TIMEOUT = 504 HTTP_505_HTTP_VERSION_NOT_SUPPORTED = 505 +HTTP_506_VARIANT_ALSO_NEGOTIATES = 506 HTTP_507_INSUFFICIENT_STORAGE = 507 +HTTP_508_LOOP_DETECTED = 508 +HTTP_510_NOT_EXTENDED = 510 HTTP_511_NETWORK_AUTHENTICATION_REQUIRED = 511 diff --git a/starlette/templating.py b/starlette/templating.py index 631b6bfee..36f613fdf 100644 --- a/starlette/templating.py +++ b/starlette/templating.py @@ -6,6 +6,12 @@ try: import jinja2 + + # @contextfunction renamed to @pass_context in Jinja 3.0, to be removed in 3.1 + if hasattr(jinja2, "pass_context"): + pass_context = jinja2.pass_context + else: # pragma: nocover + pass_context = jinja2.contextfunction except ImportError: # pragma: nocover jinja2 = None # type: ignore @@ -50,10 +56,10 @@ class Jinja2Templates: def __init__(self, directory: str) -> None: assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates" - self.env = self.get_env(directory) + self.env = self._create_env(directory) - def get_env(self, directory: str) -> "jinja2.Environment": - @jinja2.contextfunction + def _create_env(self, directory: str) -> "jinja2.Environment": + @pass_context def url_for(context: dict, name: str, **path_params: typing.Any) -> str: request = context["request"] return request.url_for(name, **path_params) diff --git a/starlette/testclient.py b/starlette/testclient.py index 7679207af..08d03fa5c 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -1,19 +1,35 @@ import asyncio +import contextlib import http import inspect import io import json +import math import queue -import threading +import sys import types import typing +from concurrent.futures import Future from urllib.parse import unquote, urljoin, urlsplit +import anyio.abc import requests +from anyio.streams.stapled import StapledObjectStream from starlette.types import Message, Receive, Scope, Send from starlette.websockets import WebSocketDisconnect +if sys.version_info >= (3, 8): # pragma: no cover + from typing import TypedDict +else: # pragma: no cover + from typing_extensions import TypedDict + + +_PortalFactoryType = typing.Callable[ + [], typing.ContextManager[anyio.abc.BlockingPortal] +] + + # Annotations for `Session.request()` Cookies = typing.Union[ typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar @@ -25,7 +41,7 @@ AuthType = typing.Union[ typing.Tuple[str, str], requests.auth.AuthBase, - typing.Callable[[requests.Request], requests.Request], + typing.Callable[[requests.PreparedRequest], requests.PreparedRequest], ] @@ -87,15 +103,30 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await instance(receive, send) +class _AsyncBackend(TypedDict): + backend: str + backend_options: typing.Dict[str, typing.Any] + + class _ASGIAdapter(requests.adapters.HTTPAdapter): - def __init__(self, app: ASGI3App, raise_server_exceptions: bool = True) -> None: + def __init__( + self, + app: ASGI3App, + portal_factory: _PortalFactoryType, + raise_server_exceptions: bool = True, + root_path: str = "", + ) -> None: self.app = app self.raise_server_exceptions = raise_server_exceptions + self.root_path = root_path + self.portal_factory = portal_factory - def send( # type: ignore + def send( self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any ) -> requests.Response: - scheme, netloc, path, query, fragment = urlsplit(request.url) # type: ignore + scheme, netloc, path, query, fragment = ( + str(item) for item in urlsplit(request.url) + ) default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme] @@ -108,7 +139,7 @@ def send( # type: ignore # Include the 'host' header. if "host" in request.headers: - headers = [] # type: typing.List[typing.Tuple[bytes, bytes]] + headers: typing.List[typing.Tuple[bytes, bytes]] = [] elif port == default_port: headers = [(b"host", host.encode())] else: @@ -123,13 +154,13 @@ def send( # type: ignore if scheme in {"ws", "wss"}: subprotocol = request.headers.get("sec-websocket-protocol", None) if subprotocol is None: - subprotocols = [] # type: typing.Sequence[str] + subprotocols: typing.Sequence[str] = [] else: subprotocols = [value.strip() for value in subprotocol.split(",")] scope = { "type": "websocket", "path": unquote(path), - "root_path": "", + "root_path": self.root_path, "scheme": scheme, "query_string": query.encode(), "headers": headers, @@ -137,7 +168,7 @@ def send( # type: ignore "server": [host, port], "subprotocols": subprotocols, } - session = WebSocketTestSession(self.app, scope) + session = WebSocketTestSession(self.app, scope, self.portal_factory) raise _Upgrade(session) scope = { @@ -145,7 +176,7 @@ def send( # type: ignore "http_version": "1.1", "method": request.method, "path": unquote(path), - "root_path": "", + "root_path": self.root_path, "scheme": scheme, "query_string": query.encode(), "headers": headers, @@ -154,17 +185,24 @@ def send( # type: ignore "extensions": {"http.response.template": {}}, } + request_complete = False + response_started = False + response_complete: anyio.Event + raw_kwargs: typing.Dict[str, typing.Any] = {"body": io.BytesIO()} + template = None + context = None + async def receive() -> Message: - nonlocal request_complete, response_complete + nonlocal request_complete if request_complete: - while not response_complete: - await asyncio.sleep(0.0001) + if not response_complete.is_set(): + await response_complete.wait() return {"type": "http.disconnect"} body = request.body if isinstance(body, str): - body_bytes = body.encode("utf-8") # type: bytes + body_bytes: bytes = body.encode("utf-8") elif body is None: body_bytes = b"" elif isinstance(body, types.GeneratorType): @@ -183,7 +221,7 @@ async def receive() -> Message: return {"type": "http.request", "body": body_bytes} async def send(message: Message) -> None: - nonlocal raw_kwargs, response_started, response_complete, template, context + nonlocal raw_kwargs, response_started, template, context if message["type"] == "http.response.start": assert ( @@ -193,7 +231,8 @@ async def send(message: Message) -> None: raw_kwargs["status"] = message["status"] raw_kwargs["reason"] = _get_reason_phrase(message["status"]) raw_kwargs["headers"] = [ - (key.decode(), value.decode()) for key, value in message["headers"] + (key.decode(), value.decode()) + for key, value in message.get("headers", []) ] raw_kwargs["preload_content"] = False raw_kwargs["original_response"] = _MockOriginalResponse( @@ -205,7 +244,7 @@ async def send(message: Message) -> None: response_started ), 'Received "http.response.body" without "http.response.start".' assert ( - not response_complete + not response_complete.is_set() ), 'Received "http.response.body" after response completed.' body = message.get("body", b"") more_body = message.get("more_body", False) @@ -213,29 +252,18 @@ async def send(message: Message) -> None: raw_kwargs["body"].write(body) if not more_body: raw_kwargs["body"].seek(0) - response_complete = True + response_complete.set() elif message["type"] == "http.response.template": template = message["template"] context = message["context"] - request_complete = False - response_started = False - response_complete = False - raw_kwargs = {"body": io.BytesIO()} # type: typing.Dict[str, typing.Any] - template = None - context = None - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete(self.app(scope, receive, send)) + with self.portal_factory() as portal: + response_complete = portal.call(anyio.Event) + portal.call(self.app, scope, receive, send) except BaseException as exc: if self.raise_server_exceptions: - raise exc from None + raise exc if self.raise_server_exceptions: assert response_started, "TestClient did not receive any response." @@ -259,32 +287,45 @@ async def send(message: Message) -> None: class WebSocketTestSession: - def __init__(self, app: ASGI3App, scope: Scope) -> None: + def __init__( + self, + app: ASGI3App, + scope: Scope, + portal_factory: _PortalFactoryType, + ) -> None: self.app = app self.scope = scope self.accepted_subprotocol = None - self._loop = asyncio.new_event_loop() - self._receive_queue = queue.Queue() # type: queue.Queue - self._send_queue = queue.Queue() # type: queue.Queue - self._thread = threading.Thread(target=self._run) - self.send({"type": "websocket.connect"}) - self._thread.start() - message = self.receive() - self._raise_on_close(message) - self.accepted_subprotocol = message.get("subprotocol", None) + self.portal_factory = portal_factory + self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue() + self._send_queue: "queue.Queue[typing.Any]" = queue.Queue() def __enter__(self) -> "WebSocketTestSession": + self.exit_stack = contextlib.ExitStack() + self.portal = self.exit_stack.enter_context(self.portal_factory()) + + try: + _: "Future[None]" = self.portal.start_task_soon(self._run) + self.send({"type": "websocket.connect"}) + message = self.receive() + self._raise_on_close(message) + except Exception: + self.exit_stack.close() + raise + self.accepted_subprotocol = message.get("subprotocol", None) return self def __exit__(self, *args: typing.Any) -> None: - self.close(1000) - self._thread.join() + try: + self.close(1000) + finally: + self.exit_stack.close() while not self._send_queue.empty(): message = self._send_queue.get() if isinstance(message, BaseException): raise message - def _run(self) -> None: + async def _run(self) -> None: """ The sub-thread in which the websocket session runs. """ @@ -292,13 +333,14 @@ def _run(self) -> None: receive = self._asgi_receive send = self._asgi_send try: - self._loop.run_until_complete(self.app(scope, receive, send)) + await self.app(scope, receive, send) except BaseException as exc: self._send_queue.put(exc) + raise async def _asgi_receive(self) -> Message: while self._receive_queue.empty(): - await asyncio.sleep(0) + await anyio.sleep(0) return self._receive_queue.get() async def _asgi_send(self, message: Message) -> None: @@ -357,14 +399,22 @@ def receive_json(self, mode: str = "text") -> typing.Any: class TestClient(requests.Session): __test__ = False # For pytest to not discover this up. + task: "Future[None]" + portal: typing.Optional[anyio.abc.BlockingPortal] = None def __init__( self, app: typing.Union[ASGI2App, ASGI3App], base_url: str = "http://testserver", raise_server_exceptions: bool = True, + root_path: str = "", + backend: str = "asyncio", + backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> None: - super(TestClient, self).__init__() + super().__init__() + self.async_backend = _AsyncBackend( + backend=backend, backend_options=backend_options or {} + ) if _is_asgi3(app): app = typing.cast(ASGI3App, app) asgi_app = app @@ -372,7 +422,10 @@ def __init__( app = typing.cast(ASGI2App, app) asgi_app = _WrapASGI2(app) #  type: ignore adapter = _ASGIAdapter( - asgi_app, raise_server_exceptions=raise_server_exceptions + asgi_app, + portal_factory=self._portal_factory, + raise_server_exceptions=raise_server_exceptions, + root_path=root_path, ) self.mount("http://", adapter) self.mount("https://", adapter) @@ -382,7 +435,17 @@ def __init__( self.app = asgi_app self.base_url = base_url - def request( + @contextlib.contextmanager + def _portal_factory( + self, + ) -> typing.Generator[anyio.abc.BlockingPortal, None, None]: + if self.portal is not None: + yield self.portal + else: + with anyio.start_blocking_portal(**self.async_backend) as portal: + yield portal + + def request( # type: ignore self, method: str, url: str, @@ -441,36 +504,73 @@ def websocket_connect( return session - def __enter__(self) -> requests.Session: - loop = asyncio.get_event_loop() - self.send_queue = asyncio.Queue() # type: asyncio.Queue - self.receive_queue = asyncio.Queue() # type: asyncio.Queue - self.task = loop.create_task(self.lifespan()) - loop.run_until_complete(self.wait_startup()) + def __enter__(self) -> "TestClient": + with contextlib.ExitStack() as stack: + self.portal = portal = stack.enter_context( + anyio.start_blocking_portal(**self.async_backend) + ) + + @stack.callback + def reset_portal() -> None: + self.portal = None + + self.stream_send = StapledObjectStream( + *anyio.create_memory_object_stream(math.inf) + ) + self.stream_receive = StapledObjectStream( + *anyio.create_memory_object_stream(math.inf) + ) + self.task = portal.start_task_soon(self.lifespan) + portal.call(self.wait_startup) + + @stack.callback + def wait_shutdown() -> None: + portal.call(self.wait_shutdown) + + self.exit_stack = stack.pop_all() + return self def __exit__(self, *args: typing.Any) -> None: - loop = asyncio.get_event_loop() - loop.run_until_complete(self.wait_shutdown()) + self.exit_stack.close() async def lifespan(self) -> None: scope = {"type": "lifespan"} try: - await self.app(scope, self.receive_queue.get, self.send_queue.put) + await self.app(scope, self.stream_receive.receive, self.stream_send.send) finally: - await self.send_queue.put(None) + await self.stream_send.send(None) async def wait_startup(self) -> None: - await self.receive_queue.put({"type": "lifespan.startup"}) - message = await self.send_queue.get() - if message is None: - self.task.result() - assert message["type"] == "lifespan.startup.complete" + await self.stream_receive.send({"type": "lifespan.startup"}) + + async def receive() -> typing.Any: + message = await self.stream_send.receive() + if message is None: + self.task.result() + return message + + message = await receive() + assert message["type"] in ( + "lifespan.startup.complete", + "lifespan.startup.failed", + ) + if message["type"] == "lifespan.startup.failed": + await receive() async def wait_shutdown(self) -> None: - await self.receive_queue.put({"type": "lifespan.shutdown"}) - message = await self.send_queue.get() - if message is None: - self.task.result() - assert message["type"] == "lifespan.shutdown.complete" - await self.task + async def receive() -> typing.Any: + message = await self.stream_send.receive() + if message is None: + self.task.result() + return message + + async with self.stream_send: + await self.stream_receive.send({"type": "lifespan.shutdown"}) + message = await receive() + assert message["type"] in ( + "lifespan.shutdown.complete", + "lifespan.shutdown.failed", + ) + if message["type"] == "lifespan.shutdown.failed": + await receive() diff --git a/starlette/websockets.py b/starlette/websockets.py index 39af91d67..b9b8844d6 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -103,6 +103,27 @@ async def receive_json(self, mode: str = "text") -> typing.Any: text = message["bytes"].decode("utf-8") return json.loads(text) + async def iter_text(self) -> typing.AsyncIterator[str]: + try: + while True: + yield await self.receive_text() + except WebSocketDisconnect: + pass + + async def iter_bytes(self) -> typing.AsyncIterator[bytes]: + try: + while True: + yield await self.receive_bytes() + except WebSocketDisconnect: + pass + + async def iter_json(self) -> typing.AsyncIterator[typing.Any]: + try: + while True: + yield await self.receive_json() + except WebSocketDisconnect: + pass + async def send_text(self, data: str) -> None: await self.send({"type": "websocket.send", "text": data}) @@ -125,5 +146,5 @@ class WebSocketClose: def __init__(self, code: int = 1000) -> None: self.code = code - async def __call__(self, receive: Receive, send: Send) -> None: + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await send({"type": "websocket.close", "code": self.code}) diff --git a/tests/.ignore_lifespan b/tests/.ignore_lifespan deleted file mode 100644 index 0a3358217..000000000 --- a/tests/.ignore_lifespan +++ /dev/null @@ -1,3 +0,0 @@ -[coverage:run] -omit = - starlette/middleware/lifespan.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..bb68aa5e2 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,25 @@ +import functools +import sys + +import pytest + +from starlette.testclient import TestClient + +collect_ignore = ["test_graphql.py"] if sys.version_info >= (3, 10) else [] + + +@pytest.fixture +def no_trio_support(anyio_backend_name): + if anyio_backend_name == "trio": + pytest.skip("Trio not supported (yet!)") + + +@pytest.fixture +def test_client_factory(anyio_backend_name, anyio_backend_options): + # anyio_backend_name defined by: + # https://anyio.readthedocs.io/en/stable/testing.html#specifying-the-backends-to-run-on + return functools.partial( + TestClient, + backend=anyio_backend_name, + backend_options=anyio_backend_options, + ) diff --git a/tests/middleware/__init__.py b/tests/middleware/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 04fa7d69c..8a8df4ea6 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1,9 +1,10 @@ import pytest from starlette.applications import Starlette +from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import PlainTextResponse -from starlette.testclient import TestClient +from starlette.routing import Route class CustomMiddleware(BaseHTTPMiddleware): @@ -46,8 +47,8 @@ async def websocket_endpoint(session): await session.close() -def test_custom_middleware(): - client = TestClient(app) +def test_custom_middleware(test_client_factory): + client = test_client_factory(app) response = client.get("/") assert response.headers["Custom-Header"] == "Example" @@ -62,7 +63,7 @@ def test_custom_middleware(): assert text == "Hello, world!" -def test_middleware_decorator(): +def test_middleware_decorator(test_client_factory): app = Starlette() @app.route("/homepage") @@ -77,10 +78,82 @@ async def plaintext(request, call_next): response.headers["Custom"] = "Example" return response - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "OK" response = client.get("/homepage") assert response.text == "Homepage" assert response.headers["Custom"] == "Example" + + +def test_state_data_across_multiple_middlewares(test_client_factory): + expected_value1 = "foo" + expected_value2 = "bar" + + class aMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + request.state.foo = expected_value1 + response = await call_next(request) + return response + + class bMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + request.state.bar = expected_value2 + response = await call_next(request) + response.headers["X-State-Foo"] = request.state.foo + return response + + class cMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + response = await call_next(request) + response.headers["X-State-Bar"] = request.state.bar + return response + + app = Starlette() + app.add_middleware(aMiddleware) + app.add_middleware(bMiddleware) + app.add_middleware(cMiddleware) + + @app.route("/") + def homepage(request): + return PlainTextResponse("OK") + + client = test_client_factory(app) + response = client.get("/") + assert response.text == "OK" + assert response.headers["X-State-Foo"] == expected_value1 + assert response.headers["X-State-Bar"] == expected_value2 + + +def test_app_middleware_argument(test_client_factory): + def homepage(request): + return PlainTextResponse("Homepage") + + app = Starlette( + routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)] + ) + + client = test_client_factory(app) + response = client.get("/") + assert response.headers["Custom-Header"] == "Example" + + +def test_middleware_repr(): + middleware = Middleware(CustomMiddleware) + assert repr(middleware) == "Middleware(CustomMiddleware)" + + +def test_fully_evaluated_response(test_client_factory): + # Test for https://github.com/encode/starlette/issues/1022 + class CustomMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + await call_next(request) + return PlainTextResponse("Custom") + + app = Starlette() + app.add_middleware(CustomMiddleware) + + client = test_client_factory(app) + response = client.get("/does_not_exist") + assert response.text == "Custom" diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py index bcba70412..65252e502 100644 --- a/tests/middleware/test_cors.py +++ b/tests/middleware/test_cors.py @@ -1,10 +1,9 @@ from starlette.applications import Starlette from starlette.middleware.cors import CORSMiddleware from starlette.responses import PlainTextResponse -from starlette.testclient import TestClient -def test_cors_allow_all(): +def test_cors_allow_all(test_client_factory): app = Starlette() app.add_middleware( @@ -20,7 +19,63 @@ def test_cors_allow_all(): def homepage(request): return PlainTextResponse("Homepage", status_code=200) - client = TestClient(app) + client = test_client_factory(app) + + # Test pre-flight response + headers = { + "Origin": "https://example.org", + "Access-Control-Request-Method": "GET", + "Access-Control-Request-Headers": "X-Example", + } + response = client.options("/", headers=headers) + assert response.status_code == 200 + assert response.text == "OK" + assert response.headers["access-control-allow-origin"] == "https://example.org" + assert response.headers["access-control-allow-headers"] == "X-Example" + assert response.headers["access-control-allow-credentials"] == "true" + assert response.headers["vary"] == "Origin" + + # Test standard response + headers = {"Origin": "https://example.org"} + response = client.get("/", headers=headers) + assert response.status_code == 200 + assert response.text == "Homepage" + assert response.headers["access-control-allow-origin"] == "*" + assert response.headers["access-control-expose-headers"] == "X-Status" + assert response.headers["access-control-allow-credentials"] == "true" + + # Test standard credentialed response + headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"} + response = client.get("/", headers=headers) + assert response.status_code == 200 + assert response.text == "Homepage" + assert response.headers["access-control-allow-origin"] == "https://example.org" + assert response.headers["access-control-expose-headers"] == "X-Status" + assert response.headers["access-control-allow-credentials"] == "true" + + # Test non-CORS response + response = client.get("/") + assert response.status_code == 200 + assert response.text == "Homepage" + assert "access-control-allow-origin" not in response.headers + + +def test_cors_allow_all_except_credentials(test_client_factory): + app = Starlette() + + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_headers=["*"], + allow_methods=["*"], + expose_headers=["X-Status"], + ) + + @app.route("/") + def homepage(request): + return PlainTextResponse("Homepage", status_code=200) + + client = test_client_factory(app) # Test pre-flight response headers = { @@ -33,6 +88,8 @@ def homepage(request): assert response.text == "OK" assert response.headers["access-control-allow-origin"] == "*" assert response.headers["access-control-allow-headers"] == "X-Example" + assert "access-control-allow-credentials" not in response.headers + assert "vary" not in response.headers # Test standard response headers = {"Origin": "https://example.org"} @@ -41,6 +98,7 @@ def homepage(request): assert response.text == "Homepage" assert response.headers["access-control-allow-origin"] == "*" assert response.headers["access-control-expose-headers"] == "X-Status" + assert "access-control-allow-credentials" not in response.headers # Test non-CORS response response = client.get("/") @@ -49,7 +107,7 @@ def homepage(request): assert "access-control-allow-origin" not in response.headers -def test_cors_allow_specific_origin(): +def test_cors_allow_specific_origin(test_client_factory): app = Starlette() app.add_middleware( @@ -62,7 +120,7 @@ def test_cors_allow_specific_origin(): def homepage(request): return PlainTextResponse("Homepage", status_code=200) - client = TestClient(app) + client = test_client_factory(app) # Test pre-flight response headers = { @@ -74,7 +132,10 @@ def homepage(request): assert response.status_code == 200 assert response.text == "OK" assert response.headers["access-control-allow-origin"] == "https://example.org" - assert response.headers["access-control-allow-headers"] == "X-Example, Content-Type" + assert response.headers["access-control-allow-headers"] == ( + "Accept, Accept-Language, Content-Language, Content-Type, X-Example" + ) + assert "access-control-allow-credentials" not in response.headers # Test standard response headers = {"Origin": "https://example.org"} @@ -82,6 +143,7 @@ def homepage(request): assert response.status_code == 200 assert response.text == "Homepage" assert response.headers["access-control-allow-origin"] == "https://example.org" + assert "access-control-allow-credentials" not in response.headers # Test non-CORS response response = client.get("/") @@ -90,7 +152,7 @@ def homepage(request): assert "access-control-allow-origin" not in response.headers -def test_cors_disallowed_preflight(): +def test_cors_disallowed_preflight(test_client_factory): app = Starlette() app.add_middleware( @@ -103,7 +165,7 @@ def test_cors_disallowed_preflight(): def homepage(request): pass # pragma: no cover - client = TestClient(app) + client = test_client_factory(app) # Test pre-flight response headers = { @@ -114,22 +176,117 @@ def homepage(request): response = client.options("/", headers=headers) assert response.status_code == 400 assert response.text == "Disallowed CORS origin, method, headers" + assert "access-control-allow-origin" not in response.headers + + # Bug specific test, https://github.com/encode/starlette/pull/1199 + # Test preflight response text with multiple disallowed headers + headers = { + "Origin": "https://example.org", + "Access-Control-Request-Method": "GET", + "Access-Control-Request-Headers": "X-Nope-1, X-Nope-2", + } + response = client.options("/", headers=headers) + assert response.text == "Disallowed CORS headers" -def test_cors_allow_origin_regex(): +def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed( + test_client_factory, +): + app = Starlette() + + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["POST"], + allow_credentials=True, + ) + + @app.route("/") + def homepage(request): + return # pragma: no cover + + client = test_client_factory(app) + + # Test pre-flight response + headers = { + "Origin": "https://example.org", + "Access-Control-Request-Method": "POST", + } + response = client.options( + "/", + headers=headers, + ) + assert response.status_code == 200 + assert response.headers["access-control-allow-origin"] == "https://example.org" + assert response.headers["access-control-allow-credentials"] == "true" + assert response.headers["vary"] == "Origin" + + +def test_cors_preflight_allow_all_methods(test_client_factory): + app = Starlette() + + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["*"], + ) + + @app.route("/") + def homepage(request): + pass # pragma: no cover + + client = test_client_factory(app) + + headers = { + "Origin": "https://example.org", + "Access-Control-Request-Method": "POST", + } + + for method in ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"): + response = client.options("/", headers=headers) + assert response.status_code == 200 + assert method in response.headers["access-control-allow-methods"] + + +def test_cors_allow_all_methods(test_client_factory): + app = Starlette() + + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["*"], + ) + + @app.route( + "/", methods=("delete", "get", "head", "options", "patch", "post", "put") + ) + def homepage(request): + return PlainTextResponse("Homepage", status_code=200) + + client = test_client_factory(app) + + headers = {"Origin": "https://example.org"} + + for method in ("delete", "get", "head", "options", "patch", "post", "put"): + response = getattr(client, method)("/", headers=headers, json={}) + assert response.status_code == 200 + + +def test_cors_allow_origin_regex(test_client_factory): app = Starlette() app.add_middleware( CORSMiddleware, allow_headers=["X-Example", "Content-Type"], - allow_origin_regex="https://*", + allow_origin_regex="https://.*", + allow_credentials=True, ) @app.route("/") def homepage(request): return PlainTextResponse("Homepage", status_code=200) - client = TestClient(app) + client = test_client_factory(app) # Test standard response headers = {"Origin": "https://example.org"} @@ -137,8 +294,17 @@ def homepage(request): assert response.status_code == 200 assert response.text == "Homepage" assert response.headers["access-control-allow-origin"] == "https://example.org" + assert response.headers["access-control-allow-credentials"] == "true" - # Test diallowed standard response + # Test standard credentialed response + headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"} + response = client.get("/", headers=headers) + assert response.status_code == 200 + assert response.text == "Homepage" + assert response.headers["access-control-allow-origin"] == "https://example.org" + assert response.headers["access-control-allow-credentials"] == "true" + + # Test disallowed standard response # Note that enforcement is a browser concern. The disallowed-ness is reflected # in the lack of an "access-control-allow-origin" header in the response. headers = {"Origin": "http://example.org"} @@ -157,7 +323,10 @@ def homepage(request): assert response.status_code == 200 assert response.text == "OK" assert response.headers["access-control-allow-origin"] == "https://another.com" - assert response.headers["access-control-allow-headers"] == "X-Example, Content-Type" + assert response.headers["access-control-allow-headers"] == ( + "Accept, Accept-Language, Content-Language, Content-Type, X-Example" + ) + assert response.headers["access-control-allow-credentials"] == "true" # Test disallowed pre-flight response headers = { @@ -171,7 +340,41 @@ def homepage(request): assert "access-control-allow-origin" not in response.headers -def test_cors_credentialed_requests_return_specific_origin(): +def test_cors_allow_origin_regex_fullmatch(test_client_factory): + app = Starlette() + + app.add_middleware( + CORSMiddleware, + allow_headers=["X-Example", "Content-Type"], + allow_origin_regex=r"https://.*\.example.org", + ) + + @app.route("/") + def homepage(request): + return PlainTextResponse("Homepage", status_code=200) + + client = test_client_factory(app) + + # Test standard response + headers = {"Origin": "https://subdomain.example.org"} + response = client.get("/", headers=headers) + assert response.status_code == 200 + assert response.text == "Homepage" + assert ( + response.headers["access-control-allow-origin"] + == "https://subdomain.example.org" + ) + assert "access-control-allow-credentials" not in response.headers + + # Test diallowed standard response + headers = {"Origin": "https://subdomain.example.org.hacker.com"} + response = client.get("/", headers=headers) + assert response.status_code == 200 + assert response.text == "Homepage" + assert "access-control-allow-origin" not in response.headers + + +def test_cors_credentialed_requests_return_specific_origin(test_client_factory): app = Starlette() app.add_middleware(CORSMiddleware, allow_origins=["*"]) @@ -180,7 +383,7 @@ def test_cors_credentialed_requests_return_specific_origin(): def homepage(request): return PlainTextResponse("Homepage", status_code=200) - client = TestClient(app) + client = test_client_factory(app) # Test credentialed request headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"} @@ -188,9 +391,10 @@ def homepage(request): assert response.status_code == 200 assert response.text == "Homepage" assert response.headers["access-control-allow-origin"] == "https://example.org" + assert "access-control-allow-credentials" not in response.headers -def test_cors_vary_header_defaults_to_origin(): +def test_cors_vary_header_defaults_to_origin(test_client_factory): app = Starlette() app.add_middleware(CORSMiddleware, allow_origins=["https://example.org"]) @@ -201,19 +405,35 @@ def test_cors_vary_header_defaults_to_origin(): def homepage(request): return PlainTextResponse("Homepage", status_code=200) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers=headers) assert response.status_code == 200 assert response.headers["vary"] == "Origin" -def test_cors_vary_header_is_properly_set(): +def test_cors_vary_header_is_not_set_for_non_credentialed_request(test_client_factory): app = Starlette() - app.add_middleware(CORSMiddleware, allow_origins=["https://example.org"]) + app.add_middleware(CORSMiddleware, allow_origins=["*"]) - headers = {"Origin": "https://example.org"} + @app.route("/") + def homepage(request): + return PlainTextResponse( + "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"} + ) + + client = test_client_factory(app) + + response = client.get("/", headers={"Origin": "https://someplace.org"}) + assert response.status_code == 200 + assert response.headers["vary"] == "Accept-Encoding" + + +def test_cors_vary_header_is_properly_set_for_credentialed_request(test_client_factory): + app = Starlette() + + app.add_middleware(CORSMiddleware, allow_origins=["*"]) @app.route("/") def homepage(request): @@ -221,15 +441,40 @@ def homepage(request): "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"} ) - client = TestClient(app) + client = test_client_factory(app) - response = client.get("/", headers=headers) + response = client.get( + "/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"} + ) assert response.status_code == 200 assert response.headers["vary"] == "Accept-Encoding, Origin" -def test_cors_allowed_origin_does_not_leak_between_credentialed_requests(): +def test_cors_vary_header_is_properly_set_when_allow_origins_is_not_wildcard( + test_client_factory, +): app = Starlette() + + app.add_middleware(CORSMiddleware, allow_origins=["https://example.org"]) + + @app.route("/") + def homepage(request): + return PlainTextResponse( + "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"} + ) + + client = test_client_factory(app) + + response = client.get("/", headers={"Origin": "https://example.org"}) + assert response.status_code == 200 + assert response.headers["vary"] == "Accept-Encoding, Origin" + + +def test_cors_allowed_origin_does_not_leak_between_credentialed_requests( + test_client_factory, +): + app = Starlette() + app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_headers=["*"], allow_methods=["*"] ) @@ -238,14 +483,17 @@ def test_cors_allowed_origin_does_not_leak_between_credentialed_requests(): def homepage(request): return PlainTextResponse("Homepage", status_code=200) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"Origin": "https://someplace.org"}) assert response.headers["access-control-allow-origin"] == "*" + assert "access-control-allow-credentials" not in response.headers response = client.get( "/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"} ) assert response.headers["access-control-allow-origin"] == "https://someplace.org" + assert "access-control-allow-credentials" not in response.headers response = client.get("/", headers={"Origin": "https://someplace.org"}) assert response.headers["access-control-allow-origin"] == "*" + assert "access-control-allow-credentials" not in response.headers diff --git a/tests/middleware/test_errors.py b/tests/middleware/test_errors.py index 768a4ee0b..146b6f90d 100644 --- a/tests/middleware/test_errors.py +++ b/tests/middleware/test_errors.py @@ -2,11 +2,10 @@ from starlette.middleware.errors import ServerErrorMiddleware from starlette.responses import JSONResponse, Response -from starlette.testclient import TestClient -from starlette.websockets import WebSocket, WebSocketDisconnect +from starlette.websockets import WebSocketDisconnect -def test_handler(): +def test_handler(test_client_factory): async def app(scope, receive, send): raise RuntimeError("Something went wrong") @@ -14,54 +13,55 @@ def error_500(request, exc): return JSONResponse({"detail": "Server Error"}, status_code=500) app = ServerErrorMiddleware(app, handler=error_500) - client = TestClient(app, raise_server_exceptions=False) + client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/") assert response.status_code == 500 assert response.json() == {"detail": "Server Error"} -def test_debug_text(): +def test_debug_text(test_client_factory): async def app(scope, receive, send): raise RuntimeError("Something went wrong") app = ServerErrorMiddleware(app, debug=True) - client = TestClient(app, raise_server_exceptions=False) + client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/") assert response.status_code == 500 assert response.headers["content-type"].startswith("text/plain") - assert "RuntimeError" in response.text + assert "RuntimeError: Something went wrong" in response.text -def test_debug_html(): +def test_debug_html(test_client_factory): async def app(scope, receive, send): raise RuntimeError("Something went wrong") app = ServerErrorMiddleware(app, debug=True) - client = TestClient(app, raise_server_exceptions=False) + client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/", headers={"Accept": "text/html, */*"}) assert response.status_code == 500 assert response.headers["content-type"].startswith("text/html") assert "RuntimeError" in response.text -def test_debug_after_response_sent(): +def test_debug_after_response_sent(test_client_factory): async def app(scope, receive, send): response = Response(b"", status_code=204) await response(scope, receive, send) raise RuntimeError("Something went wrong") app = ServerErrorMiddleware(app, debug=True) - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(RuntimeError): client.get("/") -def test_debug_websocket(): +def test_debug_websocket(test_client_factory): async def app(scope, receive, send): raise RuntimeError("Something went wrong") app = ServerErrorMiddleware(app) with pytest.raises(WebSocketDisconnect): - client = TestClient(app) - client.websocket_connect("/") + client = test_client_factory(app) + with client.websocket_connect("/"): + pass # pragma: nocover diff --git a/tests/middleware/test_gzip.py b/tests/middleware/test_gzip.py index cd989b8c1..b917ea4db 100644 --- a/tests/middleware/test_gzip.py +++ b/tests/middleware/test_gzip.py @@ -1,10 +1,9 @@ from starlette.applications import Starlette from starlette.middleware.gzip import GZipMiddleware from starlette.responses import PlainTextResponse, StreamingResponse -from starlette.testclient import TestClient -def test_gzip_responses(): +def test_gzip_responses(test_client_factory): app = Starlette() app.add_middleware(GZipMiddleware) @@ -13,7 +12,7 @@ def test_gzip_responses(): def homepage(request): return PlainTextResponse("x" * 4000, status_code=200) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"accept-encoding": "gzip"}) assert response.status_code == 200 assert response.text == "x" * 4000 @@ -21,7 +20,7 @@ def homepage(request): assert int(response.headers["Content-Length"]) < 4000 -def test_gzip_not_in_accept_encoding(): +def test_gzip_not_in_accept_encoding(test_client_factory): app = Starlette() app.add_middleware(GZipMiddleware) @@ -30,7 +29,7 @@ def test_gzip_not_in_accept_encoding(): def homepage(request): return PlainTextResponse("x" * 4000, status_code=200) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"accept-encoding": "identity"}) assert response.status_code == 200 assert response.text == "x" * 4000 @@ -38,7 +37,7 @@ def homepage(request): assert int(response.headers["Content-Length"]) == 4000 -def test_gzip_ignored_for_small_responses(): +def test_gzip_ignored_for_small_responses(test_client_factory): app = Starlette() app.add_middleware(GZipMiddleware) @@ -47,7 +46,7 @@ def test_gzip_ignored_for_small_responses(): def homepage(request): return PlainTextResponse("OK", status_code=200) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"accept-encoding": "gzip"}) assert response.status_code == 200 assert response.text == "OK" @@ -55,7 +54,7 @@ def homepage(request): assert int(response.headers["Content-Length"]) == 2 -def test_gzip_streaming_response(): +def test_gzip_streaming_response(test_client_factory): app = Starlette() app.add_middleware(GZipMiddleware) @@ -69,7 +68,7 @@ async def generator(bytes, count): streaming = generator(bytes=b"x" * 400, count=10) return StreamingResponse(streaming, status_code=200) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"accept-encoding": "gzip"}) assert response.status_code == 200 assert response.text == "x" * 4000 diff --git a/tests/middleware/test_https_redirect.py b/tests/middleware/test_https_redirect.py index 15f1e3fe0..8db950634 100644 --- a/tests/middleware/test_https_redirect.py +++ b/tests/middleware/test_https_redirect.py @@ -1,10 +1,9 @@ from starlette.applications import Starlette from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware from starlette.responses import PlainTextResponse -from starlette.testclient import TestClient -def test_https_redirect_middleware(): +def test_https_redirect_middleware(test_client_factory): app = Starlette() app.add_middleware(HTTPSRedirectMiddleware) @@ -13,26 +12,26 @@ def test_https_redirect_middleware(): def homepage(request): return PlainTextResponse("OK", status_code=200) - client = TestClient(app, base_url="https://testserver") + client = test_client_factory(app, base_url="https://testserver") response = client.get("/") assert response.status_code == 200 - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", allow_redirects=False) - assert response.status_code == 301 + assert response.status_code == 307 assert response.headers["location"] == "https://testserver/" - client = TestClient(app, base_url="http://testserver:80") + client = test_client_factory(app, base_url="http://testserver:80") response = client.get("/", allow_redirects=False) - assert response.status_code == 301 + assert response.status_code == 307 assert response.headers["location"] == "https://testserver/" - client = TestClient(app, base_url="http://testserver:443") + client = test_client_factory(app, base_url="http://testserver:443") response = client.get("/", allow_redirects=False) - assert response.status_code == 301 + assert response.status_code == 307 assert response.headers["location"] == "https://testserver/" - client = TestClient(app, base_url="http://testserver:123") + client = test_client_factory(app, base_url="http://testserver:123") response = client.get("/", allow_redirects=False) - assert response.status_code == 301 + assert response.status_code == 307 assert response.headers["location"] == "https://testserver:123/" diff --git a/tests/middleware/test_lifespan.py b/tests/middleware/test_lifespan.py deleted file mode 100644 index 822468964..000000000 --- a/tests/middleware/test_lifespan.py +++ /dev/null @@ -1,108 +0,0 @@ -import pytest - -from starlette.applications import Starlette -from starlette.responses import PlainTextResponse -from starlette.routing import Lifespan, Route, Router -from starlette.testclient import TestClient - - -def test_routed_lifespan(): - startup_complete = False - shutdown_complete = False - - def hello_world(request): - return PlainTextResponse("hello, world") - - def run_startup(): - nonlocal startup_complete - startup_complete = True - - def run_shutdown(): - nonlocal shutdown_complete - shutdown_complete = True - - app = Router( - routes=[ - Lifespan(on_startup=run_startup, on_shutdown=run_shutdown), - Route("/", hello_world), - ] - ) - - assert not startup_complete - assert not shutdown_complete - with TestClient(app) as client: - assert startup_complete - assert not shutdown_complete - client.get("/") - assert startup_complete - assert shutdown_complete - - -def test_raise_on_startup(): - def run_startup(): - raise RuntimeError() - - app = Router(routes=[Lifespan(on_startup=run_startup)]) - - with pytest.raises(RuntimeError): - with TestClient(app): - pass # pragma: nocover - - -def test_raise_on_shutdown(): - def run_shutdown(): - raise RuntimeError() - - app = Router(routes=[Lifespan(on_shutdown=run_shutdown)]) - - with pytest.raises(RuntimeError): - with TestClient(app): - pass # pragma: nocover - - -def test_app_lifespan(): - startup_complete = False - shutdown_complete = False - app = Starlette() - - @app.on_event("startup") - def run_startup(): - nonlocal startup_complete - startup_complete = True - - @app.on_event("shutdown") - def run_shutdown(): - nonlocal shutdown_complete - shutdown_complete = True - - assert not startup_complete - assert not shutdown_complete - with TestClient(app): - assert startup_complete - assert not shutdown_complete - assert startup_complete - assert shutdown_complete - - -def test_app_async_lifespan(): - startup_complete = False - shutdown_complete = False - app = Starlette() - - @app.on_event("startup") - async def run_startup(): - nonlocal startup_complete - startup_complete = True - - @app.on_event("shutdown") - async def run_shutdown(): - nonlocal shutdown_complete - shutdown_complete = True - - assert not startup_complete - assert not shutdown_complete - with TestClient(app): - assert startup_complete - assert not shutdown_complete - assert startup_complete - assert shutdown_complete diff --git a/tests/middleware/test_session.py b/tests/middleware/test_session.py index 3f71232e6..314f2be58 100644 --- a/tests/middleware/test_session.py +++ b/tests/middleware/test_session.py @@ -3,7 +3,6 @@ from starlette.applications import Starlette from starlette.middleware.sessions import SessionMiddleware from starlette.responses import JSONResponse -from starlette.testclient import TestClient def view_session(request): @@ -29,10 +28,10 @@ def create_app(): return app -def test_session(): +def test_session(test_client_factory): app = create_app() app.add_middleware(SessionMiddleware, secret_key="example") - client = TestClient(app) + client = test_client_factory(app) response = client.get("/view_session") assert response.json() == {"session": {}} @@ -56,10 +55,10 @@ def test_session(): assert response.json() == {"session": {}} -def test_session_expires(): +def test_session_expires(test_client_factory): app = create_app() app.add_middleware(SessionMiddleware, secret_key="example", max_age=-1) - client = TestClient(app) + client = test_client_factory(app) response = client.post("/update_session", json={"some": "data"}) assert response.json() == {"session": {"some": "data"}} @@ -72,11 +71,11 @@ def test_session_expires(): assert response.json() == {"session": {}} -def test_secure_session(): +def test_secure_session(test_client_factory): app = create_app() app.add_middleware(SessionMiddleware, secret_key="example", https_only=True) - secure_client = TestClient(app, base_url="https://testserver") - unsecure_client = TestClient(app, base_url="http://testserver") + secure_client = test_client_factory(app, base_url="https://testserver") + unsecure_client = test_client_factory(app, base_url="http://testserver") response = unsecure_client.get("/view_session") assert response.json() == {"session": {}} @@ -101,3 +100,15 @@ def test_secure_session(): response = secure_client.get("/view_session") assert response.json() == {"session": {}} + + +def test_session_cookie_subpath(test_client_factory): + app = create_app() + second_app = create_app() + second_app.add_middleware(SessionMiddleware, secret_key="example") + app.mount("/second_app", second_app) + client = test_client_factory(app, base_url="http://testserver/second_app") + response = client.post("second_app/update_session", json={"some": "data"}) + cookie = response.headers["set-cookie"] + cookie_path = re.search(r"; path=(\S+);", cookie).groups()[0] + assert cookie_path == "/second_app" diff --git a/tests/middleware/test_trusted_host.py b/tests/middleware/test_trusted_host.py index 934f2477b..de9c79e66 100644 --- a/tests/middleware/test_trusted_host.py +++ b/tests/middleware/test_trusted_host.py @@ -1,10 +1,9 @@ from starlette.applications import Starlette from starlette.middleware.trustedhost import TrustedHostMiddleware from starlette.responses import PlainTextResponse -from starlette.testclient import TestClient -def test_trusted_host_middleware(): +def test_trusted_host_middleware(test_client_factory): app = Starlette() app.add_middleware( @@ -15,15 +14,15 @@ def test_trusted_host_middleware(): def homepage(request): return PlainTextResponse("OK", status_code=200) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.status_code == 200 - client = TestClient(app, base_url="http://subdomain.testserver") + client = test_client_factory(app, base_url="http://subdomain.testserver") response = client.get("/") assert response.status_code == 200 - client = TestClient(app, base_url="http://invalidhost") + client = test_client_factory(app, base_url="http://invalidhost") response = client.get("/") assert response.status_code == 400 @@ -34,7 +33,7 @@ def test_default_allowed_hosts(): assert middleware.allowed_hosts == ["*"] -def test_www_redirect(): +def test_www_redirect(test_client_factory): app = Starlette() app.add_middleware(TrustedHostMiddleware, allowed_hosts=["www.example.com"]) @@ -43,7 +42,7 @@ def test_www_redirect(): def homepage(request): return PlainTextResponse("OK", status_code=200) - client = TestClient(app, base_url="https://example.com") + client = test_client_factory(app, base_url="https://example.com") response = client.get("/") assert response.status_code == 200 assert response.url == "https://www.example.com/" diff --git a/tests/middleware/test_wsgi.py b/tests/middleware/test_wsgi.py index 202726097..bcb4cd6ff 100644 --- a/tests/middleware/test_wsgi.py +++ b/tests/middleware/test_wsgi.py @@ -3,7 +3,6 @@ import pytest from starlette.middleware.wsgi import WSGIMiddleware, build_environ -from starlette.testclient import TestClient def hello_world(environ, start_response): @@ -46,41 +45,41 @@ def return_exc_info(environ, start_response): return [output] -def test_wsgi_get(): +def test_wsgi_get(test_client_factory): app = WSGIMiddleware(hello_world) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.status_code == 200 assert response.text == "Hello World!\n" -def test_wsgi_post(): +def test_wsgi_post(test_client_factory): app = WSGIMiddleware(echo_body) - client = TestClient(app) + client = test_client_factory(app) response = client.post("/", json={"example": 123}) assert response.status_code == 200 assert response.text == '{"example": 123}' -def test_wsgi_exception(): +def test_wsgi_exception(test_client_factory): # Note that we're testing the WSGI app directly here. # The HTTP protocol implementations would catch this error and return 500. app = WSGIMiddleware(raise_exception) - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(RuntimeError): client.get("/") -def test_wsgi_exc_info(): +def test_wsgi_exc_info(test_client_factory): # Note that we're testing the WSGI app directly here. # The HTTP protocol implementations would catch this error and return 500. app = WSGIMiddleware(return_exc_info) - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(RuntimeError): response = client.get("/") app = WSGIMiddleware(return_exc_info) - client = TestClient(app, raise_server_exceptions=False) + client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/") assert response.status_code == 500 assert response.text == "Internal Server Error" @@ -128,3 +127,18 @@ def test_build_environ(): "wsgi.url_scheme": "https", "wsgi.version": (1, 0), } + + +def test_build_environ_encoding() -> None: + scope = { + "type": "http", + "http_version": "1.1", + "method": "GET", + "path": "/小星", + "root_path": "/中国", + "query_string": b"a=123&b=456", + "headers": [], + } + environ = build_environ(scope, b"") + assert environ["SCRIPT_NAME"] == "/中国".encode().decode("latin-1") + assert environ["PATH_INFO"] == "/小星".encode().decode("latin-1") diff --git a/tests/test_applications.py b/tests/test_applications.py index 9c2d4d25d..aaccabab0 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -1,5 +1,6 @@ import asyncio import os +import sys import pytest @@ -11,8 +12,11 @@ from starlette.responses import JSONResponse, PlainTextResponse from starlette.routing import Host, Mount, Route, Router, WebSocketRoute from starlette.staticfiles import StaticFiles -from starlette.testclient import TestClient -from starlette.websockets import WebSocketDisconnect + +if sys.version_info >= (3, 7): + from contextlib import asynccontextmanager # pragma: no cover +else: + from contextlib2 import asynccontextmanager # pragma: no cover app = Starlette() @@ -113,14 +117,17 @@ def custom_ws_exception_handler(websocket, exc): loop.run_until_complete(websocket.close(code=status.WS_1013_TRY_AGAIN_LATER)) -client = TestClient(app) +@pytest.fixture +def client(test_client_factory): + with test_client_factory(app) as client: + yield client def test_url_path_for(): assert app.url_path_for("func_homepage") == "/func" -def test_func_route(): +def test_func_route(client): response = client.get("/func") assert response.status_code == 200 assert response.text == "Hello, world!" @@ -130,51 +137,51 @@ def test_func_route(): assert response.text == "" -def test_async_route(): +def test_async_route(client): response = client.get("/async") assert response.status_code == 200 assert response.text == "Hello, world!" -def test_class_route(): +def test_class_route(client): response = client.get("/class") assert response.status_code == 200 assert response.text == "Hello, world!" -def test_mounted_route(): +def test_mounted_route(client): response = client.get("/users/") assert response.status_code == 200 assert response.text == "Hello, everyone!" -def test_mounted_route_path_params(): +def test_mounted_route_path_params(client): response = client.get("/users/tomchristie") assert response.status_code == 200 assert response.text == "Hello, tomchristie!" -def test_subdomain_route(): - client = TestClient(app, base_url="https://foo.example.org/") +def test_subdomain_route(test_client_factory): + client = test_client_factory(app, base_url="https://foo.example.org/") response = client.get("/") assert response.status_code == 200 assert response.text == "Subdomain: foo" -def test_websocket_route(): +def test_websocket_route(client): with client.websocket_connect("/ws") as session: text = session.receive_text() assert text == "Hello, world!" -def test_400(): +def test_400(client): response = client.get("/404") assert response.status_code == 404 assert response.json() == {"detail": "Not Found"} -def test_405(): +def test_405(client): response = client.post("/func") assert response.status_code == 405 assert response.json() == {"detail": "Custom message"} @@ -184,15 +191,14 @@ def test_405(): assert response.json() == {"detail": "Custom message"} -def test_500(): - client = TestClient(app, raise_server_exceptions=False) +def test_500(test_client_factory): + client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/500") assert response.status_code == 500 assert response.json() == {"detail": "Server Error"} -def test_websocket_raise_websocket_exception(): - client = TestClient(app) +def test_websocket_raise_websocket_exception(client): with client.websocket_connect("/ws-raise-websocket") as session: response = session.receive() assert response == { @@ -201,8 +207,7 @@ def test_websocket_raise_websocket_exception(): } -def test_websocket_raise_custom_exception(): - client = TestClient(app) +def test_websocket_raise_custom_exception(client): with client.websocket_connect("/ws-raise-custom") as session: response = session.receive() assert response == { @@ -211,8 +216,8 @@ def test_websocket_raise_custom_exception(): } -def test_middleware(): - client = TestClient(app, base_url="http://incorrecthost") +def test_middleware(test_client_factory): + client = test_client_factory(app, base_url="http://incorrecthost") response = client.get("/func") assert response.status_code == 400 assert response.text == "Invalid host header" @@ -245,7 +250,7 @@ def test_routes(): ] -def test_app_mount(tmpdir): +def test_app_mount(tmpdir, test_client_factory): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") @@ -253,7 +258,7 @@ def test_app_mount(tmpdir): app = Starlette() app.mount("/static", StaticFiles(directory=tmpdir)) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/static/example.txt") assert response.status_code == 200 @@ -264,7 +269,7 @@ def test_app_mount(tmpdir): assert response.text == "Method Not Allowed" -def test_app_debug(): +def test_app_debug(test_client_factory): app = Starlette() app.debug = True @@ -272,27 +277,27 @@ def test_app_debug(): async def homepage(request): raise RuntimeError() - client = TestClient(app, raise_server_exceptions=False) + client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/") assert response.status_code == 500 assert "RuntimeError" in response.text assert app.debug -def test_app_add_route(): +def test_app_add_route(test_client_factory): app = Starlette() async def homepage(request): return PlainTextResponse("Hello, World!") app.add_route("/", homepage) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.status_code == 200 assert response.text == "Hello, World!" -def test_app_add_websocket_route(): +def test_app_add_websocket_route(test_client_factory): app = Starlette() async def websocket_endpoint(session): @@ -301,14 +306,14 @@ async def websocket_endpoint(session): await session.close() app.add_websocket_route("/ws", websocket_endpoint) - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/ws") as session: text = session.receive_text() assert text == "Hello, world!" -def test_app_add_event_handler(): +def test_app_add_event_handler(test_client_factory): startup_complete = False cleanup_complete = False app = Starlette() @@ -326,7 +331,82 @@ def run_cleanup(): assert not startup_complete assert not cleanup_complete - with TestClient(app): + with test_client_factory(app): + assert startup_complete + assert not cleanup_complete + assert startup_complete + assert cleanup_complete + + +def test_app_async_cm_lifespan(test_client_factory): + startup_complete = False + cleanup_complete = False + + @asynccontextmanager + async def lifespan(app): + nonlocal startup_complete, cleanup_complete + startup_complete = True + yield + cleanup_complete = True + + app = Starlette(lifespan=lifespan) + + assert not startup_complete + assert not cleanup_complete + with test_client_factory(app): + assert startup_complete + assert not cleanup_complete + assert startup_complete + assert cleanup_complete + + +deprecated_lifespan = pytest.mark.filterwarnings( + r"ignore" + r":(async )?generator function lifespans are deprecated, use an " + r"@contextlib\.asynccontextmanager function instead" + r":DeprecationWarning" + r":starlette.routing" +) + + +@deprecated_lifespan +def test_app_async_gen_lifespan(test_client_factory): + startup_complete = False + cleanup_complete = False + + async def lifespan(app): + nonlocal startup_complete, cleanup_complete + startup_complete = True + yield + cleanup_complete = True + + app = Starlette(lifespan=lifespan) + + assert not startup_complete + assert not cleanup_complete + with test_client_factory(app): + assert startup_complete + assert not cleanup_complete + assert startup_complete + assert cleanup_complete + + +@deprecated_lifespan +def test_app_sync_gen_lifespan(test_client_factory): + startup_complete = False + cleanup_complete = False + + def lifespan(app): + nonlocal startup_complete, cleanup_complete + startup_complete = True + yield + cleanup_complete = True + + app = Starlette(lifespan=lifespan) + + assert not startup_complete + assert not cleanup_complete + with test_client_factory(app): assert startup_complete assert not cleanup_complete assert startup_complete diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 372ea81d8..43c7ab96d 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -15,7 +15,6 @@ from starlette.middleware.authentication import AuthenticationMiddleware from starlette.requests import Request from starlette.responses import JSONResponse -from starlette.testclient import TestClient from starlette.websockets import WebSocketDisconnect @@ -117,6 +116,76 @@ async def websocket_endpoint(websocket): ) +def async_inject_decorator(**kwargs): + def wrapper(endpoint): + async def app(request): + return await endpoint(request=request, **kwargs) + + return app + + return wrapper + + +@app.route("/dashboard/decorated") +@async_inject_decorator(additional="payload") +@requires("authenticated") +async def decorated_async(request, additional): + return JSONResponse( + { + "authenticated": request.user.is_authenticated, + "user": request.user.display_name, + "additional": additional, + } + ) + + +def sync_inject_decorator(**kwargs): + def wrapper(endpoint): + def app(request): + return endpoint(request=request, **kwargs) + + return app + + return wrapper + + +@app.route("/dashboard/decorated/sync") +@sync_inject_decorator(additional="payload") +@requires("authenticated") +def decorated_sync(request, additional): + return JSONResponse( + { + "authenticated": request.user.is_authenticated, + "user": request.user.display_name, + "additional": additional, + } + ) + + +def ws_inject_decorator(**kwargs): + def wrapper(endpoint): + def app(websocket): + return endpoint(websocket=websocket, **kwargs) + + return app + + return wrapper + + +@app.websocket_route("/ws/decorated") +@ws_inject_decorator(additional="payload") +@requires("authenticated") +async def websocket_endpoint_decorated(websocket, additional): + await websocket.accept() + await websocket.send_json( + { + "authenticated": websocket.user.is_authenticated, + "user": websocket.user.display_name, + "additional": additional, + } + ) + + def test_invalid_decorator_usage(): with pytest.raises(Exception): @@ -125,8 +194,8 @@ def foo(): pass # pragma: nocover -def test_user_interface(): - with TestClient(app) as client: +def test_user_interface(test_client_factory): + with test_client_factory(app) as client: response = client.get("/") assert response.status_code == 200 assert response.json() == {"authenticated": False, "user": ""} @@ -136,8 +205,8 @@ def test_user_interface(): assert response.json() == {"authenticated": True, "user": "tomchristie"} -def test_authentication_required(): - with TestClient(app) as client: +def test_authentication_required(test_client_factory): + with test_client_factory(app) as client: response = client.get("/dashboard") assert response.status_code == 403 @@ -159,18 +228,46 @@ def test_authentication_required(): assert response.status_code == 200 assert response.json() == {"authenticated": True, "user": "tomchristie"} + response = client.get("/dashboard/decorated", auth=("tomchristie", "example")) + assert response.status_code == 200 + assert response.json() == { + "authenticated": True, + "user": "tomchristie", + "additional": "payload", + } + + response = client.get("/dashboard/decorated") + assert response.status_code == 403 + + response = client.get( + "/dashboard/decorated/sync", auth=("tomchristie", "example") + ) + assert response.status_code == 200 + assert response.json() == { + "authenticated": True, + "user": "tomchristie", + "additional": "payload", + } + + response = client.get("/dashboard/decorated/sync") + assert response.status_code == 403 + response = client.get("/dashboard", headers={"Authorization": "basic foobar"}) assert response.status_code == 400 assert response.text == "Invalid basic auth credentials" -def test_websocket_authentication_required(): - with TestClient(app) as client: +def test_websocket_authentication_required(test_client_factory): + with test_client_factory(app) as client: with pytest.raises(WebSocketDisconnect): - client.websocket_connect("/ws") + with client.websocket_connect("/ws"): + pass # pragma: nocover with pytest.raises(WebSocketDisconnect): - client.websocket_connect("/ws", headers={"Authorization": "basic foobar"}) + with client.websocket_connect( + "/ws", headers={"Authorization": "basic foobar"} + ): + pass # pragma: nocover with client.websocket_connect( "/ws", auth=("tomchristie", "example") @@ -178,9 +275,29 @@ def test_websocket_authentication_required(): data = websocket.receive_json() assert data == {"authenticated": True, "user": "tomchristie"} + with pytest.raises(WebSocketDisconnect): + with client.websocket_connect("/ws/decorated"): + pass # pragma: nocover + + with pytest.raises(WebSocketDisconnect): + with client.websocket_connect( + "/ws/decorated", headers={"Authorization": "basic foobar"} + ): + pass # pragma: nocover + + with client.websocket_connect( + "/ws/decorated", auth=("tomchristie", "example") + ) as websocket: + data = websocket.receive_json() + assert data == { + "authenticated": True, + "user": "tomchristie", + "additional": "payload", + } + -def test_authentication_redirect(): - with TestClient(app) as client: +def test_authentication_redirect(test_client_factory): + with test_client_factory(app) as client: response = client.get("/admin") assert response.status_code == 200 assert response.url == "http://testserver/" @@ -219,8 +336,8 @@ def control_panel(request): ) -def test_custom_on_error(): - with TestClient(other_app) as client: +def test_custom_on_error(test_client_factory): + with test_client_factory(other_app) as client: response = client.get("/control-panel", auth=("tomchristie", "example")) assert response.status_code == 200 assert response.json() == {"authenticated": True, "user": "tomchristie"} diff --git a/tests/test_background.py b/tests/test_background.py index d9d7ddd87..e299ec362 100644 --- a/tests/test_background.py +++ b/tests/test_background.py @@ -1,9 +1,8 @@ from starlette.background import BackgroundTask, BackgroundTasks from starlette.responses import Response -from starlette.testclient import TestClient -def test_async_task(): +def test_async_task(test_client_factory): TASK_COMPLETE = False async def async_task(): @@ -16,13 +15,13 @@ async def app(scope, receive, send): response = Response("task initiated", media_type="text/plain", background=task) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "task initiated" assert TASK_COMPLETE -def test_sync_task(): +def test_sync_task(test_client_factory): TASK_COMPLETE = False def sync_task(): @@ -35,13 +34,13 @@ async def app(scope, receive, send): response = Response("task initiated", media_type="text/plain", background=task) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "task initiated" assert TASK_COMPLETE -def test_multiple_tasks(): +def test_multiple_tasks(test_client_factory): TASK_COUNTER = 0 def increment(amount): @@ -58,7 +57,7 @@ async def app(scope, receive, send): ) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "tasks initiated" assert TASK_COUNTER == 1 + 2 + 3 diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py new file mode 100644 index 000000000..cc5eba974 --- /dev/null +++ b/tests/test_concurrency.py @@ -0,0 +1,22 @@ +import anyio +import pytest + +from starlette.concurrency import run_until_first_complete + + +@pytest.mark.anyio +async def test_run_until_first_complete(): + task1_finished = anyio.Event() + task2_finished = anyio.Event() + + async def task1(): + task1_finished.set() + + async def task2(): + await task1_finished.wait() + await anyio.sleep(0) # pragma: nocover + task2_finished.set() # pragma: nocover + + await run_until_first_complete((task1, {}), (task2, {})) + assert task1_finished.is_set() + assert not task2_finished.is_set() diff --git a/tests/test_config.py b/tests/test_config.py index f0ea1d450..ae91f9695 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,4 +1,5 @@ import os +from pathlib import Path import pytest @@ -44,6 +45,10 @@ def test_config(tmpdir, monkeypatch): with pytest.raises(ValueError): config.get("REQUEST_HOSTNAME", cast=bool) + config = Config(Path(path)) + REQUEST_HOSTNAME = config("REQUEST_HOSTNAME") + assert REQUEST_HOSTNAME == "example.com" + config = Config() monkeypatch.setenv("STARLETTE_EXAMPLE_TEST", "123") monkeypatch.setenv("BOOL_AS_INT", "1") diff --git a/tests/test_database.py b/tests/test_database.py index 258a71ec5..1230fc8f6 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -4,7 +4,6 @@ from starlette.applications import Starlette from starlette.responses import JSONResponse -from starlette.testclient import TestClient DATABASE_URL = "sqlite:///test.db" @@ -19,6 +18,9 @@ ) +pytestmark = pytest.mark.usefixtures("no_trio_support") + + @pytest.fixture(autouse=True, scope="module") def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) @@ -87,8 +89,8 @@ async def read_note_text(request): return JSONResponse(result[0]) -def test_database(): - with TestClient(app) as client: +def test_database(test_client_factory): + with test_client_factory(app) as client: response = client.post( "/notes", json={"text": "buy the milk", "completed": True} ) @@ -122,8 +124,8 @@ def test_database(): assert response.json() == "buy the milk" -def test_database_execute_many(): - with TestClient(app) as client: +def test_database_execute_many(test_client_factory): + with test_client_factory(app) as client: response = client.get("/notes") data = [ @@ -141,11 +143,11 @@ def test_database_execute_many(): ] -def test_database_isolated_during_test_cases(): +def test_database_isolated_during_test_cases(test_client_factory): """ Using `TestClient` as a context manager """ - with TestClient(app) as client: + with test_client_factory(app) as client: response = client.post( "/notes", json={"text": "just one note", "completed": True} ) @@ -155,7 +157,7 @@ def test_database_isolated_during_test_cases(): assert response.status_code == 200 assert response.json() == [{"text": "just one note", "completed": True}] - with TestClient(app) as client: + with test_client_factory(app) as client: response = client.post( "/notes", json={"text": "just one note", "completed": True} ) diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index e5e255e58..bb71ba870 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -1,5 +1,7 @@ import io +import pytest + from starlette.datastructures import ( URL, CommaSeparatedStrings, @@ -8,6 +10,7 @@ MultiDict, MutableHeaders, QueryParams, + UploadFile, ) @@ -36,6 +39,19 @@ def test_url(): assert new.hostname == "example.com" +def test_url_query_params(): + u = URL("https://example.org/path/?page=3") + assert u.query == "page=3" + u = u.include_query_params(page=4) + assert str(u) == "https://example.org/path/?page=4" + u = u.include_query_params(search="testing") + assert str(u) == "https://example.org/path/?page=4&search=testing" + u = u.replace_query_params(order="name") + assert str(u) == "https://example.org/path/?order=name" + u = u.remove_query_params("order") + assert str(u) == "https://example.org/path/" + + def test_hidden_password(): u = URL("https://example.org/path/to/somewhere") assert repr(u) == "URL('https://example.org/path/to/somewhere')" @@ -154,6 +170,17 @@ def test_headers_mutablecopy(): assert c.items() == [("a", "abc"), ("b", "789")] +def test_url_blank_params(): + q = QueryParams("a=123&abc&def&b=456") + assert "a" in q + assert "abc" in q + assert "def" in q + assert "b" in q + assert len(q.get("abc")) == 0 + assert len(q["a"]) == 3 + assert list(q.keys()) == ["a", "abc", "def", "b"] + + def test_queryparams(): q = QueryParams("a=123&a=456&b=789") assert "a" in q @@ -186,6 +213,20 @@ def test_queryparams(): assert QueryParams(q) == q +class BigUploadFile(UploadFile): + spool_max_size = 1024 + + +@pytest.mark.anyio +async def test_upload_file(): + big_file = BigUploadFile("big-file") + await big_file.write(b"big-data" * 512) + await big_file.write(b"big-data") + await big_file.seek(0) + assert await big_file.read(1024) == b"big-data" * 128 + await big_file.close() + + def test_formdata(): upload = io.BytesIO(b"test") form = FormData([("a", "123"), ("a", "456"), ("b", upload)]) diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index e491c085f..e57d47486 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -3,7 +3,6 @@ from starlette.endpoints import HTTPEndpoint, WebSocketEndpoint from starlette.responses import PlainTextResponse from starlette.routing import Route, Router -from starlette.testclient import TestClient class Homepage(HTTPEndpoint): @@ -18,46 +17,50 @@ async def get(self, request): routes=[Route("/", endpoint=Homepage), Route("/{username}", endpoint=Homepage)] ) -client = TestClient(app) +@pytest.fixture +def client(test_client_factory): + with test_client_factory(app) as client: + yield client -def test_http_endpoint_route(): + +def test_http_endpoint_route(client): response = client.get("/") assert response.status_code == 200 assert response.text == "Hello, world!" -def test_http_endpoint_route_path_params(): +def test_http_endpoint_route_path_params(client): response = client.get("/tomchristie") assert response.status_code == 200 assert response.text == "Hello, tomchristie!" -def test_http_endpoint_route_method(): +def test_http_endpoint_route_method(client): response = client.post("/") assert response.status_code == 405 assert response.text == "Method Not Allowed" -def test_websocket_endpoint_on_connect(): +def test_websocket_endpoint_on_connect(test_client_factory): class WebSocketApp(WebSocketEndpoint): async def on_connect(self, websocket): assert websocket["subprotocols"] == ["soap", "wamp"] await websocket.accept(subprotocol="wamp") - client = TestClient(WebSocketApp) + client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws", subprotocols=["soap", "wamp"]) as websocket: assert websocket.accepted_subprotocol == "wamp" -def test_websocket_endpoint_on_receive_bytes(): +def test_websocket_endpoint_on_receive_bytes(test_client_factory): class WebSocketApp(WebSocketEndpoint): encoding = "bytes" async def on_receive(self, websocket, data): await websocket.send_bytes(b"Message bytes was: " + data) - client = TestClient(WebSocketApp) + client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.send_bytes(b"Hello, world!") _bytes = websocket.receive_bytes() @@ -68,14 +71,14 @@ async def on_receive(self, websocket, data): websocket.send_text("Hello world") -def test_websocket_endpoint_on_receive_json(): +def test_websocket_endpoint_on_receive_json(test_client_factory): class WebSocketApp(WebSocketEndpoint): encoding = "json" async def on_receive(self, websocket, data): await websocket.send_json({"message": data}) - client = TestClient(WebSocketApp) + client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.send_json({"hello": "world"}) data = websocket.receive_json() @@ -86,28 +89,28 @@ async def on_receive(self, websocket, data): websocket.send_text("Hello world") -def test_websocket_endpoint_on_receive_json_binary(): +def test_websocket_endpoint_on_receive_json_binary(test_client_factory): class WebSocketApp(WebSocketEndpoint): encoding = "json" async def on_receive(self, websocket, data): await websocket.send_json({"message": data}, mode="binary") - client = TestClient(WebSocketApp) + client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.send_json({"hello": "world"}, mode="binary") data = websocket.receive_json(mode="binary") assert data == {"message": {"hello": "world"}} -def test_websocket_endpoint_on_receive_text(): +def test_websocket_endpoint_on_receive_text(test_client_factory): class WebSocketApp(WebSocketEndpoint): encoding = "text" async def on_receive(self, websocket, data): await websocket.send_text(f"Message text was: {data}") - client = TestClient(WebSocketApp) + client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.send_text("Hello, world!") _text = websocket.receive_text() @@ -118,26 +121,26 @@ async def on_receive(self, websocket, data): websocket.send_bytes(b"Hello world") -def test_websocket_endpoint_on_default(): +def test_websocket_endpoint_on_default(test_client_factory): class WebSocketApp(WebSocketEndpoint): encoding = None async def on_receive(self, websocket, data): await websocket.send_text(f"Message text was: {data}") - client = TestClient(WebSocketApp) + client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.send_text("Hello, world!") _text = websocket.receive_text() assert _text == "Message text was: Hello, world!" -def test_websocket_endpoint_on_disconnect(): +def test_websocket_endpoint_on_disconnect(test_client_factory): class WebSocketApp(WebSocketEndpoint): async def on_disconnect(self, websocket, close_code): assert close_code == 1001 await websocket.close(code=close_code) - client = TestClient(WebSocketApp) + client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.close(code=1001) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index f66c6b19c..5fba9981b 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -3,7 +3,6 @@ from starlette.exceptions import ExceptionMiddleware, HTTPException from starlette.responses import PlainTextResponse from starlette.routing import Route, Router, WebSocketRoute -from starlette.testclient import TestClient def raise_runtime_error(request): @@ -37,27 +36,33 @@ async def __call__(self, scope, receive, send): app = ExceptionMiddleware(router) -client = TestClient(app) -def test_not_acceptable(): +@pytest.fixture +def client(test_client_factory): + with test_client_factory(app) as client: + yield client + + +def test_not_acceptable(client): response = client.get("/not_acceptable") assert response.status_code == 406 assert response.text == "Not Acceptable" -def test_not_modified(): +def test_not_modified(client): response = client.get("/not_modified") assert response.status_code == 304 assert response.text == "" -def test_websockets_should_raise(): +def test_websockets_should_raise(client): with pytest.raises(RuntimeError): - client.websocket_connect("/runtime_error") + with client.websocket_connect("/runtime_error"): + pass # pragma: nocover -def test_handled_exc_after_response(): +def test_handled_exc_after_response(test_client_factory, client): # A 406 HttpException is raised *after* the response has already been sent. # The exception middleware should raise a RuntimeError. with pytest.raises(RuntimeError): @@ -65,17 +70,33 @@ def test_handled_exc_after_response(): # If `raise_server_exceptions=False` then the test client will still allow # us to see the response as it will have been seen by the client. - allow_200_client = TestClient(app, raise_server_exceptions=False) + allow_200_client = test_client_factory(app, raise_server_exceptions=False) response = allow_200_client.get("/handled_exc_after_response") assert response.status_code == 200 assert response.text == "OK" -def test_force_500_response(): +def test_force_500_response(test_client_factory): def app(scope): raise RuntimeError() - force_500_client = TestClient(app, raise_server_exceptions=False) + force_500_client = test_client_factory(app, raise_server_exceptions=False) response = force_500_client.get("/") assert response.status_code == 500 assert response.text == "" + + +def test_repr(): + assert repr(HTTPException(404)) == ( + "HTTPException(status_code=404, detail='Not Found')" + ) + assert repr(HTTPException(404, detail="Not Found: foo")) == ( + "HTTPException(status_code=404, detail='Not Found: foo')" + ) + + class CustomHTTPException(HTTPException): + pass + + assert repr(CustomHTTPException(500, detail="Something custom")) == ( + "CustomHTTPException(status_code=500, detail='Something custom')" + ) diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index 8bc71905e..8a1174e1d 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -1,9 +1,8 @@ import os -from starlette.formparsers import UploadFile +from starlette.formparsers import UploadFile, _user_safe_decode from starlette.requests import Request from starlette.responses import JSONResponse -from starlette.testclient import TestClient class ForceMultipartDict(dict): @@ -70,18 +69,18 @@ async def app_read_body(scope, receive, send): await response(scope, receive, send) -def test_multipart_request_data(tmpdir): - client = TestClient(app) +def test_multipart_request_data(tmpdir, test_client_factory): + client = test_client_factory(app) response = client.post("/", data={"some": "data"}, files=FORCE_MULTIPART) assert response.json() == {"some": "data"} -def test_multipart_request_files(tmpdir): +def test_multipart_request_files(tmpdir, test_client_factory): path = os.path.join(tmpdir, "test.txt") with open(path, "wb") as file: file.write(b"") - client = TestClient(app) + client = test_client_factory(app) with open(path, "rb") as f: response = client.post("/", files={"test": f}) assert response.json() == { @@ -93,12 +92,12 @@ def test_multipart_request_files(tmpdir): } -def test_multipart_request_files_with_content_type(tmpdir): +def test_multipart_request_files_with_content_type(tmpdir, test_client_factory): path = os.path.join(tmpdir, "test.txt") with open(path, "wb") as file: file.write(b"") - client = TestClient(app) + client = test_client_factory(app) with open(path, "rb") as f: response = client.post("/", files={"test": ("test.txt", f, "text/plain")}) assert response.json() == { @@ -110,7 +109,7 @@ def test_multipart_request_files_with_content_type(tmpdir): } -def test_multipart_request_multiple_files(tmpdir): +def test_multipart_request_multiple_files(tmpdir, test_client_factory): path1 = os.path.join(tmpdir, "test1.txt") with open(path1, "wb") as file: file.write(b"") @@ -119,7 +118,7 @@ def test_multipart_request_multiple_files(tmpdir): with open(path2, "wb") as file: file.write(b"") - client = TestClient(app) + client = test_client_factory(app) with open(path1, "rb") as f1, open(path2, "rb") as f2: response = client.post( "/", files={"test1": f1, "test2": ("test2.txt", f2, "text/plain")} @@ -138,7 +137,7 @@ def test_multipart_request_multiple_files(tmpdir): } -def test_multi_items(tmpdir): +def test_multi_items(tmpdir, test_client_factory): path1 = os.path.join(tmpdir, "test1.txt") with open(path1, "wb") as file: file.write(b"") @@ -147,7 +146,7 @@ def test_multi_items(tmpdir): with open(path2, "wb") as file: file.write(b"") - client = TestClient(multi_items_app) + client = test_client_factory(multi_items_app) with open(path1, "rb") as f1, open(path2, "rb") as f2: response = client.post( "/", @@ -171,8 +170,8 @@ def test_multi_items(tmpdir): } -def test_multipart_request_mixed_files_and_data(tmpdir): - client = TestClient(app) +def test_multipart_request_mixed_files_and_data(tmpdir, test_client_factory): + client = test_client_factory(app) response = client.post( "/", data=( @@ -192,7 +191,9 @@ def test_multipart_request_mixed_files_and_data(tmpdir): b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n" ), headers={ - "Content-Type": "multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c" + "Content-Type": ( + "multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c" + ) }, ) assert response.json() == { @@ -206,39 +207,125 @@ def test_multipart_request_mixed_files_and_data(tmpdir): } -def test_urlencoded_request_data(tmpdir): - client = TestClient(app) +def test_multipart_request_with_charset_for_filename(tmpdir, test_client_factory): + client = test_client_factory(app) + response = client.post( + "/", + data=( + # file + b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" + b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n' # noqa: E501 + b"Content-Type: text/plain\r\n\r\n" + b"\r\n" + b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n" + ), + headers={ + "Content-Type": ( + "multipart/form-data; charset=utf-8; " + "boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c" + ) + }, + ) + assert response.json() == { + "file": { + "filename": "文書.txt", + "content": "", + "content_type": "text/plain", + } + } + + +def test_multipart_request_without_charset_for_filename(tmpdir, test_client_factory): + client = test_client_factory(app) + response = client.post( + "/", + data=( + # file + b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" + b'Content-Disposition: form-data; name="file"; filename="\xe7\x94\xbb\xe5\x83\x8f.jpg"\r\n' # noqa: E501 + b"Content-Type: image/jpeg\r\n\r\n" + b"\r\n" + b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n" + ), + headers={ + "Content-Type": ( + "multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c" + ) + }, + ) + assert response.json() == { + "file": { + "filename": "画像.jpg", + "content": "", + "content_type": "image/jpeg", + } + } + + +def test_multipart_request_with_encoded_value(tmpdir, test_client_factory): + client = test_client_factory(app) + response = client.post( + "/", + data=( + b"--20b303e711c4ab8c443184ac833ab00f\r\n" + b"Content-Disposition: form-data; " + b'name="value"\r\n\r\n' + b"Transf\xc3\xa9rer\r\n" + b"--20b303e711c4ab8c443184ac833ab00f--\r\n" + ), + headers={ + "Content-Type": ( + "multipart/form-data; charset=utf-8; " + "boundary=20b303e711c4ab8c443184ac833ab00f" + ) + }, + ) + assert response.json() == {"value": "Transférer"} + + +def test_urlencoded_request_data(tmpdir, test_client_factory): + client = test_client_factory(app) response = client.post("/", data={"some": "data"}) assert response.json() == {"some": "data"} -def test_no_request_data(tmpdir): - client = TestClient(app) +def test_no_request_data(tmpdir, test_client_factory): + client = test_client_factory(app) response = client.post("/") assert response.json() == {} -def test_urlencoded_percent_encoding(tmpdir): - client = TestClient(app) +def test_urlencoded_percent_encoding(tmpdir, test_client_factory): + client = test_client_factory(app) response = client.post("/", data={"some": "da ta"}) assert response.json() == {"some": "da ta"} -def test_urlencoded_percent_encoding_keys(tmpdir): - client = TestClient(app) +def test_urlencoded_percent_encoding_keys(tmpdir, test_client_factory): + client = test_client_factory(app) response = client.post("/", data={"so me": "data"}) assert response.json() == {"so me": "data"} -def test_urlencoded_multi_field_app_reads_body(tmpdir): - client = TestClient(app_read_body) +def test_urlencoded_multi_field_app_reads_body(tmpdir, test_client_factory): + client = test_client_factory(app_read_body) response = client.post("/", data={"some": "data", "second": "key pair"}) assert response.json() == {"some": "data", "second": "key pair"} -def test_multipart_multi_field_app_reads_body(tmpdir): - client = TestClient(app_read_body) +def test_multipart_multi_field_app_reads_body(tmpdir, test_client_factory): + client = test_client_factory(app_read_body) response = client.post( "/", data={"some": "data", "second": "key pair"}, files=FORCE_MULTIPART ) assert response.json() == {"some": "data", "second": "key pair"} + + +def test_user_safe_decode_helper(): + result = _user_safe_decode(b"\xc4\x99\xc5\xbc\xc4\x87", "utf-8") + assert result == "ężć" + + +def test_user_safe_decode_ignores_wrong_charset(): + result = _user_safe_decode(b"abc", "latin-8") + assert result == "abc" diff --git a/tests/test_graphql.py b/tests/test_graphql.py index e0bd034a5..8492439f8 100644 --- a/tests/test_graphql.py +++ b/tests/test_graphql.py @@ -1,10 +1,10 @@ import graphene +import pytest from graphql.execution.executors.asyncio import AsyncioExecutor from starlette.applications import Starlette from starlette.datastructures import Headers from starlette.graphql import GraphQLApp -from starlette.testclient import TestClient class FakeAuthMiddleware: @@ -33,55 +33,59 @@ def resolve_whoami(self, info): schema = graphene.Schema(query=Query) -app = GraphQLApp(schema=schema, graphiql=True) -client = TestClient(app) -def test_graphql_get(): +@pytest.fixture +def client(test_client_factory): + app = GraphQLApp(schema=schema, graphiql=True) + return test_client_factory(app) + + +def test_graphql_get(client): response = client.get("/?query={ hello }") assert response.status_code == 200 - assert response.json() == {"data": {"hello": "Hello stranger"}, "errors": None} + assert response.json() == {"data": {"hello": "Hello stranger"}} -def test_graphql_post(): +def test_graphql_post(client): response = client.post("/?query={ hello }") assert response.status_code == 200 - assert response.json() == {"data": {"hello": "Hello stranger"}, "errors": None} + assert response.json() == {"data": {"hello": "Hello stranger"}} -def test_graphql_post_json(): +def test_graphql_post_json(client): response = client.post("/", json={"query": "{ hello }"}) assert response.status_code == 200 - assert response.json() == {"data": {"hello": "Hello stranger"}, "errors": None} + assert response.json() == {"data": {"hello": "Hello stranger"}} -def test_graphql_post_graphql(): +def test_graphql_post_graphql(client): response = client.post( "/", data="{ hello }", headers={"content-type": "application/graphql"} ) assert response.status_code == 200 - assert response.json() == {"data": {"hello": "Hello stranger"}, "errors": None} + assert response.json() == {"data": {"hello": "Hello stranger"}} -def test_graphql_post_invalid_media_type(): +def test_graphql_post_invalid_media_type(client): response = client.post("/", data="{ hello }", headers={"content-type": "dummy"}) assert response.status_code == 415 assert response.text == "Unsupported Media Type" -def test_graphql_put(): +def test_graphql_put(client): response = client.put("/", json={"query": "{ hello }"}) assert response.status_code == 405 assert response.text == "Method Not Allowed" -def test_graphql_no_query(): +def test_graphql_no_query(client): response = client.get("/") assert response.status_code == 400 assert response.text == "No GraphQL query found in the request" -def test_graphql_invalid_field(): +def test_graphql_invalid_field(client): response = client.post("/", json={"query": "{ dummy }"}) assert response.status_code == 400 assert response.json() == { @@ -95,39 +99,39 @@ def test_graphql_invalid_field(): } -def test_graphiql_get(): +def test_graphiql_get(client): response = client.get("/", headers={"accept": "text/html"}) assert response.status_code == 200 assert "" in response.text -def test_graphiql_not_found(): +def test_graphiql_not_found(test_client_factory): app = GraphQLApp(schema=schema, graphiql=False) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"accept": "text/html"}) assert response.status_code == 404 assert response.text == "Not Found" -def test_add_graphql_route(): +def test_add_graphql_route(test_client_factory): app = Starlette() app.add_route("/", GraphQLApp(schema=schema)) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/?query={ hello }") assert response.status_code == 200 - assert response.json() == {"data": {"hello": "Hello stranger"}, "errors": None} + assert response.json() == {"data": {"hello": "Hello stranger"}} -def test_graphql_context(): +def test_graphql_context(test_client_factory): app = Starlette() app.add_middleware(FakeAuthMiddleware) app.add_route("/", GraphQLApp(schema=schema)) - client = TestClient(app) + client = test_client_factory(app) response = client.post( "/", json={"query": "{ whoami }"}, headers={"Authorization": "Bearer 123"} ) assert response.status_code == 200 - assert response.json() == {"data": {"whoami": "Jane"}, "errors": None} + assert response.json() == {"data": {"whoami": "Jane"}} class ASyncQuery(graphene.ObjectType): @@ -141,20 +145,8 @@ async def resolve_hello(self, info, name): async_app = GraphQLApp(schema=async_schema, executor_class=AsyncioExecutor) -def test_graphql_async(): - client = TestClient(async_app) - response = client.get("/?query={ hello }") - assert response.status_code == 200 - assert response.json() == {"data": {"hello": "Hello stranger"}, "errors": None} - - -async_schema = graphene.Schema(query=ASyncQuery) -old_style_async_app = GraphQLApp(schema=async_schema, executor=AsyncioExecutor()) - - -def test_graphql_async_old_style_executor(): - # See https://github.com/encode/starlette/issues/242 - client = TestClient(old_style_async_app) +def test_graphql_async(no_trio_support, test_client_factory): + client = test_client_factory(async_app) response = client.get("/?query={ hello }") assert response.status_code == 200 - assert response.json() == {"data": {"hello": "Hello stranger"}, "errors": None} + assert response.json() == {"data": {"hello": "Hello stranger"}} diff --git a/tests/test_requests.py b/tests/test_requests.py index 270181eff..d7c69fbeb 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -1,20 +1,18 @@ -import asyncio - +import anyio import pytest -from starlette.requests import ClientDisconnect, Request +from starlette.requests import ClientDisconnect, Request, State from starlette.responses import JSONResponse, Response -from starlette.testclient import TestClient -def test_request_url(): +def test_request_url(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) data = {"method": request.method, "url": str(request.url)} response = JSONResponse(data) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/123?a=abc") assert response.json() == {"method": "GET", "url": "http://testserver/123?a=abc"} @@ -22,26 +20,26 @@ async def app(scope, receive, send): assert response.json() == {"method": "GET", "url": "https://example.org:123/"} -def test_request_query_params(): +def test_request_query_params(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) params = dict(request.query_params) response = JSONResponse({"params": params}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/?a=123&b=456") assert response.json() == {"params": {"a": "123", "b": "456"}} -def test_request_headers(): +def test_request_headers(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) headers = dict(request.headers) response = JSONResponse({"headers": headers}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"host": "example.org"}) assert response.json() == { "headers": { @@ -54,7 +52,7 @@ async def app(scope, receive, send): } -def test_request_client(): +def test_request_client(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) response = JSONResponse( @@ -62,19 +60,19 @@ async def app(scope, receive, send): ) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.json() == {"host": "testclient", "port": 50000} -def test_request_body(): +def test_request_body(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) body = await request.body() response = JSONResponse({"body": body.decode()}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.json() == {"body": ""} @@ -86,7 +84,7 @@ async def app(scope, receive, send): assert response.json() == {"body": "abc"} -def test_request_stream(): +def test_request_stream(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) body = b"" @@ -95,7 +93,7 @@ async def app(scope, receive, send): response = JSONResponse({"body": body.decode()}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.json() == {"body": ""} @@ -107,20 +105,20 @@ async def app(scope, receive, send): assert response.json() == {"body": "abc"} -def test_request_form_urlencoded(): +def test_request_form_urlencoded(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) form = await request.form() response = JSONResponse({"form": dict(form)}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.post("/", data={"abc": "123 @"}) assert response.json() == {"form": {"abc": "123 @"}} -def test_request_body_then_stream(): +def test_request_body_then_stream(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) body = await request.body() @@ -130,13 +128,13 @@ async def app(scope, receive, send): response = JSONResponse({"body": body.decode(), "stream": chunks.decode()}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.post("/", data="abc") assert response.json() == {"body": "abc", "stream": "abc"} -def test_request_stream_then_body(): +def test_request_stream_then_body(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) chunks = b"" @@ -149,20 +147,20 @@ async def app(scope, receive, send): response = JSONResponse({"body": body.decode(), "stream": chunks.decode()}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.post("/", data="abc") assert response.json() == {"body": "", "stream": "abc"} -def test_request_json(): +def test_request_json(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) data = await request.json() response = JSONResponse({"json": data}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.post("/", json={"a": "123"}) assert response.json() == {"json": {"a": "123"}} @@ -178,7 +176,7 @@ def test_request_scope_interface(): assert len(request) == 3 -def test_request_without_setting_receive(): +def test_request_without_setting_receive(test_client_factory): """ If Request is instantiated without the receive channel, then .body() is not available. @@ -193,12 +191,12 @@ async def app(scope, receive, send): response = JSONResponse({"json": data}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.post("/", json={"a": "123"}) assert response.json() == {"json": "Receive channel not available"} -def test_request_disconnect(): +def test_request_disconnect(anyio_backend_name, anyio_backend_options): """ If a client disconnect occurs while reading request body then ClientDisconnect should be raised. @@ -212,12 +210,18 @@ async def receiver(): return {"type": "http.disconnect"} scope = {"type": "http", "method": "POST", "path": "/"} - loop = asyncio.get_event_loop() with pytest.raises(ClientDisconnect): - loop.run_until_complete(app(scope, receiver, None)) + anyio.run( + app, + scope, + receiver, + None, + backend=anyio_backend_name, + backend_options=anyio_backend_options, + ) -def test_request_is_disconnected(): +def test_request_is_disconnected(test_client_factory): """ If a client disconnect occurs while reading request body then ClientDisconnect should be raised. @@ -234,25 +238,39 @@ async def app(scope, receive, send): await response(scope, receive, send) disconnected_after_response = await request.is_disconnected() - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.json() == {"disconnected": False} assert disconnected_after_response -def test_request_state(): +def test_request_state_object(): + scope = {"state": {"old": "foo"}} + + s = State(scope["state"]) + + s.new = "value" + assert s.new == "value" + + del s.new + + with pytest.raises(AttributeError): + s.new + + +def test_request_state(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) request.state.example = 123 - response = JSONResponse({"state.example": request["state"].example}) + response = JSONResponse({"state.example": request.state.example}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/123?a=abc") assert response.json() == {"state.example": 123} -def test_request_cookies(): +def test_request_cookies(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) mycookie = request.cookies.get("mycookie") @@ -264,21 +282,129 @@ async def app(scope, receive, send): await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "Hello, world!" response = client.get("/") assert response.text == "Hello, cookies!" -def test_chunked_encoding(): +def test_cookie_lenient_parsing(test_client_factory): + """ + The following test is based on a cookie set by Okta, a well-known authorization + service. It turns out that it's common practice to set cookies that would be + invalid according to the spec. + """ + tough_cookie = ( + "provider-oauth-nonce=validAsciiblabla; " + 'okta-oauth-redirect-params={"responseType":"code","state":"somestate",' + '"nonce":"somenonce","scopes":["openid","profile","email","phone"],' + '"urls":{"issuer":"https://subdomain.okta.com/oauth2/authServer",' + '"authorizeUrl":"https://subdomain.okta.com/oauth2/authServer/v1/authorize",' + '"userinfoUrl":"https://subdomain.okta.com/oauth2/authServer/v1/userinfo"}}; ' + "importantCookie=importantValue; sessionCookie=importantSessionValue" + ) + expected_keys = { + "importantCookie", + "okta-oauth-redirect-params", + "provider-oauth-nonce", + "sessionCookie", + } + + async def app(scope, receive, send): + request = Request(scope, receive) + response = JSONResponse({"cookies": request.cookies}) + await response(scope, receive, send) + + client = test_client_factory(app) + response = client.get("/", headers={"cookie": tough_cookie}) + result = response.json() + assert len(result["cookies"]) == 4 + assert set(result["cookies"].keys()) == expected_keys + + +# These test cases copied from Tornado's implementation +@pytest.mark.parametrize( + "set_cookie,expected", + [ + ("chips=ahoy; vienna=finger", {"chips": "ahoy", "vienna": "finger"}), + # all semicolons are delimiters, even within quotes + ( + 'keebler="E=mc2; L=\\"Loves\\"; fudge=\\012;"', + {"keebler": '"E=mc2', "L": '\\"Loves\\"', "fudge": "\\012", "": '"'}, + ), + # Illegal cookies that have an '=' char in an unquoted value. + ("keebler=E=mc2", {"keebler": "E=mc2"}), + # Cookies with ':' character in their name. + ("key:term=value:term", {"key:term": "value:term"}), + # Cookies with '[' and ']'. + ("a=b; c=[; d=r; f=h", {"a": "b", "c": "[", "d": "r", "f": "h"}), + # Cookies that RFC6265 allows. + ("a=b; Domain=example.com", {"a": "b", "Domain": "example.com"}), + # parse_cookie() keeps only the last cookie with the same name. + ("a=b; h=i; a=c", {"a": "c", "h": "i"}), + ], +) +def test_cookies_edge_cases(set_cookie, expected, test_client_factory): + async def app(scope, receive, send): + request = Request(scope, receive) + response = JSONResponse({"cookies": request.cookies}) + await response(scope, receive, send) + + client = test_client_factory(app) + response = client.get("/", headers={"cookie": set_cookie}) + result = response.json() + assert result["cookies"] == expected + + +@pytest.mark.parametrize( + "set_cookie,expected", + [ + # Chunks without an equals sign appear as unnamed values per + # https://bugzilla.mozilla.org/show_bug.cgi?id=169091 + ( + "abc=def; unnamed; django_language=en", + {"": "unnamed", "abc": "def", "django_language": "en"}, + ), + # Even a double quote may be an unamed value. + ('a=b; "; c=d', {"a": "b", "": '"', "c": "d"}), + # Spaces in names and values, and an equals sign in values. + ("a b c=d e = f; gh=i", {"a b c": "d e = f", "gh": "i"}), + # More characters the spec forbids. + ('a b,c<>@:/[]?{}=d " =e,f g', {"a b,c<>@:/[]?{}": 'd " =e,f g'}), + # Unicode characters. The spec only allows ASCII. + # ("saint=André Bessette", {"saint": "André Bessette"}), + # Browsers don't send extra whitespace or semicolons in Cookie headers, + # but cookie_parser() should parse whitespace the same way + # document.cookie parses whitespace. + # (" = b ; ; = ; c = ; ", {"": "b", "c": ""}), + ], +) +def test_cookies_invalid(set_cookie, expected, test_client_factory): + """ + Cookie strings that are against the RFC6265 spec but which browsers will send if set + via document.cookie. + """ + + async def app(scope, receive, send): + request = Request(scope, receive) + response = JSONResponse({"cookies": request.cookies}) + await response(scope, receive, send) + + client = test_client_factory(app) + response = client.get("/", headers={"cookie": set_cookie}) + result = response.json() + assert result["cookies"] == expected + + +def test_chunked_encoding(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) body = await request.body() response = JSONResponse({"body": body.decode()}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) def post_body(): yield b"foo" @@ -286,3 +412,61 @@ def post_body(): response = client.post("/", data=post_body()) assert response.json() == {"body": "foobar"} + + +def test_request_send_push_promise(test_client_factory): + async def app(scope, receive, send): + # the server is push-enabled + scope["extensions"]["http.response.push"] = {} + + request = Request(scope, receive, send) + await request.send_push_promise("/style.css") + + response = JSONResponse({"json": "OK"}) + await response(scope, receive, send) + + client = test_client_factory(app) + response = client.get("/") + assert response.json() == {"json": "OK"} + + +def test_request_send_push_promise_without_push_extension(test_client_factory): + """ + If server does not support the `http.response.push` extension, + .send_push_promise() does nothing. + """ + + async def app(scope, receive, send): + request = Request(scope) + await request.send_push_promise("/style.css") + + response = JSONResponse({"json": "OK"}) + await response(scope, receive, send) + + client = test_client_factory(app) + response = client.get("/") + assert response.json() == {"json": "OK"} + + +def test_request_send_push_promise_without_setting_send(test_client_factory): + """ + If Request is instantiated without the send channel, then + .send_push_promise() is not available. + """ + + async def app(scope, receive, send): + # the server is push-enabled + scope["extensions"]["http.response.push"] = {} + + data = "OK" + request = Request(scope) + try: + await request.send_push_promise("/style.css") + except RuntimeError: + data = "Send channel not available" + response = JSONResponse({"json": data}) + await response(scope, receive, send) + + client = test_client_factory(app) + response = client.get("/") + assert response.json() == {"json": "Send channel not available"} diff --git a/tests/test_responses.py b/tests/test_responses.py index 3d5de413f..baba549ba 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -1,6 +1,6 @@ -import asyncio import os +import anyio import pytest from starlette import status @@ -8,45 +8,44 @@ from starlette.requests import Request from starlette.responses import ( FileResponse, + JSONResponse, RedirectResponse, Response, StreamingResponse, - UJSONResponse, ) -from starlette.testclient import TestClient -def test_text_response(): +def test_text_response(test_client_factory): async def app(scope, receive, send): response = Response("hello, world", media_type="text/plain") await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "hello, world" -def test_bytes_response(): +def test_bytes_response(test_client_factory): async def app(scope, receive, send): response = Response(b"xxxxx", media_type="image/png") await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.content == b"xxxxx" -def test_ujson_response(): +def test_json_none_response(test_client_factory): async def app(scope, receive, send): - response = UJSONResponse({"hello": "world"}) + response = JSONResponse(None) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") - assert response.json() == {"hello": "world"} + assert response.json() is None -def test_redirect_response(): +def test_redirect_response(test_client_factory): async def app(scope, receive, send): if scope["path"] == "/": response = Response("hello, world", media_type="text/plain") @@ -54,13 +53,27 @@ async def app(scope, receive, send): response = RedirectResponse("/") await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/redirect") assert response.text == "hello, world" assert response.url == "http://testserver/" -def test_streaming_response(): +def test_quoting_redirect_response(test_client_factory): + async def app(scope, receive, send): + if scope["path"] == "/I ♥ Starlette/": + response = Response("hello, world", media_type="text/plain") + else: + response = RedirectResponse("/I ♥ Starlette/") + await response(scope, receive, send) + + client = test_client_factory(app) + response = client.get("/redirect") + assert response.text == "hello, world" + assert response.url == "http://testserver/I%20%E2%99%A5%20Starlette/" + + +def test_streaming_response(test_client_factory): filled_by_bg_task = "" async def app(scope, receive, send): @@ -69,7 +82,7 @@ async def numbers(minimum, maximum): yield str(i) if i != maximum: yield ", " - await asyncio.sleep(0) + await anyio.sleep(0) async def numbers_for_cleanup(start=1, stop=5): nonlocal filled_by_bg_task @@ -84,13 +97,51 @@ async def numbers_for_cleanup(start=1, stop=5): await response(scope, receive, send) assert filled_by_bg_task == "" - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "1, 2, 3, 4, 5" assert filled_by_bg_task == "6, 7, 8, 9" -def test_sync_streaming_response(): +def test_streaming_response_custom_iterator(test_client_factory): + async def app(scope, receive, send): + class CustomAsyncIterator: + def __init__(self): + self._called = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + if self._called == 5: + raise StopAsyncIteration() + self._called += 1 + return str(self._called) + + response = StreamingResponse(CustomAsyncIterator(), media_type="text/plain") + await response(scope, receive, send) + + client = test_client_factory(app) + response = client.get("/") + assert response.text == "12345" + + +def test_streaming_response_custom_iterable(test_client_factory): + async def app(scope, receive, send): + class CustomAsyncIterable: + async def __aiter__(self): + for i in range(5): + yield str(i + 1) + + response = StreamingResponse(CustomAsyncIterable(), media_type="text/plain") + await response(scope, receive, send) + + client = test_client_factory(app) + response = client.get("/") + assert response.text == "12345" + + +def test_sync_streaming_response(test_client_factory): async def app(scope, receive, send): def numbers(minimum, maximum): for i in range(minimum, maximum + 1): @@ -102,37 +153,37 @@ def numbers(minimum, maximum): response = StreamingResponse(generator, media_type="text/plain") await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "1, 2, 3, 4, 5" -def test_response_headers(): +def test_response_headers(test_client_factory): async def app(scope, receive, send): headers = {"x-header-1": "123", "x-header-2": "456"} response = Response("hello, world", media_type="text/plain", headers=headers) response.headers["x-header-2"] = "789" await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.headers["x-header-1"] == "123" assert response.headers["x-header-2"] == "789" -def test_response_phrase(): +def test_response_phrase(test_client_factory): app = Response(status_code=204) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.reason == "No Content" app = Response(b"", status_code=123) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.reason == "" -def test_file_response(tmpdir): +def test_file_response(tmpdir, test_client_factory): path = os.path.join(tmpdir, "xyz") content = b"" * 1000 with open(path, "wb") as file: @@ -145,7 +196,7 @@ async def numbers(minimum, maximum): yield str(i) if i != maximum: yield ", " - await asyncio.sleep(0) + await anyio.sleep(0) async def numbers_for_cleanup(start=1, stop=5): nonlocal filled_by_bg_task @@ -161,7 +212,7 @@ async def app(scope, receive, send): await response(scope, receive, send) assert filled_by_bg_task == "" - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") expected_disposition = 'attachment; filename="example.png"' assert response.status_code == status.HTTP_200_OK @@ -174,24 +225,39 @@ async def app(scope, receive, send): assert filled_by_bg_task == "6, 7, 8, 9" -def test_file_response_with_directory_raises_error(tmpdir): +def test_file_response_with_directory_raises_error(tmpdir, test_client_factory): app = FileResponse(path=tmpdir, filename="example.png") - client = TestClient(app) - with pytest.raises(RuntimeError) as exc: + client = test_client_factory(app) + with pytest.raises(RuntimeError) as exc_info: client.get("/") - assert "is not a file" in str(exc) + assert "is not a file" in str(exc_info.value) -def test_file_response_with_missing_file_raises_error(tmpdir): +def test_file_response_with_missing_file_raises_error(tmpdir, test_client_factory): path = os.path.join(tmpdir, "404.txt") app = FileResponse(path=path, filename="404.txt") - client = TestClient(app) - with pytest.raises(RuntimeError) as exc: + client = test_client_factory(app) + with pytest.raises(RuntimeError) as exc_info: client.get("/") - assert "does not exist" in str(exc) + assert "does not exist" in str(exc_info.value) + + +def test_file_response_with_chinese_filename(tmpdir, test_client_factory): + content = b"file content" + filename = "你好.txt" # probably "Hello.txt" in Chinese + path = os.path.join(tmpdir, filename) + with open(path, "wb") as f: + f.write(content) + app = FileResponse(path=path, filename=filename) + client = test_client_factory(app) + response = client.get("/") + expected_disposition = "attachment; filename*=utf-8''%E4%BD%A0%E5%A5%BD.txt" + assert response.status_code == status.HTTP_200_OK + assert response.content == content + assert response.headers["content-disposition"] == expected_disposition -def test_set_cookie(): +def test_set_cookie(test_client_factory): async def app(scope, receive, send): response = Response("Hello, world!", media_type="text/plain") response.set_cookie( @@ -203,15 +269,16 @@ async def app(scope, receive, send): domain="localhost", secure=True, httponly=True, + samesite="none", ) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "Hello, world!" -def test_delete_cookie(): +def test_delete_cookie(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) response = Response("Hello, world!", media_type="text/plain") @@ -221,24 +288,24 @@ async def app(scope, receive, send): response.set_cookie("mycookie", "myvalue") await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.cookies["mycookie"] response = client.get("/") assert not response.cookies.get("mycookie") -def test_populate_headers(): +def test_populate_headers(test_client_factory): app = Response(content="hi", headers={}, media_type="text/html") - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "hi" assert response.headers["content-length"] == "2" assert response.headers["content-type"] == "text/html; charset=utf-8" -def test_head_method(): +def test_head_method(test_client_factory): app = Response("hello, world", media_type="text/plain") - client = TestClient(app) + client = test_client_factory(app) response = client.head("/") assert response.text == "" diff --git a/tests/test_routing.py b/tests/test_routing.py index 3c0c4cd9c..9e734b9cc 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -1,8 +1,11 @@ +import functools +import uuid + import pytest +from starlette.applications import Starlette from starlette.responses import JSONResponse, PlainTextResponse, Response from starlette.routing import Host, Mount, NoMatchFound, Route, Router, WebSocketRoute -from starlette.testclient import TestClient from starlette.websockets import WebSocket, WebSocketDisconnect @@ -74,6 +77,19 @@ def path_convertor(request): return JSONResponse({"path": path}) +@app.route("/uuid/{param:uuid}", name="uuid-convertor") +def uuid_converter(request): + uuid_param = request.path_params["param"] + return JSONResponse({"uuid": str(uuid_param)}) + + +# Route with chars that conflict with regex meta chars +@app.route("/path-with-parentheses({param:int})", name="path-with-parentheses") +def path_with_parentheses(request): + number = request.path_params["param"] + return JSONResponse({"int": number}) + + @app.websocket_route("/ws") async def websocket_endpoint(session): await session.accept() @@ -88,10 +104,19 @@ async def websocket_params(session): await session.close() -client = TestClient(app) +@pytest.fixture +def client(test_client_factory): + with test_client_factory(app) as client: + yield client -def test_router(): +@pytest.mark.filterwarnings( + r"ignore" + r":Trying to detect encoding from a tiny portion of \(5\) byte\(s\)\." + r":UserWarning" + r":charset_normalizer.api" +) +def test_router(client): response = client.get("/") assert response.status_code == 200 assert response.text == "Hello, world" @@ -116,6 +141,11 @@ def test_router(): assert response.status_code == 200 assert response.text == "User fixed me" + response = client.get("/users/tomchristie/") + assert response.status_code == 200 + assert response.url == "http://testserver/users/tomchristie" + assert response.text == "User tomchristie" + response = client.get("/users/nomatch") assert response.status_code == 200 assert response.text == "User nomatch" @@ -125,13 +155,22 @@ def test_router(): assert response.text == "xxxxx" -def test_route_converters(): +def test_route_converters(client): # Test integer conversion response = client.get("/int/5") assert response.status_code == 200 assert response.json() == {"int": 5} assert app.url_path_for("int-convertor", param=5) == "/int/5" + # Test path with parentheses + response = client.get("/path-with-parentheses(7)") + assert response.status_code == 200 + assert response.json() == {"int": 7} + assert ( + app.url_path_for("path-with-parentheses", param=7) + == "/path-with-parentheses(7)" + ) + # Test float conversion response = client.get("/float/25.5") assert response.status_code == 200 @@ -146,6 +185,17 @@ def test_route_converters(): app.url_path_for("path-convertor", param="some/example") == "/path/some/example" ) + # Test UUID conversion + response = client.get("/uuid/ec38df32-ceda-4cfa-9b4a-1aeb94ad551a") + assert response.status_code == 200 + assert response.json() == {"uuid": "ec38df32-ceda-4cfa-9b4a-1aeb94ad551a"} + assert ( + app.url_path_for( + "uuid-convertor", param=uuid.UUID("ec38df32-ceda-4cfa-9b4a-1aeb94ad551a") + ) + == "/uuid/ec38df32-ceda-4cfa-9b4a-1aeb94ad551a" + ) + def test_url_path_for(): assert app.url_path_for("homepage") == "/" @@ -164,12 +214,24 @@ def test_url_for(): app.url_path_for("homepage").make_absolute_url(base_url="https://example.org") == "https://example.org/" ) + assert ( + app.url_path_for("homepage").make_absolute_url( + base_url="https://example.org/root_path/" + ) + == "https://example.org/root_path/" + ) assert ( app.url_path_for("user", username="tomchristie").make_absolute_url( base_url="https://example.org" ) == "https://example.org/users/tomchristie" ) + assert ( + app.url_path_for("user", username="tomchristie").make_absolute_url( + base_url="https://example.org/root_path/" + ) + == "https://example.org/root_path/users/tomchristie" + ) assert ( app.url_path_for("websocket_endpoint").make_absolute_url( base_url="https://example.org" @@ -178,19 +240,19 @@ def test_url_for(): ) -def test_router_add_route(): +def test_router_add_route(client): response = client.get("/func") assert response.status_code == 200 assert response.text == "Hello, world!" -def test_router_duplicate_path(): +def test_router_duplicate_path(client): response = client.post("/func") assert response.status_code == 200 assert response.text == "Hello, POST!" -def test_router_add_websocket_route(): +def test_router_add_websocket_route(client): with client.websocket_connect("/ws") as session: text = session.receive_text() assert text == "Hello, world!" @@ -221,8 +283,8 @@ async def __call__(self, scope, receive, send): ) -def test_protocol_switch(): - client = TestClient(mixed_protocol_app) +def test_protocol_switch(test_client_factory): + client = test_client_factory(mixed_protocol_app) response = client.get("/") assert response.status_code == 200 @@ -232,15 +294,16 @@ def test_protocol_switch(): assert session.receive_json() == {"URL": "ws://testserver/"} with pytest.raises(WebSocketDisconnect): - client.websocket_connect("/404") + with client.websocket_connect("/404"): + pass # pragma: nocover ok = PlainTextResponse("OK") -def test_mount_urls(): +def test_mount_urls(test_client_factory): mounted = Router([Mount("/users", ok, name="users")]) - client = TestClient(mounted) + client = test_client_factory(mounted) assert client.get("/users").status_code == 200 assert client.get("/users").url == "http://testserver/users/" assert client.get("/users/").status_code == 200 @@ -263,9 +326,9 @@ def test_reverse_mount_urls(): ) -def test_mount_at_root(): +def test_mount_at_root(test_client_factory): mounted = Router([Mount("/", ok, name="users")]) - client = TestClient(mounted) + client = test_client_factory(mounted) assert client.get("/").status_code == 200 @@ -293,8 +356,8 @@ def users_api(request): ) -def test_host_routing(): - client = TestClient(mixed_hosts_app, base_url="https://api.example.org/") +def test_host_routing(test_client_factory): + client = test_client_factory(mixed_hosts_app, base_url="https://api.example.org/") response = client.get("/users") assert response.status_code == 200 @@ -303,7 +366,7 @@ def test_host_routing(): response = client.get("/") assert response.status_code == 404 - client = TestClient(mixed_hosts_app, base_url="https://www.example.org/") + client = test_client_factory(mixed_hosts_app, base_url="https://www.example.org/") response = client.get("/users") assert response.status_code == 200 @@ -338,8 +401,8 @@ async def subdomain_app(scope, receive, send): ) -def test_subdomain_routing(): - client = TestClient(subdomain_app, base_url="https://foo.example.org/") +def test_subdomain_routing(test_client_factory): + client = test_client_factory(subdomain_app, base_url="https://foo.example.org/") response = client.get("/") assert response.status_code == 200 @@ -353,3 +416,235 @@ def test_subdomain_reverse_urls(): ).make_absolute_url("https://whatever") == "https://foo.example.org/homepage" ) + + +async def echo_urls(request): + return JSONResponse( + { + "index": request.url_for("index"), + "submount": request.url_for("mount:submount"), + } + ) + + +echo_url_routes = [ + Route("/", echo_urls, name="index", methods=["GET"]), + Mount( + "/submount", + name="mount", + routes=[Route("/", echo_urls, name="submount", methods=["GET"])], + ), +] + + +def test_url_for_with_root_path(test_client_factory): + app = Starlette(routes=echo_url_routes) + client = test_client_factory( + app, base_url="https://www.example.org/", root_path="/sub_path" + ) + response = client.get("/") + assert response.json() == { + "index": "https://www.example.org/sub_path/", + "submount": "https://www.example.org/sub_path/submount/", + } + response = client.get("/submount/") + assert response.json() == { + "index": "https://www.example.org/sub_path/", + "submount": "https://www.example.org/sub_path/submount/", + } + + +async def stub_app(scope, receive, send): + pass # pragma: no cover + + +double_mount_routes = [ + Mount("/mount", name="mount", routes=[Mount("/static", stub_app, name="static")]), +] + + +def test_url_for_with_double_mount(): + app = Starlette(routes=double_mount_routes) + url = app.url_path_for("mount:static", path="123") + assert url == "/mount/static/123" + + +def test_standalone_route_matches(test_client_factory): + app = Route("/", PlainTextResponse("Hello, World!")) + client = test_client_factory(app) + response = client.get("/") + assert response.status_code == 200 + assert response.text == "Hello, World!" + + +def test_standalone_route_does_not_match(test_client_factory): + app = Route("/", PlainTextResponse("Hello, World!")) + client = test_client_factory(app) + response = client.get("/invalid") + assert response.status_code == 404 + assert response.text == "Not Found" + + +async def ws_helloworld(websocket): + await websocket.accept() + await websocket.send_text("Hello, world!") + await websocket.close() + + +def test_standalone_ws_route_matches(test_client_factory): + app = WebSocketRoute("/", ws_helloworld) + client = test_client_factory(app) + with client.websocket_connect("/") as websocket: + text = websocket.receive_text() + assert text == "Hello, world!" + + +def test_standalone_ws_route_does_not_match(test_client_factory): + app = WebSocketRoute("/", ws_helloworld) + client = test_client_factory(app) + with pytest.raises(WebSocketDisconnect): + with client.websocket_connect("/invalid"): + pass # pragma: nocover + + +def test_lifespan_async(test_client_factory): + startup_complete = False + shutdown_complete = False + + async def hello_world(request): + return PlainTextResponse("hello, world") + + async def run_startup(): + nonlocal startup_complete + startup_complete = True + + async def run_shutdown(): + nonlocal shutdown_complete + shutdown_complete = True + + app = Router( + on_startup=[run_startup], + on_shutdown=[run_shutdown], + routes=[Route("/", hello_world)], + ) + + assert not startup_complete + assert not shutdown_complete + with test_client_factory(app) as client: + assert startup_complete + assert not shutdown_complete + client.get("/") + assert startup_complete + assert shutdown_complete + + +def test_lifespan_sync(test_client_factory): + startup_complete = False + shutdown_complete = False + + def hello_world(request): + return PlainTextResponse("hello, world") + + def run_startup(): + nonlocal startup_complete + startup_complete = True + + def run_shutdown(): + nonlocal shutdown_complete + shutdown_complete = True + + app = Router( + on_startup=[run_startup], + on_shutdown=[run_shutdown], + routes=[Route("/", hello_world)], + ) + + assert not startup_complete + assert not shutdown_complete + with test_client_factory(app) as client: + assert startup_complete + assert not shutdown_complete + client.get("/") + assert startup_complete + assert shutdown_complete + + +def test_raise_on_startup(test_client_factory): + def run_startup(): + raise RuntimeError() + + router = Router(on_startup=[run_startup]) + + async def app(scope, receive, send): + async def _send(message): + nonlocal startup_failed + if message["type"] == "lifespan.startup.failed": + startup_failed = True + return await send(message) + + await router(scope, receive, _send) + + startup_failed = False + with pytest.raises(RuntimeError): + with test_client_factory(app): + pass # pragma: nocover + assert startup_failed + + +def test_raise_on_shutdown(test_client_factory): + def run_shutdown(): + raise RuntimeError() + + app = Router(on_shutdown=[run_shutdown]) + + with pytest.raises(RuntimeError): + with test_client_factory(app): + pass # pragma: nocover + + +class AsyncEndpointClassMethod: + @classmethod + async def async_endpoint(cls, arg, request): + return JSONResponse({"arg": arg}) + + +async def _partial_async_endpoint(arg, request): + return JSONResponse({"arg": arg}) + + +partial_async_endpoint = functools.partial(_partial_async_endpoint, "foo") +partial_cls_async_endpoint = functools.partial( + AsyncEndpointClassMethod.async_endpoint, "foo" +) + +partial_async_app = Router( + routes=[ + Route("/", partial_async_endpoint), + Route("/cls", partial_cls_async_endpoint), + ] +) + + +def test_partial_async_endpoint(test_client_factory): + test_client = test_client_factory(partial_async_app) + response = test_client.get("/") + assert response.status_code == 200 + assert response.json() == {"arg": "foo"} + + cls_method_response = test_client.get("/cls") + assert cls_method_response.status_code == 200 + assert cls_method_response.json() == {"arg": "foo"} + + +def test_duplicated_param_names(): + with pytest.raises( + ValueError, + match="Duplicated param name id at path /{id}/{id}", + ): + Route("/{id}/{id}", user) + + with pytest.raises( + ValueError, + match="Duplicated param names id, name at path /{id}/{name}/{id}/{name}", + ): + Route("/{id}/{name}/{id}/{name}", user) diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 0ae43238f..28fe777f0 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -1,7 +1,6 @@ from starlette.applications import Starlette from starlette.endpoints import HTTPEndpoint from starlette.schemas import SchemaGenerator -from starlette.testclient import TestClient schemas = SchemaGenerator( {"openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}} @@ -213,8 +212,8 @@ def test_schema_generation(): """ -def test_schema_endpoint(): - client = TestClient(app) +def test_schema_endpoint(test_client_factory): + client = test_client_factory(app) response = client.get("/schema") assert response.headers["Content-Type"] == "application/vnd.oai.openapi" assert response.text.strip() == EXPECTED_SCHEMA.strip() diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index 142b8f654..d5ec1afc5 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -1,100 +1,140 @@ -import asyncio import os +import pathlib import time +import anyio import pytest +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.routing import Mount from starlette.staticfiles import StaticFiles -from starlette.testclient import TestClient -def test_staticfiles(tmpdir): +def test_staticfiles(tmpdir, test_client_factory): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmpdir) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/example.txt") assert response.status_code == 200 assert response.text == "" -def test_staticfiles_with_package(): +def test_staticfiles_with_pathlib(tmpdir, test_client_factory): + base_dir = pathlib.Path(tmpdir) + path = base_dir / "example.txt" + with open(path, "w") as file: + file.write("") + + app = StaticFiles(directory=base_dir) + client = test_client_factory(app) + response = client.get("/example.txt") + assert response.status_code == 200 + assert response.text == "" + + +def test_staticfiles_head_with_middleware(tmpdir, test_client_factory): + """ + see https://github.com/encode/starlette/pull/935 + """ + path = os.path.join(tmpdir, "example.txt") + with open(path, "w") as file: + file.write("x" * 100) + + routes = [Mount("/static", app=StaticFiles(directory=tmpdir), name="static")] + app = Starlette(routes=routes) + + @app.middleware("http") + async def does_nothing_middleware(request: Request, call_next): + response = await call_next(request) + return response + + client = test_client_factory(app) + response = client.head("/static/example.txt") + assert response.status_code == 200 + assert response.headers.get("content-length") == "100" + + +def test_staticfiles_with_package(test_client_factory): app = StaticFiles(packages=["tests"]) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/example.txt") assert response.status_code == 200 assert response.text == "123\n" -def test_staticfiles_post(tmpdir): +def test_staticfiles_post(tmpdir, test_client_factory): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmpdir) - client = TestClient(app) + client = test_client_factory(app) response = client.post("/example.txt") assert response.status_code == 405 assert response.text == "Method Not Allowed" -def test_staticfiles_with_directory_returns_404(tmpdir): +def test_staticfiles_with_directory_returns_404(tmpdir, test_client_factory): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmpdir) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.status_code == 404 assert response.text == "Not Found" -def test_staticfiles_with_missing_file_returns_404(tmpdir): +def test_staticfiles_with_missing_file_returns_404(tmpdir, test_client_factory): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmpdir) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/404.txt") assert response.status_code == 404 assert response.text == "Not Found" def test_staticfiles_instantiated_with_missing_directory(tmpdir): - with pytest.raises(RuntimeError) as exc: + with pytest.raises(RuntimeError) as exc_info: path = os.path.join(tmpdir, "no_such_directory") StaticFiles(directory=path) - assert "does not exist" in str(exc) + assert "does not exist" in str(exc_info.value) -def test_staticfiles_configured_with_missing_directory(tmpdir): +def test_staticfiles_configured_with_missing_directory(tmpdir, test_client_factory): path = os.path.join(tmpdir, "no_such_directory") app = StaticFiles(directory=path, check_dir=False) - client = TestClient(app) - with pytest.raises(RuntimeError) as exc: + client = test_client_factory(app) + with pytest.raises(RuntimeError) as exc_info: client.get("/example.txt") - assert "does not exist" in str(exc) + assert "does not exist" in str(exc_info.value) -def test_staticfiles_configured_with_file_instead_of_directory(tmpdir): +def test_staticfiles_configured_with_file_instead_of_directory( + tmpdir, test_client_factory +): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=path, check_dir=False) - client = TestClient(app) - with pytest.raises(RuntimeError) as exc: + client = test_client_factory(app) + with pytest.raises(RuntimeError) as exc_info: client.get("/example.txt") - assert "is not a directory" in str(exc) + assert "is not a directory" in str(exc_info.value) -def test_staticfiles_config_check_occurs_only_once(tmpdir): +def test_staticfiles_config_check_occurs_only_once(tmpdir, test_client_factory): app = StaticFiles(directory=tmpdir) - client = TestClient(app) + client = test_client_factory(app) assert not app.config_checked client.get("/") assert app.config_checked @@ -114,32 +154,31 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir): # We can't test this with 'requests', so we test the app directly here. path = app.get_path({"path": "/../example.txt"}) scope = {"method": "GET"} - loop = asyncio.get_event_loop() - response = loop.run_until_complete(app.get_response(path, scope)) + response = anyio.run(app.get_response, path, scope) assert response.status_code == 404 assert response.body == b"Not Found" -def test_staticfiles_never_read_file_for_head_method(tmpdir): +def test_staticfiles_never_read_file_for_head_method(tmpdir, test_client_factory): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmpdir) - client = TestClient(app) + client = test_client_factory(app) response = client.head("/example.txt") assert response.status_code == 200 assert response.content == b"" assert response.headers["content-length"] == "14" -def test_staticfiles_304_with_etag_match(tmpdir): +def test_staticfiles_304_with_etag_match(tmpdir, test_client_factory): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmpdir) - client = TestClient(app) + client = test_client_factory(app) first_resp = client.get("/example.txt") assert first_resp.status_code == 200 last_etag = first_resp.headers["etag"] @@ -148,7 +187,9 @@ def test_staticfiles_304_with_etag_match(tmpdir): assert second_resp.content == b"" -def test_staticfiles_304_with_last_modified_compare_last_req(tmpdir): +def test_staticfiles_304_with_last_modified_compare_last_req( + tmpdir, test_client_factory +): path = os.path.join(tmpdir, "example.txt") file_last_modified_time = time.mktime( time.strptime("2013-10-10 23:40:00", "%Y-%m-%d %H:%M:%S") @@ -158,7 +199,7 @@ def test_staticfiles_304_with_last_modified_compare_last_req(tmpdir): os.utime(path, (file_last_modified_time, file_last_modified_time)) app = StaticFiles(directory=tmpdir) - client = TestClient(app) + client = test_client_factory(app) # last modified less than last request, 304 response = client.get( "/example.txt", headers={"If-Modified-Since": "Thu, 11 Oct 2013 15:30:19 GMT"} @@ -173,7 +214,7 @@ def test_staticfiles_304_with_last_modified_compare_last_req(tmpdir): assert response.content == b"" -def test_staticfiles_html(tmpdir): +def test_staticfiles_html(tmpdir, test_client_factory): path = os.path.join(tmpdir, "404.html") with open(path, "w") as file: file.write("

Custom not found page

") @@ -184,7 +225,7 @@ def test_staticfiles_html(tmpdir): file.write("

Hello

") app = StaticFiles(directory=tmpdir, html=True) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/dir/") assert response.url == "http://testserver/dir/" @@ -204,3 +245,42 @@ def test_staticfiles_html(tmpdir): response = client.get("/missing") assert response.status_code == 404 assert response.text == "

Custom not found page

" + + +def test_staticfiles_cache_invalidation_for_deleted_file_html_mode( + tmpdir, test_client_factory +): + path_404 = os.path.join(tmpdir, "404.html") + with open(path_404, "w") as file: + file.write("

404 file

") + path_some = os.path.join(tmpdir, "some.html") + with open(path_some, "w") as file: + file.write("

some file

") + + common_modified_time = time.mktime( + time.strptime("2013-10-10 23:40:00", "%Y-%m-%d %H:%M:%S") + ) + os.utime(path_404, (common_modified_time, common_modified_time)) + os.utime(path_some, (common_modified_time, common_modified_time)) + + app = StaticFiles(directory=tmpdir, html=True) + client = test_client_factory(app) + + resp_exists = client.get("/some.html") + assert resp_exists.status_code == 200 + assert resp_exists.text == "

some file

" + + resp_cached = client.get( + "/some.html", + headers={"If-Modified-Since": resp_exists.headers["last-modified"]}, + ) + assert resp_cached.status_code == 304 + + os.remove(path_some) + + resp_deleted = client.get( + "/some.html", + headers={"If-Modified-Since": resp_exists.headers["last-modified"]}, + ) + assert resp_deleted.status_code == 404 + assert resp_deleted.text == "

404 file

" diff --git a/tests/test_templates.py b/tests/test_templates.py index a0ab3e1b0..073482d65 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -4,10 +4,9 @@ from starlette.applications import Starlette from starlette.templating import Jinja2Templates -from starlette.testclient import TestClient -def test_templates(tmpdir): +def test_templates(tmpdir, test_client_factory): path = os.path.join(tmpdir, "index.html") with open(path, "w") as file: file.write("Hello, world") @@ -19,7 +18,7 @@ def test_templates(tmpdir): async def homepage(request): return templates.TemplateResponse("index.html", {"request": request}) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "Hello, world" assert response.template.name == "index.html" diff --git a/tests/test_testclient.py b/tests/test_testclient.py index 6569990fb..8c0666789 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -1,12 +1,24 @@ import asyncio +import itertools +import sys +import anyio import pytest +import sniffio +import trio.lowlevel from starlette.applications import Starlette +from starlette.middleware import Middleware from starlette.responses import JSONResponse -from starlette.testclient import TestClient from starlette.websockets import WebSocket, WebSocketDisconnect +if sys.version_info >= (3, 7): # pragma: no cover + from asyncio import current_task as asyncio_current_task + from contextlib import asynccontextmanager +else: # pragma: no cover + asyncio_current_task = asyncio.Task.current_task + from contextlib2 import asynccontextmanager + mock_service = Starlette() @@ -15,14 +27,19 @@ def mock_service_endpoint(request): return JSONResponse({"mock": "example"}) -app = Starlette() - +def current_task(): + # anyio's TaskInfo comparisons are invalid after their associated native + # task object is GC'd https://github.com/agronholm/anyio/issues/324 + asynclib_name = sniffio.current_async_library() + if asynclib_name == "trio": + return trio.lowlevel.current_task() -@app.route("/") -def homepage(request): - client = TestClient(mock_service) - response = client.get("/") - return JSONResponse(response.json()) + if asynclib_name == "asyncio": + task = asyncio_current_task() + if task is None: + raise RuntimeError("must be called from a running task") # pragma: no cover + return task + raise RuntimeError(f"unsupported asynclib={asynclib_name}") # pragma: no cover startup_error_app = Starlette() @@ -33,30 +50,128 @@ def startup(): raise RuntimeError() -def test_use_testclient_in_endpoint(): +def test_use_testclient_in_endpoint(test_client_factory): """ We should be able to use the test client within applications. This is useful if we need to mock out other services, during tests or in development. """ - client = TestClient(app) + + app = Starlette() + + @app.route("/") + def homepage(request): + client = test_client_factory(mock_service) + response = client.get("/") + return JSONResponse(response.json()) + + client = test_client_factory(app) response = client.get("/") assert response.json() == {"mock": "example"} -def testclient_as_contextmanager(): - with TestClient(app): - pass +def test_use_testclient_as_contextmanager(test_client_factory, anyio_backend_name): + """ + This test asserts a number of properties that are important for an + app level task_group + """ + counter = itertools.count() + identity_runvar = anyio.lowlevel.RunVar[int]("identity_runvar") + + def get_identity(): + try: + return identity_runvar.get() + except LookupError: + token = next(counter) + identity_runvar.set(token) + return token + + startup_task = object() + startup_loop = None + shutdown_task = object() + shutdown_loop = None + + @asynccontextmanager + async def lifespan_context(app): + nonlocal startup_task, startup_loop, shutdown_task, shutdown_loop + + startup_task = current_task() + startup_loop = get_identity() + async with anyio.create_task_group() as app.task_group: + yield + shutdown_task = current_task() + shutdown_loop = get_identity() + + app = Starlette(lifespan=lifespan_context) + + @app.route("/loop_id") + async def loop_id(request): + return JSONResponse(get_identity()) + + client = test_client_factory(app) + + with client: + # within a TestClient context every async request runs in the same thread + assert client.get("/loop_id").json() == 0 + assert client.get("/loop_id").json() == 0 + # that thread is also the same as the lifespan thread + assert startup_loop == 0 + assert shutdown_loop == 0 -def test_error_on_startup(): + # lifespan events run in the same task, this is important because a task + # group must be entered and exited in the same task. + assert startup_task is shutdown_task + + # outside the TestClient context, new requests continue to spawn in new + # eventloops in new threads + assert client.get("/loop_id").json() == 1 + assert client.get("/loop_id").json() == 2 + + first_task = startup_task + + with client: + # the TestClient context can be re-used, starting a new lifespan task + # in a new thread + assert client.get("/loop_id").json() == 3 + assert client.get("/loop_id").json() == 3 + + assert startup_loop == 3 + assert shutdown_loop == 3 + + # lifespan events still run in the same task, with the context but... + assert startup_task is shutdown_task + + # ... the second TestClient context creates a new lifespan task. + assert first_task is not startup_task + + +def test_error_on_startup(test_client_factory): with pytest.raises(RuntimeError): - with TestClient(startup_error_app): + with test_client_factory(startup_error_app): + pass # pragma: no cover + + +def test_exception_in_middleware(test_client_factory): + class MiddlewareException(Exception): + pass + + class BrokenMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + raise MiddlewareException() + + broken_middleware = Starlette(middleware=[Middleware(BrokenMiddleware)]) + + with pytest.raises(MiddlewareException): + with test_client_factory(broken_middleware): pass # pragma: no cover -def test_testclient_asgi2(): +def test_testclient_asgi2(test_client_factory): def app(scope): async def inner(receive, send): await send( @@ -70,12 +185,12 @@ async def inner(receive, send): return inner - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "Hello, world!" -def test_testclient_asgi3(): +def test_testclient_asgi3(test_client_factory): async def app(scope, receive, send): await send( { @@ -86,12 +201,12 @@ async def app(scope, receive, send): ) await send({"type": "http.response.body", "body": b"Hello, world!"}) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "Hello, world!" -def test_websocket_blocking_receive(): +def test_websocket_blocking_receive(test_client_factory): def app(scope): async def respond(websocket): await websocket.send_json({"message": "test"}) @@ -99,17 +214,18 @@ async def respond(websocket): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() - asyncio.ensure_future(respond(websocket)) - try: - # this will block as the client does not send us data - # it should not prevent `respond` from executing though - await websocket.receive_json() - except WebSocketDisconnect: - pass + async with anyio.create_task_group() as task_group: + task_group.start_soon(respond, websocket) + try: + # this will block as the client does not send us data + # it should not prevent `respond` from executing though + await websocket.receive_json() + except WebSocketDisconnect: + pass return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: data = websocket.receive_json() assert data == {"message": "test"} diff --git a/tests/test_websockets.py b/tests/test_websockets.py index 41e77237c..e02d433d5 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -1,11 +1,11 @@ +import anyio import pytest from starlette import status -from starlette.testclient import TestClient from starlette.websockets import WebSocket, WebSocketDisconnect -def test_websocket_url(): +def test_websocket_url(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -15,13 +15,13 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/123?a=abc") as websocket: data = websocket.receive_json() assert data == {"url": "ws://testserver/123?a=abc"} -def test_websocket_binary_json(): +def test_websocket_binary_json(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -32,14 +32,14 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/123?a=abc") as websocket: websocket.send_json({"test": "data"}, mode="binary") data = websocket.receive_json(mode="binary") assert data == {"test": "data"} -def test_websocket_query_params(): +def test_websocket_query_params(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -50,13 +50,13 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/?a=abc&b=456") as websocket: data = websocket.receive_json() assert data == {"params": {"a": "abc", "b": "456"}} -def test_websocket_headers(): +def test_websocket_headers(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -67,7 +67,7 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: expected_headers = { "accept": "*/*", @@ -82,7 +82,7 @@ async def asgi(receive, send): assert data == {"headers": expected_headers} -def test_websocket_port(): +def test_websocket_port(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -92,13 +92,13 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("ws://example.com:123/123?a=abc") as websocket: data = websocket.receive_json() assert data == {"port": 123} -def test_websocket_send_and_receive_text(): +def test_websocket_send_and_receive_text(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -109,14 +109,14 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_text("Hello, world!") data = websocket.receive_text() assert data == "Message was: Hello, world!" -def test_websocket_send_and_receive_bytes(): +def test_websocket_send_and_receive_bytes(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -127,14 +127,14 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_bytes(b"Hello, world!") data = websocket.receive_bytes() assert data == b"Message was: Hello, world!" -def test_websocket_send_and_receive_json(): +def test_websocket_send_and_receive_json(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -145,14 +145,96 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_json({"hello": "world"}) data = websocket.receive_json() assert data == {"message": {"hello": "world"}} -def test_client_close(): +def test_websocket_iter_text(test_client_factory): + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + async for data in websocket.iter_text(): + await websocket.send_text("Message was: " + data) + + return asgi + + client = test_client_factory(app) + with client.websocket_connect("/") as websocket: + websocket.send_text("Hello, world!") + data = websocket.receive_text() + assert data == "Message was: Hello, world!" + + +def test_websocket_iter_bytes(test_client_factory): + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + async for data in websocket.iter_bytes(): + await websocket.send_bytes(b"Message was: " + data) + + return asgi + + client = test_client_factory(app) + with client.websocket_connect("/") as websocket: + websocket.send_bytes(b"Hello, world!") + data = websocket.receive_bytes() + assert data == b"Message was: Hello, world!" + + +def test_websocket_iter_json(test_client_factory): + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + async for data in websocket.iter_json(): + await websocket.send_json({"message": data}) + + return asgi + + client = test_client_factory(app) + with client.websocket_connect("/") as websocket: + websocket.send_json({"hello": "world"}) + data = websocket.receive_json() + assert data == {"message": {"hello": "world"}} + + +def test_websocket_concurrency_pattern(test_client_factory): + def app(scope): + stream_send, stream_receive = anyio.create_memory_object_stream() + + async def reader(websocket): + async with stream_send: + async for data in websocket.iter_json(): + await stream_send.send(data) + + async def writer(websocket): + async with stream_receive: + async for message in stream_receive: + await websocket.send_json(message) + + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + async with anyio.create_task_group() as task_group: + task_group.start_soon(reader, websocket) + await writer(websocket) + await websocket.close() + + return asgi + + client = test_client_factory(app) + with client.websocket_connect("/") as websocket: + websocket.send_json({"hello": "world"}) + data = websocket.receive_json() + assert data == {"hello": "world"} + + +def test_client_close(test_client_factory): close_code = None def app(scope): @@ -167,13 +249,13 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.close(code=status.WS_1001_GOING_AWAY) assert close_code == status.WS_1001_GOING_AWAY -def test_application_close(): +def test_application_close(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -182,14 +264,14 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: with pytest.raises(WebSocketDisconnect) as exc: websocket.receive_text() assert exc.value.code == status.WS_1001_GOING_AWAY -def test_rejected_connection(): +def test_rejected_connection(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -197,13 +279,14 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(WebSocketDisconnect) as exc: - client.websocket_connect("/") + with client.websocket_connect("/"): + pass # pragma: nocover assert exc.value.code == status.WS_1001_GOING_AWAY -def test_subprotocol(): +def test_subprotocol(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -213,24 +296,25 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/", subprotocols=["soap", "wamp"]) as websocket: assert websocket.accepted_subprotocol == "wamp" -def test_websocket_exception(): +def test_websocket_exception(test_client_factory): def app(scope): async def asgi(receive, send): assert False return asgi - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(AssertionError): - client.websocket_connect("/123?a=abc") + with client.websocket_connect("/123?a=abc"): + pass # pragma: nocover -def test_duplicate_close(): +def test_duplicate_close(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -240,13 +324,13 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): - pass + pass # pragma: nocover -def test_duplicate_disconnect(): +def test_duplicate_disconnect(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -257,7 +341,7 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/") as websocket: websocket.close() @@ -283,3 +367,13 @@ async def mock_send(message): assert websocket["type"] == "websocket" assert dict(websocket) == {"type": "websocket", "path": "/abc/", "headers": []} assert len(websocket) == 3 + + # check __eq__ and __hash__ + assert websocket != WebSocket( + {"type": "websocket", "path": "/abc/", "headers": []}, + receive=mock_receive, + send=mock_send, + ) + assert websocket == websocket + assert websocket in {websocket} + assert {websocket} == {websocket} From db9bf2c7cf7ff4fadd6bccb1bacb4245b968d6fb Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 11 Aug 2021 21:02:08 +0200 Subject: [PATCH 05/14] Add WebSocketException and support for WS handlers --- starlette/exceptions.py | 19 +++++-------------- starlette/middleware/errors.py | 2 +- tests/test_applications.py | 12 ++++++------ tests/test_exceptions.py | 13 +++++++++++-- 4 files changed, 23 insertions(+), 23 deletions(-) diff --git a/starlette/exceptions.py b/starlette/exceptions.py index 557534d87..d32807397 100644 --- a/starlette/exceptions.py +++ b/starlette/exceptions.py @@ -2,7 +2,6 @@ import http import typing -from starlette import status from starlette.concurrency import run_in_threadpool from starlette.requests import Request from starlette.responses import PlainTextResponse, Response @@ -23,21 +22,13 @@ def __repr__(self) -> str: class WebSocketException(Exception): - def __init__(self, code: int = status.WS_1008_POLICY_VIOLATION) -> None: - """ - `code` defaults to 1008, from the WebSocket specification: - - > 1008 indicates that an endpoint is terminating the connection - > because it has received a message that violates its policy. This - > is a generic status code that can be returned when there is no - > other more suitable status code (e.g., 1003 or 1009) or if there - > is a need to hide specific details about the policy. - - Set `code` to any value allowed by the - [WebSocket specification](https://tools.ietf.org/html/rfc6455#section-7.4.1). - """ + def __init__(self, code: int) -> None: self.code = code + def __repr__(self) -> str: + class_name = self.__class__.__name__ + return f"{class_name}(code={self.code!r})" + class ExceptionMiddleware: def __init__( diff --git a/starlette/middleware/errors.py b/starlette/middleware/errors.py index 0eaae03ad..041e3468d 100644 --- a/starlette/middleware/errors.py +++ b/starlette/middleware/errors.py @@ -160,7 +160,7 @@ async def _send(message: Message) -> None: try: await self.app(scope, receive, _send) except Exception as exc: - if not response_started and scope["type"] == "http": + if scope["type"] == "http" and not response_started: request = Request(scope) if self.debug: # In debug mode, return traceback responses. diff --git a/tests/test_applications.py b/tests/test_applications.py index aaccabab0..c7c3e9d57 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -1,7 +1,7 @@ -import asyncio import os import sys +import anyio import pytest from starlette import status @@ -12,6 +12,7 @@ from starlette.responses import JSONResponse, PlainTextResponse from starlette.routing import Host, Mount, Route, Router, WebSocketRoute from starlette.staticfiles import StaticFiles +from starlette.websockets import WebSocket if sys.version_info >= (3, 7): from contextlib import asynccontextmanager # pragma: no cover @@ -96,7 +97,7 @@ async def websocket_endpoint(session): @app.websocket_route("/ws-raise-websocket") -async def websocket_raise_websocket_exception(websocket): +async def websocket_raise_websocket_exception(websocket: WebSocket): await websocket.accept() raise WebSocketException(code=status.WS_1003_UNSUPPORTED_DATA) @@ -106,15 +107,14 @@ class CustomWSException(Exception): @app.websocket_route("/ws-raise-custom") -async def websocket_raise_custom(websocket): +async def websocket_raise_custom(websocket: WebSocket): await websocket.accept() raise CustomWSException() @app.exception_handler(CustomWSException) -def custom_ws_exception_handler(websocket, exc): - loop = asyncio.new_event_loop() - loop.run_until_complete(websocket.close(code=status.WS_1013_TRY_AGAIN_LATER)) +def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException): + anyio.from_thread.run(websocket.close, status.WS_1013_TRY_AGAIN_LATER) @pytest.fixture diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 5fba9981b..e6c2ad72d 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -1,6 +1,6 @@ import pytest -from starlette.exceptions import ExceptionMiddleware, HTTPException +from starlette.exceptions import ExceptionMiddleware, HTTPException, WebSocketException from starlette.responses import PlainTextResponse from starlette.routing import Route, Router, WebSocketRoute @@ -86,7 +86,7 @@ def app(scope): assert response.text == "" -def test_repr(): +def test_http_repr(): assert repr(HTTPException(404)) == ( "HTTPException(status_code=404, detail='Not Found')" ) @@ -100,3 +100,12 @@ class CustomHTTPException(HTTPException): assert repr(CustomHTTPException(500, detail="Something custom")) == ( "CustomHTTPException(status_code=500, detail='Something custom')" ) + + +def test_websocket_repr(): + assert repr(WebSocketException(1008)) == ("WebSocketException(code=1008)") + + class CustomWebSocketException(WebSocketException): + pass + + assert repr(CustomWebSocketException(1013)) == "CustomWebSocketException(code=1013)" From bdd0d2f922eb8904b0fcc7064fe1ed76769849f3 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 27 Jan 2022 11:40:02 -0600 Subject: [PATCH 06/14] incorporate reason from #1417 --- tests/test_applications.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_applications.py b/tests/test_applications.py index c7c3e9d57..10e4c02af 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -204,6 +204,7 @@ def test_websocket_raise_websocket_exception(client): assert response == { "type": "websocket.close", "code": status.WS_1003_UNSUPPORTED_DATA, + "reason": "", } @@ -213,6 +214,7 @@ def test_websocket_raise_custom_exception(client): assert response == { "type": "websocket.close", "code": status.WS_1013_TRY_AGAIN_LATER, + "reason": "", } From 41b06d2378243ba7fee5ce28207d5ea316c2d6ac Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 10 May 2022 07:44:10 +0200 Subject: [PATCH 07/14] Add reason to WebSocketException --- starlette/exceptions.py | 7 ++++--- tests/test_exceptions.py | 9 +++++++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/starlette/exceptions.py b/starlette/exceptions.py index 9eeab7933..6b173391a 100644 --- a/starlette/exceptions.py +++ b/starlette/exceptions.py @@ -28,12 +28,13 @@ def __repr__(self) -> str: class WebSocketException(Exception): - def __init__(self, code: int) -> None: + def __init__(self, code: int, reason: typing.Optional[str] = None) -> None: self.code = code + self.reason = reason or "" def __repr__(self) -> str: class_name = self.__class__.__name__ - return f"{class_name}(code={self.code!r})" + return f"{class_name}(code={self.code!r}, reason={self.reason!r})" class ExceptionMiddleware: @@ -133,4 +134,4 @@ def http_exception(self, request: Request, exc: HTTPException) -> Response: async def websocket_exception( self, websocket: WebSocket, exc: WebSocketException ) -> None: - await websocket.close(code=exc.code) + await websocket.close(code=exc.code, reason=exc.reason) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 7824f8bae..c61186e64 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -133,9 +133,14 @@ class CustomHTTPException(HTTPException): def test_websocket_repr(): - assert repr(WebSocketException(1008)) == ("WebSocketException(code=1008)") + assert repr(WebSocketException(1008, reason="Policy Violation")) == ( + "WebSocketException(code=1008, reason='Policy Violation')" + ) class CustomWebSocketException(WebSocketException): pass - assert repr(CustomWebSocketException(1013)) == "CustomWebSocketException(code=1013)" + assert ( + repr(CustomWebSocketException(1013, reason="Something custom")) + == "CustomWebSocketException(code=1013, reason='Something custom')" + ) From 193d9c104306aa0a3f0906d78ad257affe1cc03d Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sat, 13 Aug 2022 14:32:12 +0200 Subject: [PATCH 08/14] Use Starlette's documentation style --- docs/exceptions.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/exceptions.md b/docs/exceptions.md index 13e22666f..c84a44d31 100644 --- a/docs/exceptions.md +++ b/docs/exceptions.md @@ -65,9 +65,12 @@ async def http_exception(request: Request, exc: HTTPException): You might also want to override how `WebSocketException` is handled: ```python -@app.exception_handler(WebSocketException) -async def websocket_exception(websocket, exc): +async def websocket_exception(websocket: WebSocket, exc: WebSocketException): await websocket.close(code=1008) + +exception_handlers = { + WebSocketException: websocket_exception +} ``` ## Errors and handled exceptions From 54a41961cec5b1877406c2702da6bd07bc83cc9b Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sat, 13 Aug 2022 14:34:43 +0200 Subject: [PATCH 09/14] Add `reason` to documentation --- docs/exceptions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/exceptions.md b/docs/exceptions.md index c84a44d31..f97f1af89 100644 --- a/docs/exceptions.md +++ b/docs/exceptions.md @@ -128,6 +128,6 @@ classes should instead just return appropriate responses directly. You can use the `WebSocketException` class to raise errors inside of WebSocket endpoints. -* `WebSocketException(code=1008)` +* `WebSocketException(code=1008, reason=None)` You can set any code valid as defined [in the specification](https://tools.ietf.org/html/rfc6455#section-7.4.1). From d572464dff006d77951f8a432699fb7d98a06208 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sat, 13 Aug 2022 14:50:26 +0200 Subject: [PATCH 10/14] Remove server logic --- starlette/middleware/errors.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/starlette/middleware/errors.py b/starlette/middleware/errors.py index dd209fa3a..3f16a5165 100644 --- a/starlette/middleware/errors.py +++ b/starlette/middleware/errors.py @@ -180,13 +180,6 @@ async def _send(message: Message) -> None: if not response_started: await response(scope, receive, send) - elif scope["type"] == "websocket": - websocket = WebSocket(scope, receive, send) - # https://tools.ietf.org/html/rfc6455#section-7.4.1 - # 1011 indicates that a server is terminating the connection because - # it encountered an unexpected condition that prevented it from - # fulfilling the request. - await websocket.close(code=status.WS_1011_INTERNAL_ERROR) # We always continue to raise the exception. # This allows servers to log the error, or allows test clients From ca6b2e03df4bce19b6c0f9fe778f2d404898f714 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sat, 13 Aug 2022 15:47:33 +0200 Subject: [PATCH 11/14] Remove unused imports --- starlette/middleware/errors.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/starlette/middleware/errors.py b/starlette/middleware/errors.py index 3f16a5165..90d0a5959 100644 --- a/starlette/middleware/errors.py +++ b/starlette/middleware/errors.py @@ -3,13 +3,11 @@ import traceback import typing -from starlette import status from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool from starlette.requests import Request from starlette.responses import HTMLResponse, PlainTextResponse, Response from starlette.types import ASGIApp, Message, Receive, Scope, Send -from starlette.websockets import WebSocket STYLES = """ p { From 8ca322f471b6eacf34d3f50b475d76a760ec1e0c Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sat, 13 Aug 2022 15:50:41 +0200 Subject: [PATCH 12/14] Fix failing test --- tests/middleware/test_errors.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/middleware/test_errors.py b/tests/middleware/test_errors.py index bcd389a8f..414038246 100644 --- a/tests/middleware/test_errors.py +++ b/tests/middleware/test_errors.py @@ -5,7 +5,6 @@ from starlette.middleware.errors import ServerErrorMiddleware from starlette.responses import JSONResponse, Response from starlette.routing import Route -from starlette.websockets import WebSocketDisconnect def test_handler(test_client_factory): @@ -64,7 +63,7 @@ async def app(scope, receive, send): app = ServerErrorMiddleware(app) - with pytest.raises(WebSocketDisconnect): + with pytest.raises(RuntimeError): client = test_client_factory(app) with client.websocket_connect("/"): pass # pragma: nocover From d028f5dca6922572fe321c57e8d879862680c4fb Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sat, 13 Aug 2022 15:54:04 +0200 Subject: [PATCH 13/14] ServerErrorMiddleware is only http again --- starlette/middleware/errors.py | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/starlette/middleware/errors.py b/starlette/middleware/errors.py index 90d0a5959..052b885f4 100644 --- a/starlette/middleware/errors.py +++ b/starlette/middleware/errors.py @@ -145,7 +145,7 @@ def __init__( self.debug = debug async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - if scope["type"] not in {"http", "websocket"}: + if scope["type"] != "http": await self.app(scope, receive, send) return @@ -161,23 +161,22 @@ async def _send(message: Message) -> None: try: await self.app(scope, receive, _send) except Exception as exc: - if scope["type"] == "http": - request = Request(scope) - if self.debug: - # In debug mode, return traceback responses. - response = self.debug_response(request, exc) - elif self.handler is None: - # Use our default 500 error handler. - response = self.error_response(request, exc) + request = Request(scope) + if self.debug: + # In debug mode, return traceback responses. + response = self.debug_response(request, exc) + elif self.handler is None: + # Use our default 500 error handler. + response = self.error_response(request, exc) + else: + # Use an installed 500 error handler. + if is_async_callable(self.handler): + response = await self.handler(request, exc) else: - # Use an installed 500 error handler. - if is_async_callable(self.handler): - response = await self.handler(request, exc) - else: - response = await run_in_threadpool(self.handler, request, exc) - - if not response_started: - await response(scope, receive, send) + response = await run_in_threadpool(self.handler, request, exc) + + if not response_started: + await response(scope, receive, send) # We always continue to raise the exception. # This allows servers to log the error, or allows test clients From 97a3eef3f7c39e9742e79356b1a3fcbcd7ad1cfd Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sat, 13 Aug 2022 16:13:22 +0200 Subject: [PATCH 14/14] Rename test --- tests/middleware/test_errors.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/middleware/test_errors.py b/tests/middleware/test_errors.py index 414038246..392c2ba16 100644 --- a/tests/middleware/test_errors.py +++ b/tests/middleware/test_errors.py @@ -57,7 +57,11 @@ async def app(scope, receive, send): client.get("/") -def test_debug_websocket(test_client_factory): +def test_debug_not_http(test_client_factory): + """ + DebugMiddleware should just pass through any non-http messages as-is. + """ + async def app(scope, receive, send): raise RuntimeError("Something went wrong")