Skip to content

Commit

Permalink
Merge branch 'master' into fix-staticfiles-follow-symlinks
Browse files Browse the repository at this point in the history
  • Loading branch information
aminalaee committed Jan 16, 2022
2 parents 5056b84 + fcc4c70 commit 5335c98
Show file tree
Hide file tree
Showing 20 changed files with 91 additions and 42 deletions.
2 changes: 1 addition & 1 deletion docs/events.md
Expand Up @@ -60,7 +60,7 @@ app = Starlette(routes=routes, lifespan=lifespan)
```

Consider using [`anyio.create_task_group()`](https://anyio.readthedocs.io/en/stable/tasks.html)
for managing asynchronious tasks.
for managing asynchronous tasks.

## Running event handlers in tests

Expand Down
2 changes: 1 addition & 1 deletion docs/middleware.md
Expand Up @@ -98,7 +98,7 @@ The following arguments are supported:

* `secret_key` - Should be a random string.
* `session_cookie` - Defaults to "session".
* `max_age` - Session expiry time in seconds. Defaults to 2 weeks.
* `max_age` - Session expiry time in seconds. Defaults to 2 weeks. If set to `None` then the cookie will last as long as the browser session.
* `same_site` - SameSite flag prevents the browser from sending session cookie along with cross-site requests. Defaults to `'lax'`.
* `https_only` - Indicate that Secure flag should be set (can be used with HTTPS only). Defaults to `False`.

Expand Down
4 changes: 2 additions & 2 deletions docs/requests.md
Expand Up @@ -126,8 +126,8 @@ multidict, containing both file uploads and text input. File upload items are re

`UploadFile` has the following `async` methods. They all call the corresponding file methods underneath (using the internal `SpooledTemporaryFile`).

* `async write(data)`: Writes `data` (`str` or `bytes`) to the file.
* `async read(size)`: Reads `size` (`int`) bytes/characters of the file.
* `async write(data)`: Writes `data` (`bytes`) to the file.
* `async read(size)`: Reads `size` (`int`) bytes of the file.
* `async seek(offset)`: Goes to the byte position `offset` (`int`) in the file.
* E.g., `await myfile.seek(0)` would go to the start of the file.
* `async close()`: Closes the file.
Expand Down
2 changes: 1 addition & 1 deletion docs/routing.md
Expand Up @@ -39,7 +39,7 @@ Route('/users/{username}', user)
```
By default this will capture characters up to the end of the path or the next `/`.

You can use convertors to modify what is captured. Four convertors are available:
You can use convertors to modify what is captured. The available convertors are:

* `str` returns a string, and is the default.
* `int` returns a Python integer.
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Expand Up @@ -8,7 +8,7 @@ coverage==6.2
databases[sqlite]==0.5.3
flake8==4.0.1
isort==5.10.1
mypy==0.930
mypy==0.931
types-requests==2.26.3
types-contextvars==2.4.0
types-PyYAML==6.0.1
Expand Down
3 changes: 1 addition & 2 deletions setup.cfg
Expand Up @@ -9,8 +9,7 @@ show_error_codes = True

[mypy-tests.*]
disallow_untyped_defs = False
# https://github.com/encode/starlette/issues/1045
# check_untyped_defs = True
check_untyped_defs = True

[tool:isort]
profile = black
Expand Down
5 changes: 4 additions & 1 deletion starlette/_compat.py
Expand Up @@ -11,7 +11,10 @@
# See issue: https://github.com/encode/starlette/issues/1365
try:

hashlib.md5(b"data", usedforsecurity=True) # type: ignore[call-arg]
# check if the Python version supports the parameter
# using usedforsecurity=False to avoid an exception on FIPS systems
# that reject usedforsecurity=True
hashlib.md5(b"data", usedforsecurity=False) # type: ignore[call-arg]

def md5_hexdigest(
data: bytes, *, usedforsecurity: bool = True
Expand Down
16 changes: 9 additions & 7 deletions starlette/datastructures.py
Expand Up @@ -163,7 +163,7 @@ class URLPath(str):

def __new__(cls, path: str, protocol: str = "", host: str = "") -> "URLPath":
assert protocol in ("http", "websocket", "")
return str.__new__(cls, path) # type: ignore
return str.__new__(cls, path)

def __init__(self, path: str, protocol: str = "", host: str = "") -> None:
self.protocol = protocol
Expand Down Expand Up @@ -415,35 +415,37 @@ class UploadFile:
"""

spool_max_size = 1024 * 1024
file: typing.BinaryIO
headers: "Headers"

def __init__(
self,
filename: str,
file: typing.IO = None,
file: typing.Optional[typing.BinaryIO] = None,
content_type: str = "",
*,
headers: "typing.Optional[Headers]" = None,
) -> None:
self.filename = filename
self.content_type = content_type
if file is None:
file = tempfile.SpooledTemporaryFile(max_size=self.spool_max_size)
self.file = file
self.file = tempfile.SpooledTemporaryFile(max_size=self.spool_max_size) # type: ignore # noqa: E501
else:
self.file = file
self.headers = headers or Headers()

@property
def _in_memory(self) -> bool:
rolled_to_disk = getattr(self.file, "_rolled", True)
return not rolled_to_disk

async def write(self, data: typing.Union[bytes, str]) -> None:
async def write(self, data: bytes) -> None:
if self._in_memory:
self.file.write(data) # type: ignore
self.file.write(data)
else:
await run_in_threadpool(self.file.write, data)

async def read(self, size: int = -1) -> typing.Union[bytes, str]:
async def read(self, size: int = -1) -> bytes:
if self._in_memory:
return self.file.read(size)
return await run_in_threadpool(self.file.read, size)
Expand Down
3 changes: 2 additions & 1 deletion starlette/middleware/gzip.py
@@ -1,5 +1,6 @@
import gzip
import io
import typing

from starlette.datastructures import Headers, MutableHeaders
from starlette.types import ASGIApp, Message, Receive, Scope, Send
Expand Down Expand Up @@ -100,5 +101,5 @@ async def send_with_gzip(self, message: Message) -> None:
await self.send(message)


async def unattached_send(message: Message) -> None:
async def unattached_send(message: Message) -> typing.NoReturn:
raise RuntimeError("send awaitable not set") # pragma: no cover
14 changes: 7 additions & 7 deletions starlette/middleware/sessions.py
Expand Up @@ -16,7 +16,7 @@ def __init__(
app: ASGIApp,
secret_key: typing.Union[str, Secret],
session_cookie: str = "session",
max_age: int = 14 * 24 * 60 * 60, # 14 days, in seconds
max_age: typing.Optional[int] = 14 * 24 * 60 * 60, # 14 days, in seconds
same_site: str = "lax",
https_only: bool = False,
) -> None:
Expand Down Expand Up @@ -55,12 +55,12 @@ async def send_wrapper(message: Message) -> None:
data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
data = self.signer.sign(data)
headers = MutableHeaders(scope=message)
header_value = "%s=%s; path=%s; Max-Age=%d; %s" % (
self.session_cookie,
data.decode("utf-8"),
path,
self.max_age,
self.security_flags,
header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format( # noqa E501
session_cookie=self.session_cookie,
data=data.decode("utf-8"),
path=path,
max_age=f"Max-Age={self.max_age}; " if self.max_age else "",
security_flags=self.security_flags,
)
headers.append("Set-Cookie", header_value)
elif not initial_session_was_empty:
Expand Down
6 changes: 3 additions & 3 deletions starlette/requests.py
Expand Up @@ -51,7 +51,7 @@ def cookie_parser(cookie_string: str) -> typing.Dict[str, str]:
key, val = key.strip(), val.strip()
if key or val:
# unquote using Python's algorithm.
cookie_dict[key] = http_cookies._unquote(val) # type: ignore
cookie_dict[key] = http_cookies._unquote(val)
return cookie_dict


Expand Down Expand Up @@ -175,11 +175,11 @@ def url_for(self, name: str, **path_params: typing.Any) -> str:
return url_path.make_absolute_url(base_url=self.base_url)


async def empty_receive() -> Message:
async def empty_receive() -> typing.NoReturn:
raise RuntimeError("Receive channel has not been made available")


async def empty_send(message: Message) -> None:
async def empty_send(message: Message) -> typing.NoReturn:
raise RuntimeError("Send channel has not been made available")


Expand Down
6 changes: 5 additions & 1 deletion starlette/responses.py
Expand Up @@ -71,7 +71,11 @@ def init_headers(self, headers: typing.Mapping[str, str] = None) -> None:
populate_content_type = b"content-type" not in keys

body = getattr(self, "body", None)
if body is not None and populate_content_length:
if (
body is not None
and populate_content_length
and not (self.status_code < 200 or self.status_code in (204, 304))
):
content_length = str(len(body))
raw_headers.append((b"content-length", content_length.encode("latin-1")))

Expand Down
2 changes: 1 addition & 1 deletion tests/middleware/test_cors.py
Expand Up @@ -258,7 +258,7 @@ def test_cors_allow_all_methods(test_client_factory):
)

@app.route(
"/", methods=("delete", "get", "head", "options", "patch", "post", "put")
"/", methods=["delete", "get", "head", "options", "patch", "post", "put"]
)
def homepage(request):
return PlainTextResponse("Homepage", status_code=200)
Expand Down
25 changes: 23 additions & 2 deletions tests/middleware/test_session.py
Expand Up @@ -66,7 +66,9 @@ def test_session_expires(test_client_factory):
# requests removes expired cookies from response.cookies, we need to
# fetch session id from the headers and pass it explicitly
expired_cookie_header = response.headers["set-cookie"]
expired_session_value = re.search(r"session=([^;]*);", expired_cookie_header)[1]
expired_session_match = re.search(r"session=([^;]*);", expired_cookie_header)
assert expired_session_match is not None
expired_session_value = expired_session_match[1]
response = client.get("/view_session", cookies={"session": expired_session_value})
assert response.json() == {"session": {}}

Expand Down Expand Up @@ -110,7 +112,9 @@ def test_session_cookie_subpath(test_client_factory):
client = test_client_factory(app, base_url="http://testserver/second_app")
response = client.post("second_app/update_session", json={"some": "data"})
cookie = response.headers["set-cookie"]
cookie_path = re.search(r"; path=(\S+);", cookie).groups()[0]
cookie_path_match = re.search(r"; path=(\S+);", cookie)
assert cookie_path_match is not None
cookie_path = cookie_path_match.groups()[0]
assert cookie_path == "/second_app"


Expand All @@ -125,3 +129,20 @@ def test_invalid_session_cookie(test_client_factory):
# we expect it to not raise an exception if we provide a bogus session cookie
response = client.get("/view_session", cookies={"session": "invalid"})
assert response.json() == {"session": {}}


def test_session_cookie(test_client_factory):
app = create_app()
app.add_middleware(SessionMiddleware, secret_key="example", max_age=None)
client = test_client_factory(app)

response = client.post("/update_session", json={"some": "data"})
assert response.json() == {"session": {"some": "data"}}

# check cookie max-age
set_cookie = response.headers["set-cookie"]
assert "Max-Age" not in set_cookie

client.cookies.clear_session_cookies()
response = client.get("/view_session")
assert response.json() == {"session": {}}
2 changes: 2 additions & 0 deletions tests/test_database.py
Expand Up @@ -77,6 +77,7 @@ async def read_note(request):
note_id = request.path_params["note_id"]
query = notes.select().where(notes.c.id == note_id)
result = await database.fetch_one(query)
assert result is not None
content = {"text": result["text"], "completed": result["completed"]}
return JSONResponse(content)

Expand All @@ -86,6 +87,7 @@ async def read_note_text(request):
note_id = request.path_params["note_id"]
query = sqlalchemy.select([notes.c.text]).where(notes.c.id == note_id)
result = await database.fetch_one(query)
assert result is not None
return JSONResponse(result[0])


Expand Down
19 changes: 14 additions & 5 deletions tests/test_datastructures.py
Expand Up @@ -227,8 +227,21 @@ async def test_upload_file():
await big_file.close()


@pytest.mark.anyio
async def test_upload_file_file_input():
"""Test passing file/stream into the UploadFile constructor"""
stream = io.BytesIO(b"data")
file = UploadFile(filename="file", file=stream)
assert await file.read() == b"data"
await file.write(b" and more data!")
assert await file.read() == b""
await file.seek(0)
assert await file.read() == b"data and more data!"


def test_formdata():
upload = io.BytesIO(b"test")
stream = io.BytesIO(b"data")
upload = UploadFile(filename="file", file=stream)
form = FormData([("a", "123"), ("a", "456"), ("b", upload)])
assert "a" in form
assert "A" not in form
Expand Down Expand Up @@ -338,10 +351,6 @@ def test_multidict():
q.update(q)
assert repr(q) == "MultiDict([('a', '123'), ('b', '456')])"

q = MultiDict([("a", "123"), ("b", "456")])
q.update(None)
assert repr(q) == "MultiDict([('a', '123'), ('b', '456')])"

q = MultiDict([("a", "123"), ("a", "456")])
q.update([("a", "123")])
assert q.getlist("a") == ["123"]
Expand Down
3 changes: 2 additions & 1 deletion tests/test_formparsers.py
@@ -1,4 +1,5 @@
import os
import typing

from starlette.formparsers import UploadFile, _user_safe_decode
from starlette.requests import Request
Expand Down Expand Up @@ -36,7 +37,7 @@ async def app(scope, receive, send):
async def multi_items_app(scope, receive, send):
request = Request(scope, receive)
data = await request.form()
output = {}
output: typing.Dict[str, list] = {}
for key, value in data.multi_items():
if key not in output:
output[key] = []
Expand Down
7 changes: 7 additions & 0 deletions tests/test_responses.py
Expand Up @@ -333,6 +333,13 @@ def test_empty_response(test_client_factory):
assert response.headers["content-length"] == "0"


def test_empty_204_response(test_client_factory):
app = Response(status_code=204)
client: TestClient = test_client_factory(app)
response = client.get("/")
assert "content-length" not in response.headers


def test_non_empty_response(test_client_factory):
app = Response(content="hi")
client: TestClient = test_client_factory(app)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_routing.py
Expand Up @@ -459,13 +459,13 @@ async def subdomain_app(scope, receive, send):
await response(scope, receive, send)


subdomain_app = Router(
subdomain_router = Router(
routes=[Host("{subdomain}.example.org", app=subdomain_app, name="subdomains")]
)


def test_subdomain_routing(test_client_factory):
client = test_client_factory(subdomain_app, base_url="https://foo.example.org/")
client = test_client_factory(subdomain_router, base_url="https://foo.example.org/")

response = client.get("/")
assert response.status_code == 200
Expand All @@ -474,7 +474,7 @@ def test_subdomain_routing(test_client_factory):

def test_subdomain_reverse_urls():
assert (
subdomain_app.url_path_for(
subdomain_router.url_path_for(
"subdomains", subdomain="foo", path="/homepage"
).make_absolute_url("https://whatever")
== "https://foo.example.org/homepage"
Expand Down Expand Up @@ -637,6 +637,7 @@ def run_startup():
raise RuntimeError()

router = Router(on_startup=[run_startup])
startup_failed = False

async def app(scope, receive, send):
async def _send(message):
Expand All @@ -647,7 +648,6 @@ async def _send(message):

await router(scope, receive, _send)

startup_failed = False
with pytest.raises(RuntimeError):
with test_client_factory(app):
pass # pragma: nocover
Expand Down
2 changes: 1 addition & 1 deletion tests/test_templates.py
Expand Up @@ -28,4 +28,4 @@ async def homepage(request):
def test_template_response_requires_request(tmpdir):
templates = Jinja2Templates(str(tmpdir))
with pytest.raises(ValueError):
templates.TemplateResponse(None, {})
templates.TemplateResponse("", {})

0 comments on commit 5335c98

Please sign in to comment.