From 2d6ddd386199e9e6cf0df0849e1de7d1a8c86b9d Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 10 Jan 2022 12:08:31 +0100 Subject: [PATCH 1/8] Don't set headers for responses with 1xx, 204 and 304 status code (#1397) * Don't set headers for responses with 1xx, 204 and 304 status code * Fix test Co-authored-by: Tom Christie --- starlette/responses.py | 6 +++++- tests/test_responses.py | 7 +++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/starlette/responses.py b/starlette/responses.py index da765cfa9..26d730540 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -71,7 +71,11 @@ def init_headers(self, headers: typing.Mapping[str, str] = None) -> None: populate_content_type = b"content-type" not in keys body = getattr(self, "body", None) - if body is not None and populate_content_length: + if ( + body is not None + and populate_content_length + and not (self.status_code < 200 or self.status_code in (204, 304)) + ): content_length = str(len(body)) raw_headers.append((b"content-length", content_length.encode("latin-1"))) diff --git a/tests/test_responses.py b/tests/test_responses.py index 150fe4795..e2337bdca 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -333,6 +333,13 @@ def test_empty_response(test_client_factory): assert response.headers["content-length"] == "0" +def test_empty_204_response(test_client_factory): + app = Response(status_code=204) + client: TestClient = test_client_factory(app) + response = client.get("/") + assert "content-length" not in response.headers + + def test_non_empty_response(test_client_factory): app = Response(content="hi") client: TestClient = test_client_factory(app) From 4e86245c4cf891bc72535e10806e122c77fad37b Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 10 Jan 2022 09:41:43 -0800 Subject: [PATCH 2/8] Document and type annotate UploadFile as a bytes-only interface. Not bytes or text. (#1312) --- docs/requests.md | 4 ++-- starlette/datastructures.py | 12 +++++++----- tests/test_datastructures.py | 12 ++++++++++++ 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/docs/requests.md b/docs/requests.md index f4d867ab1..872946638 100644 --- a/docs/requests.md +++ b/docs/requests.md @@ -126,8 +126,8 @@ multidict, containing both file uploads and text input. File upload items are re `UploadFile` has the following `async` methods. They all call the corresponding file methods underneath (using the internal `SpooledTemporaryFile`). -* `async write(data)`: Writes `data` (`str` or `bytes`) to the file. -* `async read(size)`: Reads `size` (`int`) bytes/characters of the file. +* `async write(data)`: Writes `data` (`bytes`) to the file. +* `async read(size)`: Reads `size` (`int`) bytes of the file. * `async seek(offset)`: Goes to the byte position `offset` (`int`) in the file. * E.g., `await myfile.seek(0)` would go to the start of the file. * `async close()`: Closes the file. diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 64f964a91..52904d51e 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -415,12 +415,13 @@ class UploadFile: """ spool_max_size = 1024 * 1024 + file: typing.BinaryIO headers: "Headers" def __init__( self, filename: str, - file: typing.IO = None, + file: typing.Optional[typing.BinaryIO] = None, content_type: str = "", *, headers: "typing.Optional[Headers]" = None, @@ -428,8 +429,9 @@ def __init__( self.filename = filename self.content_type = content_type if file is None: - file = tempfile.SpooledTemporaryFile(max_size=self.spool_max_size) - self.file = file + self.file = tempfile.SpooledTemporaryFile(max_size=self.spool_max_size) # type: ignore # noqa: E501 + else: + self.file = file self.headers = headers or Headers() @property @@ -437,13 +439,13 @@ def _in_memory(self) -> bool: rolled_to_disk = getattr(self.file, "_rolled", True) return not rolled_to_disk - async def write(self, data: typing.Union[bytes, str]) -> None: + async def write(self, data: bytes) -> None: if self._in_memory: self.file.write(data) # type: ignore else: await run_in_threadpool(self.file.write, data) - async def read(self, size: int = -1) -> typing.Union[bytes, str]: + async def read(self, size: int = -1) -> bytes: if self._in_memory: return self.file.read(size) return await run_in_threadpool(self.file.read, size) diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index bb71ba870..5d44bdc13 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -227,6 +227,18 @@ async def test_upload_file(): await big_file.close() +@pytest.mark.anyio +async def test_upload_file_file_input(): + """Test passing file/stream into the UploadFile constructor""" + stream = io.BytesIO(b"data") + file = UploadFile(filename="file", file=stream) + assert await file.read() == b"data" + await file.write(b" and more data!") + assert await file.read() == b"" + await file.seek(0) + assert await file.read() == b"data and more data!" + + def test_formdata(): upload = io.BytesIO(b"test") form = FormData([("a", "123"), ("a", "456"), ("b", upload)]) From 3c93a19cef91ae856a5d8f02e09b481525071a07 Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Tue, 11 Jan 2022 09:28:39 +0100 Subject: [PATCH 3/8] Add Mypy checks to tests (#1353) --- setup.cfg | 3 +-- tests/middleware/test_cors.py | 2 +- tests/middleware/test_session.py | 8 ++++++-- tests/test_database.py | 2 ++ tests/test_datastructures.py | 7 ++----- tests/test_formparsers.py | 3 ++- tests/test_routing.py | 8 ++++---- tests/test_templates.py | 2 +- 8 files changed, 19 insertions(+), 16 deletions(-) diff --git a/setup.cfg b/setup.cfg index 0595a7463..3089eaaf7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,8 +9,7 @@ show_error_codes = True [mypy-tests.*] disallow_untyped_defs = False -# https://github.com/encode/starlette/issues/1045 -# check_untyped_defs = True +check_untyped_defs = True [tool:isort] profile = black diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py index 65252e502..2f0ca3d34 100644 --- a/tests/middleware/test_cors.py +++ b/tests/middleware/test_cors.py @@ -258,7 +258,7 @@ def test_cors_allow_all_methods(test_client_factory): ) @app.route( - "/", methods=("delete", "get", "head", "options", "patch", "post", "put") + "/", methods=["delete", "get", "head", "options", "patch", "post", "put"] ) def homepage(request): return PlainTextResponse("Homepage", status_code=200) diff --git a/tests/middleware/test_session.py b/tests/middleware/test_session.py index 42f4447e5..07296bcbb 100644 --- a/tests/middleware/test_session.py +++ b/tests/middleware/test_session.py @@ -66,7 +66,9 @@ def test_session_expires(test_client_factory): # requests removes expired cookies from response.cookies, we need to # fetch session id from the headers and pass it explicitly expired_cookie_header = response.headers["set-cookie"] - expired_session_value = re.search(r"session=([^;]*);", expired_cookie_header)[1] + expired_session_match = re.search(r"session=([^;]*);", expired_cookie_header) + assert expired_session_match is not None + expired_session_value = expired_session_match[1] response = client.get("/view_session", cookies={"session": expired_session_value}) assert response.json() == {"session": {}} @@ -110,7 +112,9 @@ def test_session_cookie_subpath(test_client_factory): client = test_client_factory(app, base_url="http://testserver/second_app") response = client.post("second_app/update_session", json={"some": "data"}) cookie = response.headers["set-cookie"] - cookie_path = re.search(r"; path=(\S+);", cookie).groups()[0] + cookie_path_match = re.search(r"; path=(\S+);", cookie) + assert cookie_path_match is not None + cookie_path = cookie_path_match.groups()[0] assert cookie_path == "/second_app" diff --git a/tests/test_database.py b/tests/test_database.py index c0a4745d1..11f770bb1 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -77,6 +77,7 @@ async def read_note(request): note_id = request.path_params["note_id"] query = notes.select().where(notes.c.id == note_id) result = await database.fetch_one(query) + assert result is not None content = {"text": result["text"], "completed": result["completed"]} return JSONResponse(content) @@ -86,6 +87,7 @@ async def read_note_text(request): note_id = request.path_params["note_id"] query = sqlalchemy.select([notes.c.text]).where(notes.c.id == note_id) result = await database.fetch_one(query) + assert result is not None return JSONResponse(result[0]) diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index 5d44bdc13..b110aa8bd 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -240,7 +240,8 @@ async def test_upload_file_file_input(): def test_formdata(): - upload = io.BytesIO(b"test") + stream = io.BytesIO(b"data") + upload = UploadFile(filename="file", file=stream) form = FormData([("a", "123"), ("a", "456"), ("b", upload)]) assert "a" in form assert "A" not in form @@ -350,10 +351,6 @@ def test_multidict(): q.update(q) assert repr(q) == "MultiDict([('a', '123'), ('b', '456')])" - q = MultiDict([("a", "123"), ("b", "456")]) - q.update(None) - assert repr(q) == "MultiDict([('a', '123'), ('b', '456')])" - q = MultiDict([("a", "123"), ("a", "456")]) q.update([("a", "123")]) assert q.getlist("a") == ["123"] diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index 384a885dc..05f0f053c 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -1,4 +1,5 @@ import os +import typing from starlette.formparsers import UploadFile, _user_safe_decode from starlette.requests import Request @@ -36,7 +37,7 @@ async def app(scope, receive, send): async def multi_items_app(scope, receive, send): request = Request(scope, receive) data = await request.form() - output = {} + output: typing.Dict[str, list] = {} for key, value in data.multi_items(): if key not in output: output[key] = [] diff --git a/tests/test_routing.py b/tests/test_routing.py index dcb996531..231c581fb 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -459,13 +459,13 @@ async def subdomain_app(scope, receive, send): await response(scope, receive, send) -subdomain_app = Router( +subdomain_router = Router( routes=[Host("{subdomain}.example.org", app=subdomain_app, name="subdomains")] ) def test_subdomain_routing(test_client_factory): - client = test_client_factory(subdomain_app, base_url="https://foo.example.org/") + client = test_client_factory(subdomain_router, base_url="https://foo.example.org/") response = client.get("/") assert response.status_code == 200 @@ -474,7 +474,7 @@ def test_subdomain_routing(test_client_factory): def test_subdomain_reverse_urls(): assert ( - subdomain_app.url_path_for( + subdomain_router.url_path_for( "subdomains", subdomain="foo", path="/homepage" ).make_absolute_url("https://whatever") == "https://foo.example.org/homepage" @@ -637,6 +637,7 @@ def run_startup(): raise RuntimeError() router = Router(on_startup=[run_startup]) + startup_failed = False async def app(scope, receive, send): async def _send(message): @@ -647,7 +648,6 @@ async def _send(message): await router(scope, receive, _send) - startup_failed = False with pytest.raises(RuntimeError): with test_client_factory(app): pass # pragma: nocover diff --git a/tests/test_templates.py b/tests/test_templates.py index 073482d65..aa8279348 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -28,4 +28,4 @@ async def homepage(request): def test_template_response_requires_request(tmpdir): templates = Jinja2Templates(str(tmpdir)) with pytest.raises(ValueError): - templates.TemplateResponse(None, {}) + templates.TemplateResponse("", {}) From a7c5a41396752c39a5a9b688e2dccfaca152a62f Mon Sep 17 00:00:00 2001 From: Alex Oleshkevich Date: Wed, 12 Jan 2022 12:57:47 +0300 Subject: [PATCH 4/8] Allow Session scoped cookies. (#1387) * Allow Session scoped cookies. * Update docs/middleware.md Co-authored-by: Marcelo Trylesinski * Improve typing. Co-authored-by: Marcelo Trylesinski --- docs/middleware.md | 2 +- starlette/middleware/sessions.py | 14 +++++++------- tests/middleware/test_session.py | 17 +++++++++++++++++ 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/docs/middleware.md b/docs/middleware.md index 6d4d1a611..5fe7ce516 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -98,7 +98,7 @@ The following arguments are supported: * `secret_key` - Should be a random string. * `session_cookie` - Defaults to "session". -* `max_age` - Session expiry time in seconds. Defaults to 2 weeks. +* `max_age` - Session expiry time in seconds. Defaults to 2 weeks. If set to `None` then the cookie will last as long as the browser session. * `same_site` - SameSite flag prevents the browser from sending session cookie along with cross-site requests. Defaults to `'lax'`. * `https_only` - Indicate that Secure flag should be set (can be used with HTTPS only). Defaults to `False`. diff --git a/starlette/middleware/sessions.py b/starlette/middleware/sessions.py index ad7a6ee89..3ff1e3de1 100644 --- a/starlette/middleware/sessions.py +++ b/starlette/middleware/sessions.py @@ -16,7 +16,7 @@ def __init__( app: ASGIApp, secret_key: typing.Union[str, Secret], session_cookie: str = "session", - max_age: int = 14 * 24 * 60 * 60, # 14 days, in seconds + max_age: typing.Optional[int] = 14 * 24 * 60 * 60, # 14 days, in seconds same_site: str = "lax", https_only: bool = False, ) -> None: @@ -55,12 +55,12 @@ async def send_wrapper(message: Message) -> None: data = b64encode(json.dumps(scope["session"]).encode("utf-8")) data = self.signer.sign(data) headers = MutableHeaders(scope=message) - header_value = "%s=%s; path=%s; Max-Age=%d; %s" % ( - self.session_cookie, - data.decode("utf-8"), - path, - self.max_age, - self.security_flags, + header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format( # noqa E501 + session_cookie=self.session_cookie, + data=data.decode("utf-8"), + path=path, + max_age=f"Max-Age={self.max_age}; " if self.max_age else "", + security_flags=self.security_flags, ) headers.append("Set-Cookie", header_value) elif not initial_session_was_empty: diff --git a/tests/middleware/test_session.py b/tests/middleware/test_session.py index 07296bcbb..867a96735 100644 --- a/tests/middleware/test_session.py +++ b/tests/middleware/test_session.py @@ -129,3 +129,20 @@ def test_invalid_session_cookie(test_client_factory): # we expect it to not raise an exception if we provide a bogus session cookie response = client.get("/view_session", cookies={"session": "invalid"}) assert response.json() == {"session": {}} + + +def test_session_cookie(test_client_factory): + app = create_app() + app.add_middleware(SessionMiddleware, secret_key="example", max_age=None) + client = test_client_factory(app) + + response = client.post("/update_session", json={"some": "data"}) + assert response.json() == {"session": {"some": "data"}} + + # check cookie max-age + set_cookie = response.headers["set-cookie"] + assert "Max-Age" not in set_cookie + + client.cookies.clear_session_cookies() + response = client.get("/view_session") + assert response.json() == {"session": {}} From 49133f1cc3d2bbe219ee833f4d433e35003d196c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sondre=20Lilleb=C3=B8=20Gundersen?= Date: Thu, 13 Jan 2022 11:29:17 +0100 Subject: [PATCH 5/8] Correct spelling of asynchronous in events.md (#1408) --- docs/events.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/events.md b/docs/events.md index c7ed49e9d..f1514f96e 100644 --- a/docs/events.md +++ b/docs/events.md @@ -60,7 +60,7 @@ app = Starlette(routes=routes, lifespan=lifespan) ``` Consider using [`anyio.create_task_group()`](https://anyio.readthedocs.io/en/stable/tasks.html) -for managing asynchronious tasks. +for managing asynchronous tasks. ## Running event handlers in tests From aa20cdad0d07c529e97516ba9a071762a5c17221 Mon Sep 17 00:00:00 2001 From: Mng <50384638+Mng-dev-ai@users.noreply.github.com> Date: Thu, 13 Jan 2022 23:59:47 +0200 Subject: [PATCH 6/8] Remove count of available convertors (#1409) Co-authored-by: Micheal Gendy --- docs/routing.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/routing.md b/docs/routing.md index f0b672d8a..8bdc6215b 100644 --- a/docs/routing.md +++ b/docs/routing.md @@ -39,7 +39,7 @@ Route('/users/{username}', user) ``` By default this will capture characters up to the end of the path or the next `/`. -You can use convertors to modify what is captured. Four convertors are available: +You can use convertors to modify what is captured. The available convertors are: * `str` returns a string, and is the default. * `int` returns a Python integer. From 7d79ad96d5aaee71f16ac9f4e41072e81d18ab86 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 14 Jan 2022 01:40:18 -0800 Subject: [PATCH 7/8] Fix md5_hexdigest wrapper on FIPS enabled systems (#1410) * Fix md5_hexdigest wrapper on FIPS enabled systems * Update _compat.py * lint --- starlette/_compat.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/starlette/_compat.py b/starlette/_compat.py index 82aa72f38..116561917 100644 --- a/starlette/_compat.py +++ b/starlette/_compat.py @@ -11,7 +11,10 @@ # See issue: https://github.com/encode/starlette/issues/1365 try: - hashlib.md5(b"data", usedforsecurity=True) # type: ignore[call-arg] + # check if the Python version supports the parameter + # using usedforsecurity=False to avoid an exception on FIPS systems + # that reject usedforsecurity=True + hashlib.md5(b"data", usedforsecurity=False) # type: ignore[call-arg] def md5_hexdigest( data: bytes, *, usedforsecurity: bool = True From fcc4c705ff69182ebd663bc686cb55c242d32683 Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Fri, 14 Jan 2022 14:30:46 +0100 Subject: [PATCH 8/8] Use typing `NoReturn` (#1412) --- requirements.txt | 2 +- starlette/datastructures.py | 4 ++-- starlette/middleware/gzip.py | 3 ++- starlette/requests.py | 6 +++--- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/requirements.txt b/requirements.txt index 48f3cfa3e..51389304e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ coverage==6.2 databases[sqlite]==0.5.3 flake8==4.0.1 isort==5.10.1 -mypy==0.930 +mypy==0.931 types-requests==2.26.3 types-contextvars==2.4.0 types-PyYAML==6.0.1 diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 52904d51e..2c5c4b016 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -163,7 +163,7 @@ class URLPath(str): def __new__(cls, path: str, protocol: str = "", host: str = "") -> "URLPath": assert protocol in ("http", "websocket", "") - return str.__new__(cls, path) # type: ignore + return str.__new__(cls, path) def __init__(self, path: str, protocol: str = "", host: str = "") -> None: self.protocol = protocol @@ -441,7 +441,7 @@ def _in_memory(self) -> bool: async def write(self, data: bytes) -> None: if self._in_memory: - self.file.write(data) # type: ignore + self.file.write(data) else: await run_in_threadpool(self.file.write, data) diff --git a/starlette/middleware/gzip.py b/starlette/middleware/gzip.py index 37c6936fa..9d69ee7ca 100644 --- a/starlette/middleware/gzip.py +++ b/starlette/middleware/gzip.py @@ -1,5 +1,6 @@ import gzip import io +import typing from starlette.datastructures import Headers, MutableHeaders from starlette.types import ASGIApp, Message, Receive, Scope, Send @@ -100,5 +101,5 @@ async def send_with_gzip(self, message: Message) -> None: await self.send(message) -async def unattached_send(message: Message) -> None: +async def unattached_send(message: Message) -> typing.NoReturn: raise RuntimeError("send awaitable not set") # pragma: no cover diff --git a/starlette/requests.py b/starlette/requests.py index cf129702e..a33367e1d 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -51,7 +51,7 @@ def cookie_parser(cookie_string: str) -> typing.Dict[str, str]: key, val = key.strip(), val.strip() if key or val: # unquote using Python's algorithm. - cookie_dict[key] = http_cookies._unquote(val) # type: ignore + cookie_dict[key] = http_cookies._unquote(val) return cookie_dict @@ -175,11 +175,11 @@ def url_for(self, name: str, **path_params: typing.Any) -> str: return url_path.make_absolute_url(base_url=self.base_url) -async def empty_receive() -> Message: +async def empty_receive() -> typing.NoReturn: raise RuntimeError("Receive channel has not been made available") -async def empty_send(message: Message) -> None: +async def empty_send(message: Message) -> typing.NoReturn: raise RuntimeError("Send channel has not been made available")