Skip to content

Commit

Permalink
Add Mypy checks to tests
Browse files Browse the repository at this point in the history
Add support for functools.partial in WebsocketRoute (#1356)

* Add support for functools.partial in WebsocketRoute

* remove commented code

* Refactor tests for partian endpoint and ws

Fix snippet on config.md (#1358)

* Fix snippet on config.md

* Update docs/config.md

Co-authored-by: Amin Alaee <mohammadamin.alaee@gmail.com>

Co-authored-by: Amin Alaee <mohammadamin.alaee@gmail.com>

remove outdated test
  • Loading branch information
aminalaee committed Dec 13, 2021
1 parent f53faba commit bcddbc0
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 51 deletions.
5 changes: 3 additions & 2 deletions docs/config.md
Expand Up @@ -128,16 +128,17 @@ application logic separated:
**myproject/settings.py**:

```python
import databases
from starlette.config import Config
from starlette.datastructures import URL, Secret
from starlette.datastructures import Secret

config = Config(".env")

DEBUG = config('DEBUG', cast=bool, default=False)
TESTING = config('TESTING', cast=bool, default=False)
SECRET_KEY = config('SECRET_KEY', cast=Secret)

DATABASE_URL = config('DATABASE_URL', cast=URL)
DATABASE_URL = config('DATABASE_URL', cast=databases.DatabaseURL)
if TESTING:
DATABASE_URL = DATABASE_URL.replace(database='test_' + DATABASE_URL.database)
```
Expand Down
3 changes: 1 addition & 2 deletions setup.cfg
Expand Up @@ -8,8 +8,7 @@ ignore_missing_imports = True

[mypy-tests.*]
disallow_untyped_defs = False
# https://github.com/encode/starlette/issues/1045
# check_untyped_defs = True
check_untyped_defs = True

[tool:isort]
profile = black
Expand Down
5 changes: 4 additions & 1 deletion starlette/routing.py
Expand Up @@ -276,7 +276,10 @@ def __init__(
self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name

if inspect.isfunction(endpoint) or inspect.ismethod(endpoint):
endpoint_handler = endpoint
while isinstance(endpoint_handler, functools.partial):
endpoint_handler = endpoint_handler.func
if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler):
# Endpoint is function or method. Treat it as `func(websocket)`.
self.app = websocket_session(endpoint)
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/middleware/test_cors.py
Expand Up @@ -258,7 +258,7 @@ def test_cors_allow_all_methods(test_client_factory):
)

@app.route(
"/", methods=("delete", "get", "head", "options", "patch", "post", "put")
"/", methods=["delete", "get", "head", "options", "patch", "post", "put"]
)
def homepage(request):
return PlainTextResponse("Homepage", status_code=200)
Expand Down
8 changes: 6 additions & 2 deletions tests/middleware/test_session.py
Expand Up @@ -66,7 +66,9 @@ def test_session_expires(test_client_factory):
# requests removes expired cookies from response.cookies, we need to
# fetch session id from the headers and pass it explicitly
expired_cookie_header = response.headers["set-cookie"]
expired_session_value = re.search(r"session=([^;]*);", expired_cookie_header)[1]
expired_session_match = re.search(r"session=([^;]*);", expired_cookie_header)
assert expired_session_match is not None
expired_session_value = expired_session_match[1]
response = client.get("/view_session", cookies={"session": expired_session_value})
assert response.json() == {"session": {}}

Expand Down Expand Up @@ -110,7 +112,9 @@ def test_session_cookie_subpath(test_client_factory):
client = test_client_factory(app, base_url="http://testserver/second_app")
response = client.post("second_app/update_session", json={"some": "data"})
cookie = response.headers["set-cookie"]
cookie_path = re.search(r"; path=(\S+);", cookie).groups()[0]
cookie_path_match = re.search(r"; path=(\S+);", cookie)
assert cookie_path_match is not None
cookie_path = cookie_path_match.groups()[0]
assert cookie_path == "/second_app"


Expand Down
2 changes: 2 additions & 0 deletions tests/test_database.py
Expand Up @@ -77,6 +77,7 @@ async def read_note(request):
note_id = request.path_params["note_id"]
query = notes.select().where(notes.c.id == note_id)
result = await database.fetch_one(query)
assert result is not None
content = {"text": result["text"], "completed": result["completed"]}
return JSONResponse(content)

Expand All @@ -86,6 +87,7 @@ async def read_note_text(request):
note_id = request.path_params["note_id"]
query = sqlalchemy.select([notes.c.text]).where(notes.c.id == note_id)
result = await database.fetch_one(query)
assert result is not None
return JSONResponse(result[0])


Expand Down
8 changes: 1 addition & 7 deletions tests/test_datastructures.py
@@ -1,5 +1,3 @@
import io

import pytest

from starlette.datastructures import (
Expand Down Expand Up @@ -228,7 +226,7 @@ async def test_upload_file():


def test_formdata():
upload = io.BytesIO(b"test")
upload = UploadFile("test")
form = FormData([("a", "123"), ("a", "456"), ("b", upload)])
assert "a" in form
assert "A" not in form
Expand Down Expand Up @@ -338,10 +336,6 @@ def test_multidict():
q.update(q)
assert repr(q) == "MultiDict([('a', '123'), ('b', '456')])"

q = MultiDict([("a", "123"), ("b", "456")])
q.update(None)
assert repr(q) == "MultiDict([('a', '123'), ('b', '456')])"

q = MultiDict([("a", "123"), ("a", "456")])
q.update([("a", "123")])
assert q.getlist("a") == ["123"]
Expand Down
9 changes: 6 additions & 3 deletions tests/test_formparsers.py
@@ -1,4 +1,5 @@
import os
import typing

from starlette.formparsers import UploadFile, _user_safe_decode
from starlette.requests import Request
Expand All @@ -21,9 +22,10 @@ async def app(scope, receive, send):
for key, value in data.items():
if isinstance(value, UploadFile):
content = await value.read()
content = content.decode() if isinstance(content, bytes) else content
output[key] = {
"filename": value.filename,
"content": content.decode(),
"content": content,
"content_type": value.content_type,
}
else:
Expand All @@ -36,16 +38,17 @@ async def app(scope, receive, send):
async def multi_items_app(scope, receive, send):
request = Request(scope, receive)
data = await request.form()
output = {}
output: typing.Dict[str, list] = {}
for key, value in data.multi_items():
if key not in output:
output[key] = []
if isinstance(value, UploadFile):
content = await value.read()
content = content.decode() if isinstance(content, bytes) else content
output[key].append(
{
"filename": value.filename,
"content": content.decode(),
"content": content,
"content_type": value.content_type,
}
)
Expand Down
89 changes: 57 additions & 32 deletions tests/test_routing.py
Expand Up @@ -32,6 +32,28 @@ def user_no_match(request): # pragma: no cover
return Response(content, media_type="text/plain")


async def partial_endpoint(arg, request):
return JSONResponse({"arg": arg})


async def partial_ws_endpoint(websocket: WebSocket):
await websocket.accept()
await websocket.send_json({"url": str(websocket.url)})
await websocket.close()


class PartialRoutes:
@classmethod
async def async_endpoint(cls, arg, request):
return JSONResponse({"arg": arg})

@classmethod
async def async_ws_endpoint(cls, websocket: WebSocket):
await websocket.accept()
await websocket.send_json({"url": str(websocket.url)})
await websocket.close()


app = Router(
[
Route("/", endpoint=homepage, methods=["GET"]),
Expand All @@ -44,6 +66,21 @@ def user_no_match(request): # pragma: no cover
Route("/nomatch", endpoint=user_no_match),
],
),
Mount(
"/partial",
routes=[
Route("/", endpoint=functools.partial(partial_endpoint, "foo")),
Route(
"/cls",
endpoint=functools.partial(PartialRoutes.async_endpoint, "foo"),
),
WebSocketRoute("/ws", endpoint=functools.partial(partial_ws_endpoint)),
WebSocketRoute(
"/ws/cls",
endpoint=functools.partial(PartialRoutes.async_ws_endpoint),
),
],
),
Mount("/static", app=Response("xxxxx", media_type="image/png")),
]
)
Expand Down Expand Up @@ -91,14 +128,14 @@ def path_with_parentheses(request):


@app.websocket_route("/ws")
async def websocket_endpoint(session):
async def websocket_endpoint(session: WebSocket):
await session.accept()
await session.send_text("Hello, world!")
await session.close()


@app.websocket_route("/ws/{room}")
async def websocket_params(session):
async def websocket_params(session: WebSocket):
await session.accept()
await session.send_text(f"Hello, {session.path_params['room']}!")
await session.close()
Expand Down Expand Up @@ -422,13 +459,13 @@ async def subdomain_app(scope, receive, send):
await response(scope, receive, send)


subdomain_app = Router(
subdomain_router = Router(
routes=[Host("{subdomain}.example.org", app=subdomain_app, name="subdomains")]
)


def test_subdomain_routing(test_client_factory):
client = test_client_factory(subdomain_app, base_url="https://foo.example.org/")
client = test_client_factory(subdomain_router, base_url="https://foo.example.org/")

response = client.get("/")
assert response.status_code == 200
Expand All @@ -437,7 +474,7 @@ def test_subdomain_routing(test_client_factory):

def test_subdomain_reverse_urls():
assert (
subdomain_app.url_path_for(
subdomain_router.url_path_for(
"subdomains", subdomain="foo", path="/homepage"
).make_absolute_url("https://whatever")
== "https://foo.example.org/homepage"
Expand Down Expand Up @@ -600,6 +637,7 @@ def run_startup():
raise RuntimeError()

router = Router(on_startup=[run_startup])
startup_failed = False

async def app(scope, receive, send):
async def _send(message):
Expand All @@ -610,7 +648,6 @@ async def _send(message):

await router(scope, receive, _send)

startup_failed = False
with pytest.raises(RuntimeError):
with test_client_factory(app):
pass # pragma: nocover
Expand All @@ -628,40 +665,28 @@ def run_shutdown():
pass # pragma: nocover


class AsyncEndpointClassMethod:
@classmethod
async def async_endpoint(cls, arg, request):
return JSONResponse({"arg": arg})


async def _partial_async_endpoint(arg, request):
return JSONResponse({"arg": arg})


partial_async_endpoint = functools.partial(_partial_async_endpoint, "foo")
partial_cls_async_endpoint = functools.partial(
AsyncEndpointClassMethod.async_endpoint, "foo"
)

partial_async_app = Router(
routes=[
Route("/", partial_async_endpoint),
Route("/cls", partial_cls_async_endpoint),
]
)


def test_partial_async_endpoint(test_client_factory):
test_client = test_client_factory(partial_async_app)
response = test_client.get("/")
test_client = test_client_factory(app)
response = test_client.get("/partial")
assert response.status_code == 200
assert response.json() == {"arg": "foo"}

cls_method_response = test_client.get("/cls")
cls_method_response = test_client.get("/partial/cls")
assert cls_method_response.status_code == 200
assert cls_method_response.json() == {"arg": "foo"}


def test_partial_async_ws_endpoint(test_client_factory):
test_client = test_client_factory(app)
with test_client.websocket_connect("/partial/ws") as websocket:
data = websocket.receive_json()
assert data == {"url": "ws://testserver/partial/ws"}

with test_client.websocket_connect("/partial/ws/cls") as websocket:
data = websocket.receive_json()
assert data == {"url": "ws://testserver/partial/ws/cls"}


def test_duplicated_param_names():
with pytest.raises(
ValueError,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_templates.py
Expand Up @@ -28,4 +28,4 @@ async def homepage(request):
def test_template_response_requires_request(tmpdir):
templates = Jinja2Templates(str(tmpdir))
with pytest.raises(ValueError):
templates.TemplateResponse(None, {})
templates.TemplateResponse("", {})

0 comments on commit bcddbc0

Please sign in to comment.