Skip to content

Commit

Permalink
re-implement simple CORS middleware for blackd (#2500)
Browse files Browse the repository at this point in the history
* re-implement simple CORS middleware for blackd
* remove aiohttp-cors from setup.py
* Remove aiohttp-cors from Pipfile.lock

Co-authored-by: Richard Si <63936253+ichard26@users.noreply.github.com>
  • Loading branch information
zsol and ichard26 committed Sep 25, 2021
1 parent 7b15393 commit a5381ba
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 34 deletions.
6 changes: 6 additions & 0 deletions CHANGES.md
@@ -1,5 +1,11 @@
# Change Log

## Unreleased

### _Blackd_

- Remove dependency on aiohttp-cors (#2500)

## 21.9b0

### Packaging
Expand Down
1 change: 0 additions & 1 deletion Pipfile
Expand Up @@ -41,7 +41,6 @@ black = {editable = true, extras = ["d", "jupyter"], path = "."}

[packages]
aiohttp = ">=3.6.0"
aiohttp-cors = ">=0.4.0"
platformdirs= ">=2"
click = ">=8.0.0"
mypy_extensions = ">=0.4.3"
Expand Down
42 changes: 24 additions & 18 deletions Pipfile.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -86,7 +86,7 @@ def get_long_description() -> str:
"mypy_extensions>=0.4.3",
],
extras_require={
"d": ["aiohttp>=3.6.0", "aiohttp-cors>=0.4.0"],
"d": ["aiohttp>=3.6.0"],
"colorama": ["colorama>=0.4.3"],
"python2": ["typed-ast>=1.4.2"],
"uvloop": ["uvloop>=0.15.2"],
Expand Down
19 changes: 5 additions & 14 deletions src/blackd/__init__.py
Expand Up @@ -8,7 +8,7 @@

try:
from aiohttp import web
import aiohttp_cors
from .middlewares import cors
except ImportError as ie:
raise ImportError(
f"aiohttp dependency is not installed: {ie}. "
Expand Down Expand Up @@ -67,20 +67,11 @@ def main(bind_host: str, bind_port: int) -> None:


def make_app() -> web.Application:
app = web.Application()
executor = ProcessPoolExecutor()

cors = aiohttp_cors.setup(app)
resource = cors.add(app.router.add_resource("/"))
cors.add(
resource.add_route("POST", partial(handle, executor=executor)),
{
"*": aiohttp_cors.ResourceOptions(
allow_headers=(*BLACK_HEADERS, "Content-Type"), expose_headers="*"
)
},
app = web.Application(
middlewares=[cors(allow_headers=(*BLACK_HEADERS, "Content-Type"))]
)

executor = ProcessPoolExecutor()
app.add_routes([web.post("/", partial(handle, executor=executor))])
return app


Expand Down
34 changes: 34 additions & 0 deletions src/blackd/middlewares.py
@@ -0,0 +1,34 @@
from typing import Iterable, Awaitable, Callable
from aiohttp.web_response import StreamResponse
from aiohttp.web_request import Request
from aiohttp.web_middlewares import middleware

Handler = Callable[[Request], Awaitable[StreamResponse]]
Middleware = Callable[[Request, Handler], Awaitable[StreamResponse]]


def cors(allow_headers: Iterable[str]) -> Middleware:
@middleware
async def impl(request: Request, handler: Handler) -> StreamResponse:
is_options = request.method == "OPTIONS"
is_preflight = is_options and "Access-Control-Request-Method" in request.headers
if is_preflight:
resp = StreamResponse()
else:
resp = await handler(request)

origin = request.headers.get("Origin")
if not origin:
return resp

resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Expose-Headers"] = "*"
if is_options:
resp.headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers)
resp.headers["Access-Control-Allow-Methods"] = ", ".join(
("OPTIONS", "POST")
)

return resp

return impl # type: ignore
21 changes: 21 additions & 0 deletions tests/test_blackd.py
Expand Up @@ -164,3 +164,24 @@ async def test_blackd_invalid_line_length(self) -> None:
async def test_blackd_response_black_version_header(self) -> None:
response = await self.client.post("/")
self.assertIsNotNone(response.headers.get(blackd.BLACK_VERSION_HEADER))

@unittest_run_loop
async def test_cors_preflight(self) -> None:
response = await self.client.options(
"/",
headers={
"Access-Control-Request-Method": "POST",
"Origin": "*",
"Access-Control-Request-Headers": "Content-Type",
},
)
self.assertEqual(response.status, 200)
self.assertIsNotNone(response.headers.get("Access-Control-Allow-Origin"))
self.assertIsNotNone(response.headers.get("Access-Control-Allow-Headers"))
self.assertIsNotNone(response.headers.get("Access-Control-Allow-Methods"))

@unittest_run_loop
async def test_cors_headers_present(self) -> None:
response = await self.client.post("/", headers={"Origin": "*"})
self.assertIsNotNone(response.headers.get("Access-Control-Allow-Origin"))
self.assertIsNotNone(response.headers.get("Access-Control-Expose-Headers"))

0 comments on commit a5381ba

Please sign in to comment.