diff --git a/docs/authentication.md b/docs/authentication.md index 48eba6ca2..d6cec3fb2 100644 --- a/docs/authentication.md +++ b/docs/authentication.md @@ -131,6 +131,29 @@ async def dashboard(request): ... ``` +When redirecting users, the page you redirect them to will include URL they originally requested at the `next` query param: + +```python +from starlette.authentication import requires +from starlette.responses import RedirectResponse + + +@requires('authenticated', redirect='login') +async def admin(request): + ... + + +async def login(request): + if request.method == "POST": + # Now that the user is authenticated, + # we can send them to their original request destination + if request.user.is_authenticated: + next_url = request.query_params.get("next") + if next_url: + return RedirectResponse(next_url) + return RedirectResponse("/") +``` + For class-based endpoints, you should wrap the decorator around a method on the class. diff --git a/starlette/authentication.py b/starlette/authentication.py index b4882070d..1a4cba377 100644 --- a/starlette/authentication.py +++ b/starlette/authentication.py @@ -2,6 +2,7 @@ import functools import inspect import typing +from urllib.parse import urlencode from starlette.exceptions import HTTPException from starlette.requests import HTTPConnection, Request @@ -63,9 +64,12 @@ async def async_wrapper( if not has_required_scope(request, scopes_list): if redirect is not None: - return RedirectResponse( - url=request.url_for(redirect), status_code=303 + orig_request_qparam = urlencode({"next": str(request.url)}) + next_url = "{redirect_path}?{orig_request}".format( + redirect_path=request.url_for(redirect), + orig_request=orig_request_qparam, ) + return RedirectResponse(url=next_url, status_code=303) raise HTTPException(status_code=status_code) return await func(*args, **kwargs) @@ -80,9 +84,12 @@ def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response: if not has_required_scope(request, scopes_list): if redirect is not None: - return RedirectResponse( - url=request.url_for(redirect), status_code=303 + orig_request_qparam = urlencode({"next": str(request.url)}) + next_url = "{redirect_path}?{orig_request}".format( + redirect_path=request.url_for(redirect), + orig_request=orig_request_qparam, ) + return RedirectResponse(url=next_url, status_code=303) raise HTTPException(status_code=status_code) return func(*args, **kwargs) diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 65b49c3ca..af0beafd0 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -1,5 +1,6 @@ import base64 import binascii +from urllib.parse import urlencode import pytest @@ -305,7 +306,10 @@ def test_authentication_redirect(test_client_factory): with test_client_factory(app) as client: response = client.get("/admin") assert response.status_code == 200 - assert response.url == "http://testserver/" + url = "{}?{}".format( + "http://testserver/", urlencode({"next": "http://testserver/admin"}) + ) + assert response.url == url response = client.get("/admin", auth=("tomchristie", "example")) assert response.status_code == 200 @@ -313,7 +317,10 @@ def test_authentication_redirect(test_client_factory): response = client.get("/admin/sync") assert response.status_code == 200 - assert response.url == "http://testserver/" + url = "{}?{}".format( + "http://testserver/", urlencode({"next": "http://testserver/admin/sync"}) + ) + assert response.url == url response = client.get("/admin/sync", auth=("tomchristie", "example")) assert response.status_code == 200