diff --git a/sanic_routing/patterns.py b/sanic_routing/patterns.py index 0ea604c..9dd816b 100644 --- a/sanic_routing/patterns.py +++ b/sanic_routing/patterns.py @@ -1,6 +1,11 @@ import re +import typing as t import uuid from datetime import date, datetime +from types import SimpleNamespace +from typing import Any, Callable, Dict, Pattern, Tuple, Type + +from sanic_routing.exceptions import InvalidUsage, NotFound def parse_date(d) -> date: @@ -19,13 +24,120 @@ def slug(param: str) -> str: return param +def ext(param: str) -> Tuple[str, ...]: + parts = tuple(param.split(".")) + if any(not p for p in parts) or len(parts) == 1: + raise ValueError(f"Value {param} does not match filename format") + return parts + + def nonemptystr(param: str) -> str: if not param: raise ValueError(f"Value {param} is an empty string") return param +class ParamInfo: + __slots__ = ( + "cast", + "ctx", + "label", + "name", + "pattern", + "priority", + "raw_path", + "regex", + ) + + def __init__( + self, + name: str, + raw_path: str, + label: str, + cast: t.Callable[[str], t.Any], + pattern: re.Pattern, + regex: bool, + priority: int, + ) -> None: + self.name = name + self.raw_path = raw_path + self.label = label + self.cast = cast + self.pattern = pattern + self.regex = regex + self.priority = priority + self.ctx = SimpleNamespace() + + def process( + self, + params: t.Dict[str, t.Any], + value: t.Union[str, t.Tuple[str, ...]], + ) -> None: + params[self.name] = value + + +class ExtParamInfo(ParamInfo): + def __init__(self, **kwargs): + super().__init__(**kwargs) + match = REGEX_PARAM_NAME_EXT.match(self.raw_path) + if not match: + raise InvalidUsage( + f"Invalid extension parameter definition: {self.raw_path}" + ) + if match.group(2) == "path": + raise InvalidUsage( + "Extension parameter matching does not support the " + "`path` type." + ) + ext_type = match.group(3) + regex_type = REGEX_TYPES.get(match.group(2)) + self.ctx.cast = None + if regex_type: + self.ctx.cast = regex_type[0] + elif match.group(2): + raise InvalidUsage( + "Extension parameter matching only supports filename matching " + "on known parameter types, and not regular expressions." + ) + self.ctx.allowed = [] + self.ctx.allowed_sub_count = 0 + if ext_type: + self.ctx.allowed = ext_type.split("|") + allowed_subs = {allowed.count(".") for allowed in self.ctx.allowed} + if len(allowed_subs) > 1: + raise InvalidUsage( + "All allowed extensions within a single route definition " + "must contain the same number of subparts. For example: " + " and are both " + "acceptable, but is not." + ) + self.ctx.allowed_sub_count = next(iter(allowed_subs)) + + for extension in self.ctx.allowed: + if not REGEX_ALLOWED_EXTENSION.match(extension): + raise InvalidUsage(f"Invalid extension: {extension}") + + def process(self, params, value): + stop = -1 * (self.ctx.allowed_sub_count + 1) + filename = ".".join(value[:stop]) + ext = ".".join(value[stop:]) + if self.ctx.allowed and ext not in self.ctx.allowed: + raise NotFound(f"Invalid extension: {ext}") + if self.ctx.cast: + try: + filename = self.ctx.cast(filename) + except ValueError: + raise NotFound(f"Invalid filename: {filename}") + params[self.name] = filename + params["ext"] = ext + + +EXTENSION = r"[a-z0-9](?:[a-z0-9\.]*[a-z0-9])?" REGEX_PARAM_NAME = re.compile(r"^<([a-zA-Z_][a-zA-Z0-9_]*)(?::(.*))?>$") +REGEX_PARAM_NAME_EXT = re.compile( + r"^<([a-zA-Z_][a-zA-Z0-9_]*)(?:=([a-z]+))?(?::ext(?:=([a-z0-9|\.]+))?)>$" +) +REGEX_ALLOWED_EXTENSION = re.compile(r"^" + EXTENSION + r"$") # Predefined path parameter types. The value is a tuple consisteing of a # callable and a compiled regular expression. @@ -35,17 +147,22 @@ def nonemptystr(param: str) -> str: # 3. raise ValueError if it cannot # The regular expression is generally NOT used. Unless the path is forced # to use regex patterns. -REGEX_TYPES = { - "strorempty": (str, re.compile(r"^[^/]*$")), - "str": (nonemptystr, re.compile(r"^[^/]+$")), - "slug": (slug, re.compile(r"^[a-z0-9]+(?:-[a-z0-9]+)*$")), - "alpha": (alpha, re.compile(r"^[A-Za-z]+$")), - "path": (str, re.compile(r"^[^/]?.*?$")), - "float": (float, re.compile(r"^-?(?:\d+(?:\.\d*)?|\.\d+)$")), - "int": (int, re.compile(r"^-?\d+$")), +REGEX_TYPES_ANNOTATION = Dict[ + str, Tuple[Callable[[str], Any], Pattern, Type[ParamInfo]] +] +REGEX_TYPES: REGEX_TYPES_ANNOTATION = { + "strorempty": (str, re.compile(r"^[^/]*$"), ParamInfo), + "str": (nonemptystr, re.compile(r"^[^/]+$"), ParamInfo), + "ext": (ext, re.compile(r"^[^/]+\." + EXTENSION + r"$"), ExtParamInfo), + "slug": (slug, re.compile(r"^[a-z0-9]+(?:-[a-z0-9]+)*$"), ParamInfo), + "alpha": (alpha, re.compile(r"^[A-Za-z]+$"), ParamInfo), + "path": (str, re.compile(r"^[^/]?.*?$"), ParamInfo), + "float": (float, re.compile(r"^-?(?:\d+(?:\.\d*)?|\.\d+)$"), ParamInfo), + "int": (int, re.compile(r"^-?\d+$"), ParamInfo), "ymd": ( parse_date, re.compile(r"^([12]\d{3}-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01]))$"), + ParamInfo, ), "uuid": ( uuid.UUID, @@ -53,5 +170,6 @@ def nonemptystr(param: str) -> str: r"^[A-Fa-f0-9]{8}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-" r"[A-Fa-f0-9]{4}-[A-Fa-f0-9]{12}$" ), + ParamInfo, ), } diff --git a/sanic_routing/route.py b/sanic_routing/route.py index f2ff972..355566d 100644 --- a/sanic_routing/route.py +++ b/sanic_routing/route.py @@ -1,17 +1,12 @@ import re import typing as t -from collections import namedtuple from types import SimpleNamespace from warnings import warn from .exceptions import InvalidUsage, ParameterNameConflicts +from .patterns import ParamInfo from .utils import Immutable, parts_to_path, path_to_parts -ParamInfo = namedtuple( - "ParamInfo", - ("name", "raw_path", "label", "cast", "pattern", "regex", "priority"), -) - class Requirements(Immutable): def __hash__(self): @@ -169,9 +164,17 @@ def _setup_params(self): label, _type, pattern, + param_info_class, ) = self.parse_parameter_string(part[1:-1]) + self.add_parameter( - idx, name, key_path, label, _type, pattern + idx, + name, + key_path, + label, + _type, + pattern, + param_info_class, ) def add_parameter( @@ -182,6 +185,7 @@ def add_parameter( label: str, cast: t.Type, pattern=None, + param_info_class=ParamInfo, ): if pattern and isinstance(pattern, str): if not pattern.startswith("^"): @@ -197,8 +201,14 @@ def add_parameter( if is_regex else list(self.router.regex_types.keys()).index(label) ) - self._params[idx] = ParamInfo( - name, raw_path, label, cast, pattern, is_regex, priority + self._params[idx] = param_info_class( + name=name, + raw_path=raw_path, + label=label, + cast=cast, + pattern=pattern, + regex=is_regex, + priority=priority, ) def _finalize_params(self): @@ -210,16 +220,25 @@ def _finalize_params(self): f"Duplicate named parameters in: {self._raw_path}" ) self.labels = labels + self.params = dict( sorted(params.items(), key=lambda param: self._sorting(param[1])) ) + if not self.regex and any( + ":" in param.label for param in self.params.values() + ): + raise InvalidUsage( + f"Invalid parameter declaration: {self.raw_path}" + ) + def _compile_regex(self): components = [] for part in self.parts: if part.startswith("<"): - name, *_, pattern = self.parse_parameter_string(part) + name, *_, pattern, __ = self.parse_parameter_string(part) + if not isinstance(pattern, str): pattern = pattern.pattern.strip("^$") compiled = re.compile(pattern) @@ -316,8 +335,14 @@ def parse_parameter_string(self, parameter_string: str): parameter_string = parameter_string.strip("<>") name = parameter_string label = "str" + if ":" in parameter_string: name, label = parameter_string.split(":", 1) + if "=" in label: + label, _ = label.split("=", 1) + if "=" in name: + name, _ = name.split("=", 1) + if not name: raise ValueError( f"Invalid parameter syntax: {parameter_string}" @@ -337,7 +362,9 @@ def parse_parameter_string(self, parameter_string: str): DeprecationWarning, ) - default = (str, label) + default = (str, label, ParamInfo) + # Pull from pre-configured types - _type, pattern = self.router.regex_types.get(label, default) - return name, label, _type, pattern + found = self.router.regex_types.get(label, default) + _type, pattern, param_info_class = found + return name, label, _type, pattern, param_info_class diff --git a/sanic_routing/router.py b/sanic_routing/router.py index 4f586e5..bd0b052 100644 --- a/sanic_routing/router.py +++ b/sanic_routing/router.py @@ -6,6 +6,7 @@ from warnings import warn from sanic_routing.group import RouteGroup +from sanic_routing.patterns import ParamInfo from .exceptions import ( BadMethod, @@ -15,7 +16,7 @@ NotFound, ) from .line import Line -from .patterns import REGEX_TYPES +from .patterns import REGEX_TYPES, REGEX_TYPES_ANNOTATION from .route import Route from .tree import Node, Tree from .utils import parts_to_path, path_to_parts @@ -60,7 +61,10 @@ def __init__( self.ctx = SimpleNamespace() self.cascade_not_found = cascade_not_found - self.regex_types = {**REGEX_TYPES} + self.regex_types: REGEX_TYPES_ANNOTATION = {} + + for label, (cast, pattern, param_info_class) in REGEX_TYPES.items(): + self.register_pattern(label, cast, pattern, param_info_class) @abstractmethod def get(self, **kwargs): @@ -106,21 +110,25 @@ def resolve( # Convert matched values to parameters params = param_basket["__params__"] - if route.regex: - params.update( - { - param.name: param.cast( - param_basket["__params__"][param.name] - ) - for param in route.params.values() - if param.cast is not str - } - ) - elif param_basket["__matches__"]: - params = { - param.name: param_basket["__matches__"][idx] - for idx, param in route.params.items() - } + if not params or param_basket["__matches__"]: + # If param_basket["__params__"] does not exist, we might have + # param_basket["__matches__"], which are indexed based matches + # on path segments. They should already be cast types. + for idx, param in route.params.items(): + # If the param index does not exist, then rely upon + # the __params__ + try: + value = param_basket["__matches__"][idx] + except KeyError: + continue + + # Apply if tuple (from ext) or if it is not a regex matcher + if isinstance(value, tuple): + param.process(params, value) + elif not route.regex or ( + route.regex and param.cast is not str + ): + params[param.name] = value # Double check that if we made a match it is not a false positive # because of strict_slashes @@ -248,6 +256,7 @@ def register_pattern( label: str, cast: t.Callable[[str], t.Any], pattern: t.Union[t.Pattern, str], + param_info_class: t.Type[ParamInfo] = ParamInfo, ): """ Add a custom parameter type to the router. The cast should raise a @@ -288,7 +297,7 @@ def register_pattern( pattern = re.compile(pattern) globals()[cast.__name__] = cast - self.regex_types[label] = (cast, pattern) + self.regex_types[label] = (cast, pattern, param_info_class) def finalize(self, do_compile: bool = True, do_optimize: bool = False): """ @@ -605,7 +614,7 @@ def requires(part): if not part.startswith("<") or ":" not in part: return False - _, pattern_type = part[1:-1].split(":", 1) + _, pattern_type, *__ = part[1:-1].split(":") return ( part.endswith(":path>") diff --git a/sanic_routing/tree.py b/sanic_routing/tree.py index 23c0e52..4825152 100644 --- a/sanic_routing/tree.py +++ b/sanic_routing/tree.py @@ -3,7 +3,7 @@ from .group import RouteGroup from .line import Line -from .patterns import REGEX_PARAM_NAME +from .patterns import REGEX_PARAM_NAME, REGEX_PARAM_NAME_EXT logger = getLogger("sanic.root") @@ -440,7 +440,9 @@ def generate(self, groups: t.Iterable[RouteGroup]) -> None: param = None dynamic = part.startswith("<") if dynamic: - if not REGEX_PARAM_NAME.match(part): + if not REGEX_PARAM_NAME.match( + part + ) and not REGEX_PARAM_NAME_EXT.match(part): raise ValueError(f"Invalid declaration: {part}") part = f"__dynamic__:{group.params[level].label}" param = group.params[level] diff --git a/sanic_routing/utils.py b/sanic_routing/utils.py index c19ece4..7624b9b 100644 --- a/sanic_routing/utils.py +++ b/sanic_routing/utils.py @@ -1,7 +1,9 @@ import re from urllib.parse import quote, unquote -from .patterns import REGEX_PARAM_NAME +from sanic_routing.exceptions import InvalidUsage + +from .patterns import REGEX_PARAM_NAME, REGEX_PARAM_NAME_EXT class Immutable(dict): @@ -74,7 +76,21 @@ def parts_to_path(parts, delimiter="/"): param_type = f":{match.group(2)}" path.append(f"<{match.group(1)}{param_type}>") except AttributeError: - raise ValueError(f"Invalid declaration: {part}") + try: + match = REGEX_PARAM_NAME_EXT.match(part) + filename_type = "" + extension_type = "" + if match.group(2): + filename_type = f"={match.group(2)}" + if match.group(3): + extension_type = f"={match.group(3)}" + segment = ( + f"<{match.group(1)}{filename_type}:" + f"ext{extension_type}>" + ) + path.append(segment) + except AttributeError: + raise InvalidUsage(f"Invalid declaration: {part}") else: path.append(part) return delimiter.join(path) diff --git a/tests/test_builtin_param_types.py b/tests/test_builtin_param_types.py index b08e503..43c5218 100644 --- a/tests/test_builtin_param_types.py +++ b/tests/test_builtin_param_types.py @@ -2,7 +2,7 @@ import pytest from sanic_routing import BaseRouter -from sanic_routing.exceptions import NotFound +from sanic_routing.exceptions import InvalidUsage, NotFound @pytest.fixture @@ -134,6 +134,89 @@ def test_correct_slug_v_string(handler): assert retval == "FooBar" +@pytest.mark.parametrize( + "value", ("somefile.txt", "SomeFile.mp3", "some.thing", "with.extra.dot") +) +def test_ext_not_defined_matches(value): + def handler(**kwargs): + return kwargs + + router = Router() + + router.add("/", handler) + router.finalize() + + _, handler, params = router.get(f"/{value}", "BASE") + retval = handler(**params) + + filename, ext = value.rsplit(".", 1) + assert retval["filename"] == filename + assert retval["ext"] == ext + + +@pytest.mark.parametrize("value", ("somefile.mp3", "with.extra.mp3")) +def test_ext_single_defined_matches(value): + def handler(**kwargs): + return kwargs + + router = Router() + + router.add("/", handler) + router.finalize() + + _, handler, params = router.get(f"/{value}", "BASE") + retval = handler(**params) + + filename, ext = value.rsplit(".", 1) + assert retval["filename"] == filename + assert retval["ext"] == ext + + +@pytest.mark.parametrize( + "value", + ("somefile.png", "with.extra.png", "somefile.jpg", "with.extra.jpg"), +) +def test_ext_multiple_defined_matches(value): + def handler(**kwargs): + return kwargs + + router = Router() + + router.add("/", handler) + router.finalize() + + _, handler, params = router.get(f"/{value}", "BASE") + retval = handler(**params) + + filename, ext = value.rsplit(".", 1) + assert retval["filename"] == filename + assert retval["ext"] == ext + + +@pytest.mark.parametrize( + "path", + ( + "/", + "/", + "/", + ), +) +def test_ext_multiple_defined_filename_types(path): + def handler(**kwargs): + return kwargs + + router = Router() + + router.add(path, handler) + router.finalize() + + _, handler, params = router.get("/123.txt", "BASE") + retval = handler(**params) + + assert retval["filename"] == 123 + assert retval["ext"] == "txt" + + @pytest.mark.parametrize( "value,matches", ( @@ -165,6 +248,75 @@ def test(path): test(path) +@pytest.mark.parametrize( + "value", + ("somefile", "SomeFile."), +) +def test_ext_not_defined_no_matches(handler, value): + def handler(**kwargs): + return kwargs + + router = Router() + + router.add("/", handler) + router.finalize() + + with pytest.raises(NotFound): + router.get(f"/{value}", "BASE") + + +@pytest.mark.parametrize( + "value", + ("somefile", "SomeFile.", "somefile.jpg"), +) +def test_ext_single_defined_no_matches(handler, value): + def handler(**kwargs): + return kwargs + + router = Router() + + router.add("/", handler) + router.finalize() + + with pytest.raises(NotFound): + router.get(f"/{value}", "BASE") + + +@pytest.mark.parametrize( + "value", + ("somefile", "SomeFile.", "somefile.txt"), +) +def test_ext_multiple_defined_no_matches(handler, value): + def handler(**kwargs): + return kwargs + + router = Router() + + router.add("/", handler) + router.finalize() + + with pytest.raises(NotFound): + router.get(f"/{value}", "BASE") + + +@pytest.mark.parametrize( + "definition", + ( + "", + "", + "", + "", + ), +) +def test_bad_ext_definition(handler, definition): + router = Router() + + with pytest.raises(InvalidUsage): + router.add(f"/{definition}", handler) + + @pytest.mark.parametrize( "value", ( @@ -232,3 +384,11 @@ def test_empty_hierarchy(): assert params == expected handler1.assert_not_called() handler2.assert_called_once_with(**expected) + + +def test_invalid_def(handler): + router = Router() + router.add("/one//", handler) + + with pytest.raises(InvalidUsage): + router.finalize() diff --git a/tests/test_routing.py b/tests/test_routing.py index 152b48c..dc04b2b 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -2,6 +2,7 @@ from datetime import date import pytest + from sanic_routing import BaseRouter from sanic_routing.exceptions import NoMethod, NotFound, RouteExists @@ -452,7 +453,13 @@ def handler2(): assert params == {"foo": f"{uri}"} -@pytest.mark.parametrize("uri", ("a-random-path", "a/random/path")) +@pytest.mark.parametrize( + "uri", + ( + "a-random-path", + "a/random/path", + ), +) def test_identical_path_routes_with_different_methods_complex(uri): def handler1(): return "handler1"