Skip to content

Commit

Permalink
re-implement simple CORS middleware for blackd
Browse files Browse the repository at this point in the history
  • Loading branch information
zsol committed Sep 18, 2021
1 parent 911470a commit 1a9bb52
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 15 deletions.
21 changes: 6 additions & 15 deletions src/blackd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

try:
from aiohttp import web
import aiohttp_cors
from .middlewares import cors
except ImportError as ie:
raise ImportError(
f"aiohttp dependency is not installed: {ie}. "
Expand Down Expand Up @@ -67,20 +67,11 @@ def main(bind_host: str, bind_port: int) -> None:


def make_app() -> web.Application:
app = web.Application()
executor = ProcessPoolExecutor()

cors = aiohttp_cors.setup(app)
resource = cors.add(app.router.add_resource("/"))
cors.add(
resource.add_route("POST", partial(handle, executor=executor)),
{
"*": aiohttp_cors.ResourceOptions(
allow_headers=(*BLACK_HEADERS, "Content-Type"), expose_headers="*"
)
},
app = web.Application(
middlewares=[cors(allow_headers=(*BLACK_HEADERS, "Content-Type"))]
)

executor = ProcessPoolExecutor()
app.add_routes([web.post("/", partial(handle, executor=executor))])
return app


Expand Down Expand Up @@ -195,7 +186,7 @@ def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersi
raise InvalidVariantHeader(f"3.{minor} is not supported")
versions.add(black.TargetVersion[version_str])
except (KeyError, ValueError):
raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'")
raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'") from None
return False, versions


Expand Down
34 changes: 34 additions & 0 deletions src/blackd/middlewares.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Iterable, Awaitable, Callable
from aiohttp.web import middleware
from aiohttp.web_response import StreamResponse
from aiohttp.web_request import Request

Handler = Callable[[Request], Awaitable[StreamResponse]]
Middleware = Callable[[Request, Handler], Awaitable[StreamResponse]]


def cors(allow_headers: Iterable[str]) -> Middleware:
@middleware
async def impl(request: Request, handler: Handler) -> StreamResponse:
is_options = request.method == "OPTIONS"
is_preflight = is_options and "Access-Control-Request-Method" in request.headers
if is_preflight:
resp = StreamResponse()
else:
resp = await handler(request)

origin = request.headers.get("Origin")
if not origin:
return resp

resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Expose-Headers"] = "*"
if is_options:
resp.headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers)
resp.headers["Access-Control-Allow-Methods"] = ", ".join(
("OPTIONS", "POST")
)

return resp

return impl # type: ignore
21 changes: 21 additions & 0 deletions tests/test_blackd.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,24 @@ async def test_blackd_invalid_line_length(self) -> None:
async def test_blackd_response_black_version_header(self) -> None:
response = await self.client.post("/")
self.assertIsNotNone(response.headers.get(blackd.BLACK_VERSION_HEADER))

@unittest_run_loop
async def test_cors_preflight(self) -> None:
response = await self.client.options(
"/",
headers={
"Access-Control-Request-Method": "POST",
"Origin": "*",
"Access-Control-Request-Headers": "Content-Type",
},
)
self.assertEqual(response.status, 200)
self.assertIsNotNone(response.headers.get("Access-Control-Allow-Origin"))
self.assertIsNotNone(response.headers.get("Access-Control-Allow-Headers"))
self.assertIsNotNone(response.headers.get("Access-Control-Allow-Methods"))

@unittest_run_loop
async def test_cors_headers_present(self) -> None:
response = await self.client.post("/", headers={"Origin": "*"})
self.assertIsNotNone(response.headers.get("Access-Control-Allow-Origin"))
self.assertIsNotNone(response.headers.get("Access-Control-Expose-Headers"))

0 comments on commit 1a9bb52

Please sign in to comment.