From 00b33bb2a8fbc3a52dc75e5d3bdfd2dcf66869c6 Mon Sep 17 00:00:00 2001 From: Stephen Rosen Date: Sat, 28 May 2022 21:45:35 +0000 Subject: [PATCH] Type annotate format checker methods The goal of this work is to apply annotations in `jsonschema._format` which match the new additions to the type stubs recently added to typeshed. In short, it defines a `_FormatCheckCallable` as a callable -- typically function -- which takes any object and returns a bool. All of the `is_*` functions are `_FormatCheckCallable`s, and the decorators are carefully annotated as being type-preserving using a typevar. Because the returned objects from these functions are not always bools, this changeset now calls `bool()` on those which are not (e.g. `ipaddress.IPv4Address`). This is really just an explicit form of the check which is going to happen in `conforms` and `check` anyway, so there's no significant new cost. The advantage of this is that we have documented (via the annotations) what a format check function is supposed to do: it returns a bool. We could equally well return `Any` from these functions, relying on `__bool__`, but this could confuse new contributors and users. One unfortunate side effect of these changes is that `FormatChecker.cls_checks` needs to be expanded into a full duplicate of `FormatChecker.checks`. `mypy` isn't able to properly understand what `cls_checks = classmethod(checks)` does -- this is part of a class of semi-sophisticated callable and method manipulations that are known to be problematic for `mypy`. In order to get the annotations correct, the simplest solutions are either to annotate it explicitly (cast or type comment) or to expand it as this changeset has done. --- jsonschema/_format.py | 108 ++++++++++++++++++++++++------------------ 1 file changed, 61 insertions(+), 47 deletions(-) diff --git a/jsonschema/_format.py b/jsonschema/_format.py index bd8377f35..9f240b3db 100644 --- a/jsonschema/_format.py +++ b/jsonschema/_format.py @@ -9,6 +9,12 @@ from jsonschema.exceptions import FormatError +_FormatCheckCallable = typing.Callable[[object], bool] +_F = typing.TypeVar("_F", bound=_FormatCheckCallable) +_RaisesType = typing.Union[ + typing.Type[Exception], typing.Tuple[typing.Type[Exception], ...] +] + class FormatChecker(object): """ @@ -35,13 +41,10 @@ class FormatChecker(object): checkers: dict[ str, - tuple[ - typing.Callable[[typing.Any], bool], - Exception | tuple[Exception, ...], - ], + tuple[_FormatCheckCallable, _RaisesType], ] = {} - def __init__(self, formats=None): + def __init__(self, formats: typing.Iterable[str] | None = None): if formats is None: self.checkers = self.checkers.copy() else: @@ -50,7 +53,9 @@ def __init__(self, formats=None): def __repr__(self): return "".format(sorted(self.checkers)) - def checks(self, format, raises=()): + def checks( + self, format: str, raises: _RaisesType = () + ) -> typing.Callable[[_F], _F]: """ Register a decorated function as validating a new format. @@ -70,14 +75,23 @@ def checks(self, format, raises=()): resulting validation error. """ - def _checks(func): + def _checks(func: _F) -> _F: self.checkers[format] = (func, raises) return func + return _checks - cls_checks = classmethod(checks) + @classmethod + def cls_checks( + cls, format: str, raises: _RaisesType = () + ) -> typing.Callable[[_F], _F]: + def _checks(func: _F) -> _F: + cls.checkers[format] = (func, raises) + return func + + return _checks - def check(self, instance, format): + def check(self, instance: object, format: str) -> None: """ Check whether the instance conforms to the given format. @@ -109,7 +123,7 @@ def check(self, instance, format): if not result: raise FormatError(f"{instance!r} is not a {format!r}", cause=cause) - def conforms(self, instance, format): + def conforms(self, instance: object, format: str) -> bool: """ Check whether the instance conforms to the given format. @@ -143,7 +157,7 @@ def conforms(self, instance, format): draft201909_format_checker = FormatChecker() draft202012_format_checker = FormatChecker() -_draft_checkers = dict( +_draft_checkers: dict[str, FormatChecker] = dict( draft3=draft3_format_checker, draft4=draft4_format_checker, draft6=draft6_format_checker, @@ -162,7 +176,7 @@ def _checks_drafts( draft201909=None, draft202012=None, raises=(), -): +) -> typing.Callable[[_F], _F]: draft3 = draft3 or name draft4 = draft4 or name draft6 = draft6 or name @@ -170,7 +184,7 @@ def _checks_drafts( draft201909 = draft201909 or name draft202012 = draft202012 or name - def wrap(func): + def wrap(func: _F) -> _F: if draft3: func = _draft_checkers["draft3"].checks(draft3, raises)(func) if draft4: @@ -195,12 +209,13 @@ def wrap(func): raises, )(func) return func + return wrap @_checks_drafts(name="idn-email") @_checks_drafts(name="email") -def is_email(instance): +def is_email(instance: object) -> bool: if not isinstance(instance, str): return True return "@" in instance @@ -215,14 +230,14 @@ def is_email(instance): draft202012="ipv4", raises=ipaddress.AddressValueError, ) -def is_ipv4(instance): +def is_ipv4(instance: object) -> bool: if not isinstance(instance, str): return True - return ipaddress.IPv4Address(instance) + return bool(ipaddress.IPv4Address(instance)) @_checks_drafts(name="ipv6", raises=ipaddress.AddressValueError) -def is_ipv6(instance): +def is_ipv6(instance: object) -> bool: if not isinstance(instance, str): return True address = ipaddress.IPv6Address(instance) @@ -240,7 +255,7 @@ def is_ipv6(instance): draft201909="hostname", draft202012="hostname", ) - def is_host_name(instance): + def is_host_name(instance: object) -> bool: if not isinstance(instance, str): return True return FQDN(instance).is_valid @@ -256,7 +271,7 @@ def is_host_name(instance): draft202012="idn-hostname", raises=(idna.IDNAError, UnicodeError), ) - def is_idn_host_name(instance): + def is_idn_host_name(instance: object) -> bool: if not isinstance(instance, str): return True idna.encode(instance) @@ -270,7 +285,7 @@ def is_idn_host_name(instance): from rfc3986_validator import validate_rfc3986 @_checks_drafts(name="uri") - def is_uri(instance): + def is_uri(instance: object) -> bool: if not isinstance(instance, str): return True return validate_rfc3986(instance, rule="URI") @@ -282,19 +297,20 @@ def is_uri(instance): draft202012="uri-reference", raises=ValueError, ) - def is_uri_reference(instance): + def is_uri_reference(instance: object) -> bool: if not isinstance(instance, str): return True return validate_rfc3986(instance, rule="URI_reference") else: + @_checks_drafts( draft7="iri", draft201909="iri", draft202012="iri", raises=ValueError, ) - def is_iri(instance): + def is_iri(instance: object) -> bool: if not isinstance(instance, str): return True return rfc3987.parse(instance, rule="IRI") @@ -305,13 +321,13 @@ def is_iri(instance): draft202012="iri-reference", raises=ValueError, ) - def is_iri_reference(instance): + def is_iri_reference(instance: object) -> bool: if not isinstance(instance, str): return True return rfc3987.parse(instance, rule="IRI_reference") @_checks_drafts(name="uri", raises=ValueError) - def is_uri(instance): + def is_uri(instance: object) -> bool: if not isinstance(instance, str): return True return rfc3987.parse(instance, rule="URI") @@ -323,16 +339,17 @@ def is_uri(instance): draft202012="uri-reference", raises=ValueError, ) - def is_uri_reference(instance): + def is_uri_reference(instance: object) -> bool: if not isinstance(instance, str): return True return rfc3987.parse(instance, rule="URI_reference") + with suppress(ImportError): from rfc3339_validator import validate_rfc3339 @_checks_drafts(name="date-time") - def is_datetime(instance): + def is_datetime(instance: object) -> bool: if not isinstance(instance, str): return True return validate_rfc3339(instance.upper()) @@ -342,17 +359,17 @@ def is_datetime(instance): draft201909="time", draft202012="time", ) - def is_time(instance): + def is_time(instance: object) -> bool: if not isinstance(instance, str): return True return is_datetime("1970-01-01T" + instance) @_checks_drafts(name="regex", raises=re.error) -def is_regex(instance): +def is_regex(instance: object) -> bool: if not isinstance(instance, str): return True - return re.compile(instance) + return bool(re.compile(instance)) @_checks_drafts( @@ -362,32 +379,29 @@ def is_regex(instance): draft202012="date", raises=ValueError, ) -def is_date(instance): +def is_date(instance: object) -> bool: if not isinstance(instance, str): return True - return instance.isascii() and datetime.date.fromisoformat(instance) + return bool(instance.isascii() and datetime.date.fromisoformat(instance)) @_checks_drafts(draft3="time", raises=ValueError) -def is_draft3_time(instance): +def is_draft3_time(instance: object) -> bool: if not isinstance(instance, str): return True - return datetime.datetime.strptime(instance, "%H:%M:%S") + return bool(datetime.datetime.strptime(instance, "%H:%M:%S")) with suppress(ImportError): from webcolors import CSS21_NAMES_TO_HEX import webcolors - def is_css_color_code(instance): + def is_css_color_code(instance: object) -> bool: return webcolors.normalize_hex(instance) @_checks_drafts(draft3="color", raises=(ValueError, TypeError)) - def is_css21_color(instance): - if ( - not isinstance(instance, str) - or instance.lower() in CSS21_NAMES_TO_HEX - ): + def is_css21_color(instance: object) -> bool: + if not isinstance(instance, str) or instance.lower() in CSS21_NAMES_TO_HEX: return True return is_css_color_code(instance) @@ -402,10 +416,10 @@ def is_css21_color(instance): draft202012="json-pointer", raises=jsonpointer.JsonPointerException, ) - def is_json_pointer(instance): + def is_json_pointer(instance: object) -> bool: if not isinstance(instance, str): return True - return jsonpointer.JsonPointer(instance) + return bool(jsonpointer.JsonPointer(instance)) # TODO: I don't want to maintain this, so it # needs to go either into jsonpointer (pending @@ -417,7 +431,7 @@ def is_json_pointer(instance): draft202012="relative-json-pointer", raises=jsonpointer.JsonPointerException, ) - def is_relative_json_pointer(instance): + def is_relative_json_pointer(instance: object) -> bool: # Definition taken from: # https://tools.ietf.org/html/draft-handrews-relative-json-pointer-01#section-3 if not isinstance(instance, str): @@ -437,7 +451,7 @@ def is_relative_json_pointer(instance): rest = instance[i:] break - return (rest == "#") or jsonpointer.JsonPointer(rest) + return (rest == "#") or bool(jsonpointer.JsonPointer(rest)) with suppress(ImportError): @@ -449,7 +463,7 @@ def is_relative_json_pointer(instance): draft201909="uri-template", draft202012="uri-template", ) - def is_uri_template(instance): + def is_uri_template(instance: object) -> bool: if not isinstance(instance, str): return True return uri_template.validate(instance) @@ -463,10 +477,10 @@ def is_uri_template(instance): draft202012="duration", raises=isoduration.DurationParsingException, ) - def is_duration(instance): + def is_duration(instance: object) -> bool: if not isinstance(instance, str): return True - return isoduration.parse_duration(instance) + return bool(isoduration.parse_duration(instance)) @_checks_drafts( @@ -474,7 +488,7 @@ def is_duration(instance): draft202012="uuid", raises=ValueError, ) -def is_uuid(instance): +def is_uuid(instance: object) -> bool: if not isinstance(instance, str): return True UUID(instance)