Skip to content

Commit

Permalink
Merge branch 'master' into task-group
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Jun 7, 2023
2 parents 242e94f + da7adf2 commit b921c56
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 76 deletions.
1 change: 0 additions & 1 deletion .github/pull_request_template.md
Expand Up @@ -10,4 +10,3 @@ Given this is a project maintained by volunteers, please read this template to n
- [ ] I understand that this PR may be closed in case there was no previous discussion. (This doesn't apply to typos!)
- [ ] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
- [ ] I've updated the documentation accordingly.
- [ ] I've added a note on the `docs/release-notes.md` about my changes (changes on documentation or type annotations don't need this).
15 changes: 15 additions & 0 deletions docs/release-notes.md
@@ -1,3 +1,18 @@
## 0.28.0

June 7, 2023

### Changed
* Reuse `Request`'s body buffer for call_next in `BaseHTTPMiddleware` [#1692](https://github.com/encode/starlette/pull/1692).
* Move exception handling logic to `Route` [#2026](https://github.com/encode/starlette/pull/2026).

### Added
* Add `env` parameter to `Jinja2Templates`, and deprecate `**env_options` [#2159](https://github.com/encode/starlette/pull/2159).
* Add clear error message when `httpx` is not installed [#2177](https://github.com/encode/starlette/pull/2177).

### Fixed
* Allow "name" argument on `templates url_for()` [#2127](https://github.com/encode/starlette/pull/2127).

## 0.27.0

May 16, 2023
Expand Down
2 changes: 1 addition & 1 deletion starlette/__init__.py
@@ -1 +1 @@
__version__ = "0.27.0"
__version__ = "0.28.0"
76 changes: 76 additions & 0 deletions starlette/_exception_handler.py
@@ -0,0 +1,76 @@
import typing

from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocket

Handler = typing.Callable[..., typing.Any]
ExceptionHandlers = typing.Dict[typing.Any, Handler]
StatusHandlers = typing.Dict[int, Handler]


def _lookup_exception_handler(
exc_handlers: ExceptionHandlers, exc: Exception
) -> typing.Optional[Handler]:
for cls in type(exc).__mro__:
if cls in exc_handlers:
return exc_handlers[cls]
return None


def wrap_app_handling_exceptions(
app: ASGIApp, conn: typing.Union[Request, WebSocket]
) -> ASGIApp:
exception_handlers: ExceptionHandlers
status_handlers: StatusHandlers
try:
exception_handlers, status_handlers = conn.scope["starlette.exception_handlers"]
except KeyError:
exception_handlers, status_handlers = {}, {}

async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None:
response_started = False

async def sender(message: Message) -> None:
nonlocal response_started

if message["type"] == "http.response.start":
response_started = True
await send(message)

try:
await app(scope, receive, sender)
except Exception as exc:
handler = None

if isinstance(exc, HTTPException):
handler = status_handlers.get(exc.status_code)

if handler is None:
handler = _lookup_exception_handler(exception_handlers, exc)

if handler is None:
raise exc

if response_started:
msg = "Caught handled exception, but response already started."
raise RuntimeError(msg) from exc

if scope["type"] == "http":
response: Response
if is_async_callable(handler):
response = await handler(conn, exc)
else:
response = await run_in_threadpool(handler, conn, exc)
await response(scope, receive, sender)
elif scope["type"] == "websocket":
if is_async_callable(handler):
await handler(conn, exc)
else:
await run_in_threadpool(handler, conn, exc)

return wrapped_app
83 changes: 24 additions & 59 deletions starlette/middleware/exceptions.py
@@ -1,11 +1,14 @@
import typing

from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette._exception_handler import (
ExceptionHandlers,
StatusHandlers,
wrap_app_handling_exceptions,
)
from starlette.exceptions import HTTPException, WebSocketException
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket


Expand All @@ -20,12 +23,10 @@ def __init__(
) -> None:
self.app = app
self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
self._status_handlers: typing.Dict[int, typing.Callable] = {}
self._exception_handlers: typing.Dict[
typing.Type[Exception], typing.Callable
] = {
self._status_handlers: StatusHandlers = {}
self._exception_handlers: ExceptionHandlers = {
HTTPException: self.http_exception,
WebSocketException: self.websocket_exception,
WebSocketException: self.websocket_exception, # type: ignore[dict-item]
}
if handlers is not None:
for key, value in handlers.items():
Expand All @@ -42,68 +43,32 @@ def add_exception_handler(
assert issubclass(exc_class_or_status_code, Exception)
self._exception_handlers[exc_class_or_status_code] = handler

def _lookup_exception_handler(
self, exc: Exception
) -> typing.Optional[typing.Callable]:
for cls in type(exc).__mro__:
if cls in self._exception_handlers:
return self._exception_handlers[cls]
return None

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ("http", "websocket"):
await self.app(scope, receive, send)
return

response_started = False

async def sender(message: Message) -> None:
nonlocal response_started

if message["type"] == "http.response.start":
response_started = True
await send(message)

try:
await self.app(scope, receive, sender)
except Exception as exc:
handler = None

if isinstance(exc, HTTPException):
handler = self._status_handlers.get(exc.status_code)

if handler is None:
handler = self._lookup_exception_handler(exc)

if handler is None:
raise exc
scope["starlette.exception_handlers"] = (
self._exception_handlers,
self._status_handlers,
)

if response_started:
msg = "Caught handled exception, but response already started."
raise RuntimeError(msg) from exc
conn: typing.Union[Request, WebSocket]
if scope["type"] == "http":
conn = Request(scope, receive, send)
else:
conn = WebSocket(scope, receive, send)

if scope["type"] == "http":
request = Request(scope, receive=receive)
if is_async_callable(handler):
response = await handler(request, exc)
else:
response = await run_in_threadpool(handler, request, exc)
await response(scope, receive, sender)
elif scope["type"] == "websocket":
websocket = WebSocket(scope, receive=receive, send=send)
if is_async_callable(handler):
await handler(websocket, exc)
else:
await run_in_threadpool(handler, websocket, exc)
await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)

def http_exception(self, request: Request, exc: HTTPException) -> Response:
def http_exception(self, request: Request, exc: Exception) -> Response:
assert isinstance(exc, HTTPException)
if exc.status_code in {204, 304}:
return Response(status_code=exc.status_code, headers=exc.headers)
return PlainTextResponse(
exc.detail, status_code=exc.status_code, headers=exc.headers
)

async def websocket_exception(
self, websocket: WebSocket, exc: WebSocketException
) -> None:
await websocket.close(code=exc.code, reason=exc.reason)
async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None:
assert isinstance(exc, WebSocketException)
await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover
23 changes: 16 additions & 7 deletions starlette/routing.py
Expand Up @@ -9,6 +9,7 @@
from contextlib import asynccontextmanager
from enum import Enum

from starlette._exception_handler import wrap_app_handling_exceptions
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.convertors import CONVERTOR_TYPES, Convertor
Expand Down Expand Up @@ -61,12 +62,16 @@ def request_response(func: typing.Callable) -> ASGIApp:
is_coroutine = is_async_callable(func)

async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive=receive, send=send)
if is_coroutine:
response = await func(request)
else:
response = await run_in_threadpool(func, request)
await response(scope, receive, send)
request = Request(scope, receive, send)

async def app(scope: Scope, receive: Receive, send: Send) -> None:
if is_coroutine:
response = await func(request)
else:
response = await run_in_threadpool(func, request)
await response(scope, receive, send)

await wrap_app_handling_exceptions(app, request)(scope, receive, send)

return app

Expand All @@ -79,7 +84,11 @@ def websocket_session(func: typing.Callable) -> ASGIApp:

async def app(scope: Scope, receive: Receive, send: Send) -> None:
session = WebSocket(scope, receive=receive, send=send)
await func(session)

async def app(scope: Scope, receive: Receive, send: Send) -> None:
await func(session)

await wrap_app_handling_exceptions(app, session)(scope, receive, send)

return app

Expand Down
10 changes: 9 additions & 1 deletion starlette/testclient.py
Expand Up @@ -13,13 +13,21 @@

import anyio
import anyio.from_thread
import httpx
from anyio.streams.stapled import StapledObjectStream

from starlette._utils import is_async_callable
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocketDisconnect

try:
import httpx
except ModuleNotFoundError: # pragma: no cover
raise RuntimeError(
"The starlette.testclient module requires the httpx package to be installed.\n"
"You can install this with:\n"
" $ pip install httpx\n"
)

if sys.version_info >= (3, 8): # pragma: no cover
from typing import TypedDict
else: # pragma: no cover
Expand Down
35 changes: 33 additions & 2 deletions tests/test_exceptions.py
Expand Up @@ -4,7 +4,8 @@

from starlette.exceptions import HTTPException, WebSocketException
from starlette.middleware.exceptions import ExceptionMiddleware
from starlette.responses import PlainTextResponse
from starlette.requests import Request
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Route, Router, WebSocketRoute


Expand All @@ -28,6 +29,22 @@ def with_headers(request):
raise HTTPException(status_code=200, headers={"x-potato": "always"})


class BadBodyException(HTTPException):
pass


async def read_body_and_raise_exc(request: Request):
await request.body()
raise BadBodyException(422)


async def handler_that_reads_body(
request: Request, exc: BadBodyException
) -> JSONResponse:
body = await request.body()
return JSONResponse(status_code=422, content={"body": body.decode()})


class HandledExcAfterResponse:
async def __call__(self, scope, receive, send):
response = PlainTextResponse("OK", status_code=200)
Expand All @@ -44,11 +61,19 @@ async def __call__(self, scope, receive, send):
Route("/with_headers", endpoint=with_headers),
Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse()),
WebSocketRoute("/runtime_error", endpoint=raise_runtime_error),
Route(
"/consume_body_in_endpoint_and_handler",
endpoint=read_body_and_raise_exc,
methods=["POST"],
),
]
)


app = ExceptionMiddleware(router)
app = ExceptionMiddleware(
router,
handlers={BadBodyException: handler_that_reads_body}, # type: ignore[dict-item]
)


@pytest.fixture
Expand Down Expand Up @@ -160,3 +185,9 @@ def test_exception_middleware_deprecation() -> None:

with pytest.warns(DeprecationWarning):
starlette.exceptions.ExceptionMiddleware


def test_request_in_app_and_handler_is_the_same_object(client) -> None:
response = client.post("/consume_body_in_endpoint_and_handler", content=b"Hello!")
assert response.status_code == 422
assert response.json() == {"body": "Hello!"}
6 changes: 1 addition & 5 deletions tests/test_routing.py
Expand Up @@ -1033,13 +1033,9 @@ async def modified_send(msg: Message) -> None:
assert resp.status_code == 200, resp.content
assert "X-Mounted" in resp.headers

# this is the "surprising" behavior bit
# the middleware on the mount never runs because there
# is nothing to catch the HTTPException
# since Mount middlweare is not wrapped by ExceptionMiddleware
resp = client.get("/mount/err")
assert resp.status_code == 403, resp.content
assert "X-Mounted" not in resp.headers
assert "X-Mounted" in resp.headers


def test_route_repr() -> None:
Expand Down

0 comments on commit b921c56

Please sign in to comment.