Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

re-implement simple CORS middleware for blackd #2500

Merged
merged 7 commits into from Sep 25, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGES.md
@@ -1,5 +1,11 @@
# Change Log

## Unreleased

### _Blackd_

- Remove dependency on aiohttp-cors (#2500)

## 21.9b0

### Packaging
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -86,7 +86,7 @@ def get_long_description() -> str:
"mypy_extensions>=0.4.3",
],
extras_require={
"d": ["aiohttp>=3.6.0", "aiohttp-cors>=0.4.0"],
"d": ["aiohttp>=3.6.0"],
"colorama": ["colorama>=0.4.3"],
"python2": ["typed-ast>=1.4.2"],
"uvloop": ["uvloop>=0.15.2"],
Expand Down
19 changes: 5 additions & 14 deletions src/blackd/__init__.py
Expand Up @@ -8,7 +8,7 @@

try:
from aiohttp import web
import aiohttp_cors
from .middlewares import cors
except ImportError as ie:
raise ImportError(
f"aiohttp dependency is not installed: {ie}. "
Expand Down Expand Up @@ -67,20 +67,11 @@ def main(bind_host: str, bind_port: int) -> None:


def make_app() -> web.Application:
app = web.Application()
executor = ProcessPoolExecutor()

cors = aiohttp_cors.setup(app)
resource = cors.add(app.router.add_resource("/"))
cors.add(
resource.add_route("POST", partial(handle, executor=executor)),
{
"*": aiohttp_cors.ResourceOptions(
allow_headers=(*BLACK_HEADERS, "Content-Type"), expose_headers="*"
)
},
app = web.Application(
middlewares=[cors(allow_headers=(*BLACK_HEADERS, "Content-Type"))]
)

executor = ProcessPoolExecutor()
app.add_routes([web.post("/", partial(handle, executor=executor))])
return app


Expand Down
32 changes: 32 additions & 0 deletions src/blackd/middlewares.py
@@ -0,0 +1,32 @@
from typing import Iterable, Awaitable, Callable
from aiohttp.web_response import StreamResponse
from aiohttp.web_request import Request

Handler = Callable[[Request], Awaitable[StreamResponse]]
Middleware = Callable[[Request, Handler], Awaitable[StreamResponse]]


def cors(allow_headers: Iterable[str]) -> Middleware:
async def impl(request: Request, handler: Handler) -> StreamResponse:
is_options = request.method == "OPTIONS"
is_preflight = is_options and "Access-Control-Request-Method" in request.headers
if is_preflight:
resp = StreamResponse()
else:
resp = await handler(request)

origin = request.headers.get("Origin")
if not origin:
return resp

resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Expose-Headers"] = "*"
if is_options:
resp.headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers)
resp.headers["Access-Control-Allow-Methods"] = ", ".join(
("OPTIONS", "POST")
)

return resp

return impl
21 changes: 21 additions & 0 deletions tests/test_blackd.py
Expand Up @@ -164,3 +164,24 @@ async def test_blackd_invalid_line_length(self) -> None:
async def test_blackd_response_black_version_header(self) -> None:
response = await self.client.post("/")
self.assertIsNotNone(response.headers.get(blackd.BLACK_VERSION_HEADER))

@unittest_run_loop
async def test_cors_preflight(self) -> None:
response = await self.client.options(
"/",
headers={
"Access-Control-Request-Method": "POST",
"Origin": "*",
"Access-Control-Request-Headers": "Content-Type",
},
)
self.assertEqual(response.status, 200)
self.assertIsNotNone(response.headers.get("Access-Control-Allow-Origin"))
self.assertIsNotNone(response.headers.get("Access-Control-Allow-Headers"))
self.assertIsNotNone(response.headers.get("Access-Control-Allow-Methods"))

@unittest_run_loop
async def test_cors_headers_present(self) -> None:
response = await self.client.post("/", headers={"Origin": "*"})
self.assertIsNotNone(response.headers.get("Access-Control-Allow-Origin"))
self.assertIsNotNone(response.headers.get("Access-Control-Expose-Headers"))