Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pathlib update to TLS #2840

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 10 additions & 9 deletions sanic/http/tls/context.py
Expand Up @@ -3,6 +3,7 @@
import os
import ssl

from pathlib import Path, PurePath
from typing import Any, Dict, Iterable, Optional, Union

from sanic.log import logger
Expand Down Expand Up @@ -39,23 +40,23 @@ def create_context(


def shorthand_to_ctx(
ctxdef: Union[None, ssl.SSLContext, dict, str]
ctxdef: Union[None, ssl.SSLContext, dict, PurePath, str]
) -> Optional[ssl.SSLContext]:
"""Convert an ssl argument shorthand to an SSLContext object."""
if ctxdef is None or isinstance(ctxdef, ssl.SSLContext):
return ctxdef
if isinstance(ctxdef, str):
return load_cert_dir(ctxdef)
if isinstance(ctxdef, (PurePath, str)):
return load_cert_dir(Path(ctxdef))
if isinstance(ctxdef, dict):
return CertSimple(**ctxdef)
raise ValueError(
f"Invalid ssl argument {type(ctxdef)}."
" Expecting a list of certdirs, a dict or an SSLContext."
" Expecting one/list of: certdir | dict | SSLContext"
)


def process_to_context(
ssldef: Union[None, ssl.SSLContext, dict, str, list, tuple]
ssldef: Union[None, ssl.SSLContext, dict, PurePath, str, list, tuple]
) -> Optional[ssl.SSLContext]:
"""Process app.run ssl argument from easy formats to full SSLContext."""
return (
Expand All @@ -65,11 +66,11 @@ def process_to_context(
)


def load_cert_dir(p: str) -> ssl.SSLContext:
if os.path.isfile(p):
def load_cert_dir(p: Path) -> ssl.SSLContext:
if p.is_file():
raise ValueError(f"Certificate folder expected but {p} is a file.")
keyfile = os.path.join(p, "privkey.pem")
certfile = os.path.join(p, "fullchain.pem")
keyfile = p / "privkey.pem"
certfile = p / "fullchain.pem"
if not os.access(keyfile, os.R_OK):
raise ValueError(
f"Certificate not found or permission denied {keyfile}"
Expand Down
25 changes: 12 additions & 13 deletions tests/test_tls.py
@@ -1,5 +1,4 @@
import logging
import os
import ssl
import subprocess
import sys
Expand Down Expand Up @@ -32,18 +31,18 @@
from sanic.worker.loader import CertLoader


current_dir = os.path.dirname(os.path.realpath(__file__))
localhost_dir = os.path.join(current_dir, "certs/localhost")
password_dir = os.path.join(current_dir, "certs/password")
sanic_dir = os.path.join(current_dir, "certs/sanic.example")
invalid_dir = os.path.join(current_dir, "certs/invalid.nonexist")
localhost_cert = os.path.join(localhost_dir, "fullchain.pem")
localhost_key = os.path.join(localhost_dir, "privkey.pem")
sanic_cert = os.path.join(sanic_dir, "fullchain.pem")
sanic_key = os.path.join(sanic_dir, "privkey.pem")
current_dir = Path(__file__).parent.resolve()
localhost_dir = current_dir / "certs/localhost"
password_dir = current_dir / "certs/password"
sanic_dir = current_dir / "certs/sanic.example"
invalid_dir = current_dir / "certs/invalid.nonexist"
localhost_cert = localhost_dir / "fullchain.pem"
localhost_key = localhost_dir / "privkey.pem"
sanic_cert = sanic_dir / "fullchain.pem"
sanic_key = sanic_dir / "privkey.pem"
password_dict = {
"cert": os.path.join(password_dir, "fullchain.pem"),
"key": os.path.join(password_dir, "privkey.pem"),
"cert": password_dir / "fullchain.pem",
"key": password_dir / "privkey.pem",
"password": "password",
"names": ["localhost"],
}
Expand Down Expand Up @@ -383,7 +382,7 @@ async def handler(request):
app.test_client.get("/test", server_kwargs={"ssl": ssl_list})

assert "folder expected" in str(excinfo.value)
assert sanic_cert in str(excinfo.value)
assert str(sanic_cert) in str(excinfo.value)


def test_missing_cert_path(app):
Expand Down