From c751499583f92199a1793a24e81ac5c8998d1b1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sondre=20Lilleb=C3=B8=20Gundersen?= Date: Thu, 13 Jan 2022 11:29:17 +0100 Subject: [PATCH 01/10] Correct spelling of asynchronous in events.md (#1408) Remove count of available convertors (#1409) Co-authored-by: Micheal Gendy Fix md5_hexdigest wrapper on FIPS enabled systems (#1410) * Fix md5_hexdigest wrapper on FIPS enabled systems * Update _compat.py * lint Use typing `NoReturn` (#1412) change github issues template Sort third-party packages and add `starlette-wtf` (#1415) Improvements on authentication documentation (#1420) * Use `conn` in `AuthenticationBackend` documentation * Remove unused import in `AuthenticationBackend` documentation * Add missing imports in authentication documentation Co-authored-by: Marcelo Trylesinski Add third-party CSRF middlewares (#1414) * change github issues template * Add third-party CSRF middlewares Co-authored-by: Tom Christie Allow Environment options in `Jinja2Templates` (#1401) Adjust type of `exception_handlers` to allow async callable (#1423) Default WebSocket accept message headers to an empty list (#1422) * If no extra headers are passed, set it to an empty list * Test websocket.accept() with no additional headers * Update starlette/websockets.py Co-authored-by: Marcelo Trylesinski * Update tests/test_websockets.py Co-authored-by: Amin Alaee * Update tests/test_websockets.py Co-authored-by: Marcelo Trylesinski Co-authored-by: Marcelo Trylesinski Co-authored-by: Amin Alaee Add reason to WebSocket closure (#1417) Co-authored-by: Marcelo Trylesinski Version 0.18.0 (#1380) * Version 0.18.0 * Add changes until 14 jan * Update release-notes.md * Update release-notes.md * Update release-notes.md * Update release-notes.md Co-authored-by: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Co-authored-by: Amin Alaee draft Update publish.yml (#1430) --- .github/workflows/publish.yml | 4 +- docs/authentication.md | 18 ++++-- docs/events.md | 2 +- docs/middleware.md | 8 +++ docs/release-notes.md | 22 +++++++ docs/routing.md | 2 +- docs/templates.md | 14 ++++ docs/third-party-packages.md | 116 +++++++++++++++++----------------- docs/websockets.md | 2 +- requirements.txt | 2 +- starlette/__init__.py | 2 +- starlette/_compat.py | 5 +- starlette/applications.py | 5 +- starlette/datastructures.py | 4 +- starlette/middleware/gzip.py | 3 +- starlette/requests.py | 6 +- starlette/responses.py | 23 +++++-- starlette/templating.py | 13 ++-- starlette/testclient.py | 4 +- starlette/websockets.py | 18 ++++-- tests/test_responses.py | 6 +- tests/test_websockets.py | 31 +++++++++ 22 files changed, 214 insertions(+), 96 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index b290d6e1a..514ed4b97 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -1,4 +1,3 @@ ---- name: Publish on: @@ -11,6 +10,9 @@ jobs: name: "Publish release" runs-on: "ubuntu-latest" + environment: + name: deploy + steps: - uses: "actions/checkout@v2" - uses: "actions/setup-python@v2" diff --git a/docs/authentication.md b/docs/authentication.md index d4af5b216..48eba6ca2 100644 --- a/docs/authentication.md +++ b/docs/authentication.md @@ -7,8 +7,7 @@ interfaces will be available in your endpoints. ```python from starlette.applications import Starlette from starlette.authentication import ( - AuthenticationBackend, AuthenticationError, SimpleUser, UnauthenticatedUser, - AuthCredentials + AuthCredentials, AuthenticationBackend, AuthenticationError, SimpleUser ) from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware @@ -19,11 +18,11 @@ import binascii class BasicAuthBackend(AuthenticationBackend): - async def authenticate(self, request): - if "Authorization" not in request.headers: + async def authenticate(self, conn): + if "Authorization" not in conn.headers: return - auth = request.headers["Authorization"] + auth = conn.headers["Authorization"] try: scheme, credentials = auth.split() if scheme.lower() != 'basic': @@ -136,6 +135,10 @@ For class-based endpoints, you should wrap the decorator around a method on the class. ```python +from starlette.authentication import requires +from starlette.endpoints import HTTPEndpoint + + class Dashboard(HTTPEndpoint): @requires("authenticated") async def get(self, request): @@ -148,6 +151,11 @@ You can customise the error response sent when a `AuthenticationError` is raised by an auth backend: ```python +from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse + + def on_auth_error(request: Request, exc: Exception): return JSONResponse({"error": str(exc)}, status_code=401) diff --git a/docs/events.md b/docs/events.md index c7ed49e9d..f1514f96e 100644 --- a/docs/events.md +++ b/docs/events.md @@ -60,7 +60,7 @@ app = Starlette(routes=routes, lifespan=lifespan) ``` Consider using [`anyio.create_task_group()`](https://anyio.readthedocs.io/en/stable/tasks.html) -for managing asynchronious tasks. +for managing asynchronous tasks. ## Running event handlers in tests diff --git a/docs/middleware.md b/docs/middleware.md index 5fe7ce516..f053e97fa 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -250,6 +250,10 @@ This middleware adds authentication to any ASGI application, requiring users to using their GitHub account (via [OAuth](https://developer.github.com/apps/building-oauth-apps/authorizing-oauth-apps/)). Access can be restricted to specific users or to members of specific GitHub organizations or teams. +#### [asgi-csrf](https://github.com/simonw/asgi-csrf) + +Middleware for protecting against CSRF attacks. This middleware implements the Double Submit Cookie pattern, where a cookie is set, then it is compared to a csrftoken hidden form field or an `x-csrftoken` HTTP header. + #### [AuthlibMiddleware](https://github.com/aogier/starlette-authlib) A drop-in replacement for Starlette session middleware, using [authlib's jwt](https://docs.authlib.org/en/latest/jose/jwt.html) @@ -259,6 +263,10 @@ module. A middleware class for logging exceptions to [Bugsnag](https://www.bugsnag.com/). +#### [CSRFMiddleware](https://github.com/frankie567/starlette-csrf) + +Middleware for protecting against CSRF attacks. This middleware implements the Double Submit Cookie pattern, where a cookie is set, then it is compared to an `x-csrftoken` HTTP header. + #### [EarlyDataMiddleware](https://github.com/HarrySky/starlette-early-data) Middleware and decorator for detecting and denying [TLSv1.3 early data](https://tools.ietf.org/html/rfc8470) requests. diff --git a/docs/release-notes.md b/docs/release-notes.md index bd36079f9..672b22409 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -1,3 +1,25 @@ +## 0.18.0 + +January 23, 2022 + +#### Added +* Change default chunk size from 4Kb to 64Kb on `FileResponse` [#1345](https://github.com/encode/starlette/pull/1345). +* Add support for `functools.partial` in `WebSocketRoute` [#1356](https://github.com/encode/starlette/pull/1356). +* Add `StaticFiles` packages with directory [#1350](https://github.com/encode/starlette/pull/1350). +* Allow environment options in `Jinja2Templates` [#1401](https://github.com/encode/starlette/pull/1401). +* Allow HEAD method on `HttpEndpoint` [#1346](https://github.com/encode/starlette/pull/1346). +* Accept additional headers on `websocket.accept` message [#1361](https://github.com/encode/starlette/pull/1361) and [#1422](https://github.com/encode/starlette/pull/1422). +* Add `reason` to `WebSocket` close ASGI event [#1417](https://github.com/encode/starlette/pull/1417). +* Add headers attribute to `UploadFile` [#1382](https://github.com/encode/starlette/pull/1382). +* Don't omit `Content-Length` header for `Content-Length: 0` cases [#1395](https://github.com/encode/starlette/pull/1395). +* Don't set headers for responses with 1xx, 204 and 304 status code [#1397](https://github.com/encode/starlette/pull/1397). +* `SessionMiddleware.max_age` now accepts `None`, so cookie can last as long as the browser session [#1387](https://github.com/encode/starlette/pull/1387). + +#### Fixed +* Tweak `hashlib.md5()` function on `FileResponse`s ETag generation. The parameter [`usedforsecurity`](https://bugs.python.org/issue9216) flag is set to `False`, if the flag is available on the system. This fixes an error raised on systems with [FIPS](https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/FIPS_Mode_-_an_explanation) enabled [#1366](https://github.com/encode/starlette/pull/1366) and [#1410](https://github.com/encode/starlette/pull/1410). +* Fix `path_params` type on `url_path_for()` method i.e. turn `str` into `Any` [#1341](https://github.com/encode/starlette/pull/1341). +* `Host` now ignores `port` on routing [#1322](https://github.com/encode/starlette/pull/1322). + ## 0.17.1 November 17, 2021 diff --git a/docs/routing.md b/docs/routing.md index f0b672d8a..8bdc6215b 100644 --- a/docs/routing.md +++ b/docs/routing.md @@ -39,7 +39,7 @@ Route('/users/{username}', user) ``` By default this will capture characters up to the end of the path or the next `/`. -You can use convertors to modify what is captured. Four convertors are available: +You can use convertors to modify what is captured. The available convertors are: * `str` returns a string, and is the default. * `int` returns a Python integer. diff --git a/docs/templates.md b/docs/templates.md index b67669920..181cd1fef 100644 --- a/docs/templates.md +++ b/docs/templates.md @@ -64,6 +64,20 @@ def test_homepage(): assert "request" in response.context ``` +## Customizing Jinja2 Environment + +`Jinja2Templates` accepts all options supported by Jinja2 `Environment`. +This will allow more control over the `Enivornment` instance created by Starlette. + +For the list of options available to `Environment` you can check Jinja2 documentation [here](https://jinja.palletsprojects.com/en/3.0.x/api/#jinja2.Environment) + +```python +from starlette.templating import Jinja2Templates + + +templates = Jinja2Templates(directory='templates', autoescape=False, auto_reload=True) +``` + ## Asynchronous template rendering Jinja2 supports async template rendering, however as a general rule diff --git a/docs/third-party-packages.md b/docs/third-party-packages.md index 9052a066d..b3d788d17 100644 --- a/docs/third-party-packages.md +++ b/docs/third-party-packages.md @@ -3,7 +3,6 @@ Starlette has a rapidly growing community of developers, building tools that int Here are some of those third party packages: - ## Backports ### Python 3.5 port @@ -12,19 +11,25 @@ Here are some of those third party packages: ## Plugins -### Starlette APISpec +### Authlib -GitHub +GitHub | +Documentation -Simple APISpec integration for Starlette. -Document your REST API built with Starlette by declaring OpenAPI (Swagger) -schemas in YAML format in your endpoint's docstrings. +The ultimate Python library in building OAuth and OpenID Connect clients and servers. Check out how to integrate with [Starlette](https://docs.authlib.org/en/latest/client/starlette.html). -### SpecTree +### ChannelBox -GitHub +GitHub -Generate OpenAPI spec document and validate request & response with Python annotations. Less boilerplate code(no need for YAML). +Another solution for websocket broadcast. Send messages to channel groups from any part of your code. +Checkout MySimpleChat, a simple chat application built using `channel-box` and `starlette`. + +### Imia + +GitHub + +An authentication framework for Starlette with pluggable authenticators and login/logout flow. ### Mangum @@ -39,13 +44,6 @@ Serverless ASGI adapter for AWS Lambda & API Gateway. Manage and send messages to groups of channels using websockets. Checkout nejma-chat, a simple chat application built using `nejma` and `starlette`. -### ChannelBox - -GitHub - -Another solution for websocket broadcast. Send messages to channel groups from any part of your code. -Checkout MySimpleChat, a simple chat application built using `channel-box` and `starlette`. - ### Scout APM GitHub @@ -53,28 +51,32 @@ Checkout MySimp An APM (Application Performance Monitoring) solution that can instrument your application to find performance bottlenecks. -### Starlette Prometheus +### SpecTree -GitHub +GitHub -A plugin for providing an endpoint that exposes [Prometheus](https://prometheus.io/) metrics based on its [official python client](https://github.com/prometheus/client_python). +Generate OpenAPI spec document and validate request & response with Python annotations. Less boilerplate code(no need for YAML). -### webargs-starlette +### Starlette APISpec -GitHub +GitHub -Declarative request parsing and validation for Starlette, built on top -of [webargs](https://github.com/marshmallow-code/webargs). +Simple APISpec integration for Starlette. +Document your REST API built with Starlette by declaring OpenAPI (Swagger) +schemas in YAML format in your endpoint's docstrings. -Allows you to parse querystring, JSON, form, headers, and cookies using -type annotations. +### Starlette Context -### Authlib +GitHub -GitHub | -Documentation +Middleware for Starlette that allows you to store and access the context data of a request. +Can be used with logging so logs automatically use request headers such as x-request-id or x-correlation-id. -The ultimate Python library in building OAuth and OpenID Connect clients and servers. Check out how to integrate with [Starlette](https://docs.authlib.org/en/latest/client/starlette.html). +### Starlette Cramjam + +GitHub + +A Starlette middleware that allows **brotli**, **gzip** and **deflate** compression algorithm with a minimal requirements. ### Starlette OAuth2 API @@ -83,13 +85,17 @@ The ultimate Python library in building OAuth and OpenID Connect clients and ser A starlette middleware to add authentication and authorization through JWTs. It relies solely on an auth provider to issue access and/or id tokens to clients. -### Starlette Context +### Starlette Prometheus -GitHub +GitHub -Middleware for Starlette that allows you to store and access the context data of a request. -Can be used with logging so logs automatically use request headers such as x-request-id or x-correlation-id. +A plugin for providing an endpoint that exposes [Prometheus](https://prometheus.io/) metrics based on its [official python client](https://github.com/prometheus/client_python). +### Starlette WTF + +GitHub + +A simple tool for integrating Starlette and WTForms. It is modeled on the excellent Flask-WTF library. ### Starsessions @@ -97,31 +103,18 @@ Can be used with logging so logs automatically use request headers such as x-req An alternate session support implementation with customizable storage backends. +### webargs-starlette -### Starlette Cramjam - -GitHub - -A Starlette middleware that allows **brotli**, **gzip** and **deflate** compression algorithm with a minimal requirements. - - -### Imia - -GitHub +GitHub -An authentication framework for Starlette with pluggable authenticators and login/logout flow. +Declarative request parsing and validation for Starlette, built on top +of [webargs](https://github.com/marshmallow-code/webargs). +Allows you to parse querystring, JSON, form, headers, and cookies using +type annotations. ## Frameworks -### Responder - -GitHub | -Documentation - -Async web service framework. Some Features: flask-style route expression, -yaml support, OpenAPI schema generation, background tasks, graphql. - ### FastAPI GitHub | @@ -139,12 +132,6 @@ Formerly Starlette API. Flama aims to bring a layer on top of Starlette to provide an **easy to learn** and **fast to develop** approach for building **highly performant** GraphQL and REST APIs. In the same way of Starlette is, Flama is a perfect option for developing **asynchronous** and **production-ready** services. -### Starlette-apps - -Roll your own framework with a simple app system, like [Django-GDAPS](https://gdaps.readthedocs.io/en/latest/) or [CakePHP](https://cakephp.org/). - -GitHub - ### Greppo GitHub | @@ -154,3 +141,16 @@ A Python framework for building geospatial dashboards and web-applications. Greppo is an open-source Python framework that makes it easy to build geospatial dashboards and web-applications. It provides a toolkit to quickly integrate data, algorithms, visualizations and UI for interactivity. It provides APIs to the update the variables in the backend, recompute the logic, and reflect the changes in the frontend (data mutation hook). +### Responder + +GitHub | +Documentation + +Async web service framework. Some Features: flask-style route expression, +yaml support, OpenAPI schema generation, background tasks, graphql. + +### Starlette-apps + +Roll your own framework with a simple app system, like [Django-GDAPS](https://gdaps.readthedocs.io/en/latest/) or [CakePHP](https://cakephp.org/). + +GitHub diff --git a/docs/websockets.md b/docs/websockets.md index 43406aced..1128bce43 100644 --- a/docs/websockets.md +++ b/docs/websockets.md @@ -75,7 +75,7 @@ Use `websocket.receive_json(data, mode="binary")` to receive JSON over binary da ### Closing the connection -* `await websocket.close(code=1000)` +* `await websocket.close(code=1000, reason=None)` ### Sending and receiving messages diff --git a/requirements.txt b/requirements.txt index 48f3cfa3e..51389304e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ coverage==6.2 databases[sqlite]==0.5.3 flake8==4.0.1 isort==5.10.1 -mypy==0.930 +mypy==0.931 types-requests==2.26.3 types-contextvars==2.4.0 types-PyYAML==6.0.1 diff --git a/starlette/__init__.py b/starlette/__init__.py index c6eae9f8a..1317d7554 100644 --- a/starlette/__init__.py +++ b/starlette/__init__.py @@ -1 +1 @@ -__version__ = "0.17.1" +__version__ = "0.18.0" diff --git a/starlette/_compat.py b/starlette/_compat.py index 82aa72f38..116561917 100644 --- a/starlette/_compat.py +++ b/starlette/_compat.py @@ -11,7 +11,10 @@ # See issue: https://github.com/encode/starlette/issues/1365 try: - hashlib.md5(b"data", usedforsecurity=True) # type: ignore[call-arg] + # check if the Python version supports the parameter + # using usedforsecurity=False to avoid an exception on FIPS systems + # that reject usedforsecurity=True + hashlib.md5(b"data", usedforsecurity=False) # type: ignore[call-arg] def md5_hexdigest( data: bytes, *, usedforsecurity: bool = True diff --git a/starlette/applications.py b/starlette/applications.py index 9f05dc286..ab6792527 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -44,7 +44,10 @@ def __init__( routes: typing.Sequence[BaseRoute] = None, middleware: typing.Sequence[Middleware] = None, exception_handlers: typing.Mapping[ - typing.Any, typing.Callable[[Request, Exception], Response] + typing.Any, + typing.Callable[ + [Request, Exception], typing.Union[Response, typing.Awaitable[Response]] + ], ] = None, on_startup: typing.Sequence[typing.Callable] = None, on_shutdown: typing.Sequence[typing.Callable] = None, diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 52904d51e..2c5c4b016 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -163,7 +163,7 @@ class URLPath(str): def __new__(cls, path: str, protocol: str = "", host: str = "") -> "URLPath": assert protocol in ("http", "websocket", "") - return str.__new__(cls, path) # type: ignore + return str.__new__(cls, path) def __init__(self, path: str, protocol: str = "", host: str = "") -> None: self.protocol = protocol @@ -441,7 +441,7 @@ def _in_memory(self) -> bool: async def write(self, data: bytes) -> None: if self._in_memory: - self.file.write(data) # type: ignore + self.file.write(data) else: await run_in_threadpool(self.file.write, data) diff --git a/starlette/middleware/gzip.py b/starlette/middleware/gzip.py index 37c6936fa..9d69ee7ca 100644 --- a/starlette/middleware/gzip.py +++ b/starlette/middleware/gzip.py @@ -1,5 +1,6 @@ import gzip import io +import typing from starlette.datastructures import Headers, MutableHeaders from starlette.types import ASGIApp, Message, Receive, Scope, Send @@ -100,5 +101,5 @@ async def send_with_gzip(self, message: Message) -> None: await self.send(message) -async def unattached_send(message: Message) -> None: +async def unattached_send(message: Message) -> typing.NoReturn: raise RuntimeError("send awaitable not set") # pragma: no cover diff --git a/starlette/requests.py b/starlette/requests.py index cf129702e..a33367e1d 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -51,7 +51,7 @@ def cookie_parser(cookie_string: str) -> typing.Dict[str, str]: key, val = key.strip(), val.strip() if key or val: # unquote using Python's algorithm. - cookie_dict[key] = http_cookies._unquote(val) # type: ignore + cookie_dict[key] = http_cookies._unquote(val) return cookie_dict @@ -175,11 +175,11 @@ def url_for(self, name: str, **path_params: typing.Any) -> str: return url_path.make_absolute_url(base_url=self.base_url) -async def empty_receive() -> Message: +async def empty_receive() -> typing.NoReturn: raise RuntimeError("Receive channel has not been made available") -async def empty_send(message: Message) -> None: +async def empty_send(message: Message) -> typing.NoReturn: raise RuntimeError("Send channel has not been made available") diff --git a/starlette/responses.py b/starlette/responses.py index 26d730540..326e743c5 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -33,25 +33,31 @@ def guess_type( class Response: media_type = None charset = "utf-8" + default_status_code = 200 + default_response = b"" def __init__( self, content: typing.Any = None, - status_code: int = 200, + status_code: int = None, headers: dict = None, media_type: str = None, background: BackgroundTask = None, ) -> None: - self.status_code = status_code if media_type is not None: self.media_type = media_type self.background = background - self.body = self.render(content) + + if content is None: + self.body = self.default_response + self.status_code = status_code + else: + self.body = self.render(content) + self.status_code = status_code or self.default_status_code + self.init_headers(headers) def render(self, content: typing.Any) -> bytes: - if content is None: - return b"" if isinstance(content, bytes): return content return content.encode(self.charset) @@ -74,7 +80,11 @@ def init_headers(self, headers: typing.Mapping[str, str] = None) -> None: if ( body is not None and populate_content_length - and not (self.status_code < 200 or self.status_code in (204, 304)) + and not ( + self.status_code + and self.status_code < 200 + or self.status_code in (204, 304) + ) ): content_length = str(len(body)) raw_headers.append((b"content-length", content_length.encode("latin-1"))) @@ -173,6 +183,7 @@ class PlainTextResponse(Response): class JSONResponse(Response): media_type = "application/json" + default_response = b"{}" def render(self, content: typing.Any) -> bytes: return json.dumps( diff --git a/starlette/templating.py b/starlette/templating.py index 18d5eb40c..a44edddc2 100644 --- a/starlette/templating.py +++ b/starlette/templating.py @@ -55,12 +55,14 @@ class Jinja2Templates: return templates.TemplateResponse("index.html", {"request": request}) """ - def __init__(self, directory: typing.Union[str, PathLike]) -> None: + def __init__( + self, directory: typing.Union[str, PathLike], **env_options: typing.Any + ) -> None: assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates" - self.env = self._create_env(directory) + self.env = self._create_env(directory, **env_options) def _create_env( - self, directory: typing.Union[str, PathLike] + self, directory: typing.Union[str, PathLike], **env_options: typing.Any ) -> "jinja2.Environment": @pass_context def url_for(context: dict, name: str, **path_params: typing.Any) -> str: @@ -68,7 +70,10 @@ def url_for(context: dict, name: str, **path_params: typing.Any) -> str: return request.url_for(name, **path_params) loader = jinja2.FileSystemLoader(directory) - env = jinja2.Environment(loader=loader, autoescape=True) + env_options.setdefault("loader", loader) + env_options.setdefault("autoescape", True) + + env = jinja2.Environment(**env_options) env.globals["url_for"] = url_for return env diff --git a/starlette/testclient.py b/starlette/testclient.py index 0b4bc78d1..c951767b4 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -352,7 +352,9 @@ async def _asgi_send(self, message: Message) -> None: def _raise_on_close(self, message: Message) -> None: if message["type"] == "websocket.close": - raise WebSocketDisconnect(message.get("code", 1000)) + raise WebSocketDisconnect( + message.get("code", 1000), message.get("reason", "") + ) def send(self, message: Message) -> None: self._receive_queue.put(message) diff --git a/starlette/websockets.py b/starlette/websockets.py index 7632b28cf..da7406047 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -13,8 +13,9 @@ class WebSocketState(enum.Enum): class WebSocketDisconnect(Exception): - def __init__(self, code: int = 1000) -> None: + def __init__(self, code: int = 1000, reason: str = None) -> None: self.code = code + self.reason = reason or "" class WebSocket(HTTPConnection): @@ -74,6 +75,8 @@ async def accept( subprotocol: str = None, headers: typing.Iterable[typing.Tuple[bytes, bytes]] = None, ) -> None: + headers = headers or [] + if self.client_state == WebSocketState.CONNECTING: # If we haven't yet seen the 'connect' message, then wait for it first. await self.receive() @@ -144,13 +147,18 @@ async def send_json(self, data: typing.Any, mode: str = "text") -> None: else: await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")}) - async def close(self, code: int = 1000) -> None: - await self.send({"type": "websocket.close", "code": code}) + async def close(self, code: int = 1000, reason: str = None) -> None: + await self.send( + {"type": "websocket.close", "code": code, "reason": reason or ""} + ) class WebSocketClose: - def __init__(self, code: int = 1000) -> None: + def __init__(self, code: int = 1000, reason: str = None) -> None: self.code = code + self.reason = reason or "" async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - await send({"type": "websocket.close", "code": self.code}) + await send( + {"type": "websocket.close", "code": self.code, "reason": self.reason} + ) diff --git a/tests/test_responses.py b/tests/test_responses.py index e2337bdca..7f1095f38 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -38,12 +38,12 @@ async def app(scope, receive, send): def test_json_none_response(test_client_factory): async def app(scope, receive, send): - response = JSONResponse(None) + response = JSONResponse(status_code=200) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") - assert response.json() is None + assert response.json() == {} def test_redirect_response(test_client_factory): @@ -327,7 +327,7 @@ def test_head_method(test_client_factory): def test_empty_response(test_client_factory): - app = Response() + app = Response(status_code=200) client: TestClient = test_client_factory(app) response = client.get("/") assert response.headers["content-length"] == "0" diff --git a/tests/test_websockets.py b/tests/test_websockets.py index bf0253309..b11685cbc 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -315,6 +315,20 @@ async def asgi(receive, send): assert websocket.extra_headers == [(b"additional", b"header")] +def test_no_additional_headers(test_client_factory): + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + await websocket.close() + + return asgi + + client = test_client_factory(app) + with client.websocket_connect("/") as websocket: + assert websocket.extra_headers == [] + + def test_websocket_exception(test_client_factory): def app(scope): async def asgi(receive, send): @@ -391,3 +405,20 @@ async def mock_send(message): assert websocket == websocket assert websocket in {websocket} assert {websocket} == {websocket} + + +def test_websocket_close_reason(test_client_factory) -> None: + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + await websocket.close(code=status.WS_1001_GOING_AWAY, reason="Going Away") + + return asgi + + client = test_client_factory(app) + with client.websocket_connect("/") as websocket: + with pytest.raises(WebSocketDisconnect) as exc: + websocket.receive_text() + assert exc.value.code == status.WS_1001_GOING_AWAY + assert exc.value.reason == "Going Away" From 730f100bca964c964f06b0cd58a8737121b63eb0 Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Tue, 25 Jan 2022 12:46:06 +0100 Subject: [PATCH 02/10] Update starlette/responses.py Co-authored-by: Tom Christie --- starlette/responses.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/starlette/responses.py b/starlette/responses.py index 326e743c5..39af0e1d9 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -49,8 +49,8 @@ def __init__( self.background = background if content is None: - self.body = self.default_response - self.status_code = status_code + self.body = b"" + self.status_code = status_code or 204 else: self.body = self.render(content) self.status_code = status_code or self.default_status_code From 065514e733db7423ab9d63467d6245150682d1f3 Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Tue, 25 Jan 2022 12:46:15 +0100 Subject: [PATCH 03/10] Update starlette/responses.py Co-authored-by: Tom Christie --- starlette/responses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/starlette/responses.py b/starlette/responses.py index 39af0e1d9..4dfb65a26 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -53,7 +53,7 @@ def __init__( self.status_code = status_code or 204 else: self.body = self.render(content) - self.status_code = status_code or self.default_status_code + self.status_code = status_code or 200 self.init_headers(headers) From 9fef183fb304c8be40df4bb5276331da7c234230 Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Tue, 25 Jan 2022 12:46:23 +0100 Subject: [PATCH 04/10] Update tests/test_responses.py Co-authored-by: Tom Christie --- tests/test_responses.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_responses.py b/tests/test_responses.py index 7f1095f38..d1a0b5229 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -327,10 +327,11 @@ def test_head_method(test_client_factory): def test_empty_response(test_client_factory): - app = Response(status_code=200) + app = Response() client: TestClient = test_client_factory(app) response = client.get("/") assert response.headers["content-length"] == "0" + assert response.status_code == 204 def test_empty_204_response(test_client_factory): From e02a7fa9312b0b8d9bd15b5e385009540dafc8d3 Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Tue, 25 Jan 2022 12:46:38 +0100 Subject: [PATCH 05/10] Update tests/test_responses.py Co-authored-by: Tom Christie --- tests/test_responses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_responses.py b/tests/test_responses.py index d1a0b5229..4a1723724 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -38,7 +38,7 @@ async def app(scope, receive, send): def test_json_none_response(test_client_factory): async def app(scope, receive, send): - response = JSONResponse(status_code=200) + response = JSONResponse() await response(scope, receive, send) client = test_client_factory(app) From ad565b6056cb6b2cb42ec1785b54978e26541384 Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Tue, 25 Jan 2022 12:47:00 +0100 Subject: [PATCH 06/10] Update tests/test_responses.py Co-authored-by: Tom Christie --- tests/test_responses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_responses.py b/tests/test_responses.py index 4a1723724..99413714a 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -43,7 +43,7 @@ async def app(scope, receive, send): client = test_client_factory(app) response = client.get("/") - assert response.json() == {} + assert response.content == b"" def test_redirect_response(test_client_factory): From f193ae758eb358559ffc6d7479b0b0d916c30e8a Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Tue, 25 Jan 2022 12:50:46 +0100 Subject: [PATCH 07/10] update --- starlette/responses.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/starlette/responses.py b/starlette/responses.py index 4dfb65a26..bc7464306 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -33,8 +33,6 @@ def guess_type( class Response: media_type = None charset = "utf-8" - default_status_code = 200 - default_response = b"" def __init__( self, @@ -80,11 +78,7 @@ def init_headers(self, headers: typing.Mapping[str, str] = None) -> None: if ( body is not None and populate_content_length - and not ( - self.status_code - and self.status_code < 200 - or self.status_code in (204, 304) - ) + and not (self.status_code < 200 or self.status_code in (204, 304)) ): content_length = str(len(body)) raw_headers.append((b"content-length", content_length.encode("latin-1"))) @@ -183,7 +177,6 @@ class PlainTextResponse(Response): class JSONResponse(Response): media_type = "application/json" - default_response = b"{}" def render(self, content: typing.Any) -> bytes: return json.dumps( From 425427f961df58875ae4428d44312a19d6b24ef7 Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Tue, 25 Jan 2022 13:14:26 +0100 Subject: [PATCH 08/10] update test --- tests/test_responses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_responses.py b/tests/test_responses.py index 99413714a..21f563077 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -330,8 +330,8 @@ def test_empty_response(test_client_factory): app = Response() client: TestClient = test_client_factory(app) response = client.get("/") - assert response.headers["content-length"] == "0" assert response.status_code == 204 + assert "content-length" not in response.headers def test_empty_204_response(test_client_factory): From af3bc9a481d29a61aacda66241f338a6e0052cca Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Wed, 26 Jan 2022 12:34:32 +0100 Subject: [PATCH 09/10] revert status_code change --- starlette/responses.py | 23 +++++++++++++++-------- tests/test_responses.py | 10 ++++++---- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/starlette/responses.py b/starlette/responses.py index bc7464306..9723cdaa2 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -37,25 +37,22 @@ class Response: def __init__( self, content: typing.Any = None, - status_code: int = None, + status_code: int = 200, headers: dict = None, media_type: str = None, background: BackgroundTask = None, ) -> None: + self.status_code = status_code if media_type is not None: self.media_type = media_type self.background = background - if content is None: - self.body = b"" - self.status_code = status_code or 204 - else: - self.body = self.render(content) - self.status_code = status_code or 200 - + self.body = self.render(content) self.init_headers(headers) def render(self, content: typing.Any) -> bytes: + if content is None: + return b"" if isinstance(content, bytes): return content return content.encode(self.charset) @@ -178,6 +175,16 @@ class PlainTextResponse(Response): class JSONResponse(Response): media_type = "application/json" + def __init__( + self, + content: typing.Any, + status_code: int = 200, + headers: dict = None, + media_type: str = None, + background: BackgroundTask = None, + ) -> None: + super().__init__(content, status_code, headers, media_type, background) + def render(self, content: typing.Any) -> bytes: return json.dumps( content, diff --git a/tests/test_responses.py b/tests/test_responses.py index 21f563077..38aea5761 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -38,12 +38,13 @@ async def app(scope, receive, send): def test_json_none_response(test_client_factory): async def app(scope, receive, send): - response = JSONResponse() + response = JSONResponse(None) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") - assert response.content == b"" + assert response.json() is None + assert response.content == b"null" def test_redirect_response(test_client_factory): @@ -330,8 +331,9 @@ def test_empty_response(test_client_factory): app = Response() client: TestClient = test_client_factory(app) response = client.get("/") - assert response.status_code == 204 - assert "content-length" not in response.headers + assert response.content == b"" + assert response.headers["content-length"] == "0" + assert "content-type" not in response.headers def test_empty_204_response(test_client_factory): From 6985784d656bc130b80bd4d538e8c702f6c76b94 Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Wed, 26 Jan 2022 13:18:27 +0100 Subject: [PATCH 10/10] remove whitespace --- starlette/responses.py | 1 - 1 file changed, 1 deletion(-) diff --git a/starlette/responses.py b/starlette/responses.py index 9723cdaa2..577d1bb04 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -46,7 +46,6 @@ def __init__( if media_type is not None: self.media_type = media_type self.background = background - self.body = self.render(content) self.init_headers(headers)