Skip to content

Commit

Permalink
Type annotate format checker methods
Browse files Browse the repository at this point in the history
The goal of this work is to apply annotations in `jsonschema._format`
which match the new additions to the type stubs recently added to
typeshed.

In short, it defines a `_FormatCheckCallable` as a callable --
typically function -- which takes any object and returns a bool.
All of the `is_*` functions are `_FormatCheckCallable`s, and the
decorators are carefully annotated as being type-preserving using a
typevar.

Because the returned objects from these functions are not always
bools, this changeset now calls `bool()` on those which are not (e.g.
`ipaddress.IPv4Address`). This is really just an explicit form of the
check which is going to happen in `conforms` and `check` anyway, so
there's no significant new cost. The advantage of this is that we have
documented (via the annotations) what a format check function is
supposed to do: it returns a bool. We could equally well return `Any`
from these functions, relying on `__bool__`, but this could confuse
new contributors and users.

One unfortunate side effect of these changes is that
`FormatChecker.cls_checks` needs to be expanded into a full duplicate
of `FormatChecker.checks`. `mypy` isn't able to properly understand
what `cls_checks = classmethod(checks)` does -- this is part of a
class of semi-sophisticated callable and method manipulations that are
known to be problematic for `mypy`. In order to get the annotations
correct, the simplest solutions are either to annotate it explicitly
(cast or type comment) or to expand it as this changeset has done.
  • Loading branch information
sirosen committed May 28, 2022
1 parent acc3ed2 commit 00b33bb
Showing 1 changed file with 61 additions and 47 deletions.
108 changes: 61 additions & 47 deletions jsonschema/_format.py
Expand Up @@ -9,6 +9,12 @@

from jsonschema.exceptions import FormatError

_FormatCheckCallable = typing.Callable[[object], bool]
_F = typing.TypeVar("_F", bound=_FormatCheckCallable)
_RaisesType = typing.Union[
typing.Type[Exception], typing.Tuple[typing.Type[Exception], ...]
]


class FormatChecker(object):
"""
Expand All @@ -35,13 +41,10 @@ class FormatChecker(object):

checkers: dict[
str,
tuple[
typing.Callable[[typing.Any], bool],
Exception | tuple[Exception, ...],
],
tuple[_FormatCheckCallable, _RaisesType],
] = {}

def __init__(self, formats=None):
def __init__(self, formats: typing.Iterable[str] | None = None):
if formats is None:
self.checkers = self.checkers.copy()
else:
Expand All @@ -50,7 +53,9 @@ def __init__(self, formats=None):
def __repr__(self):
return "<FormatChecker checkers={}>".format(sorted(self.checkers))

def checks(self, format, raises=()):
def checks(
self, format: str, raises: _RaisesType = ()
) -> typing.Callable[[_F], _F]:
"""
Register a decorated function as validating a new format.
Expand All @@ -70,14 +75,23 @@ def checks(self, format, raises=()):
resulting validation error.
"""

def _checks(func):
def _checks(func: _F) -> _F:
self.checkers[format] = (func, raises)
return func

return _checks

cls_checks = classmethod(checks)
@classmethod
def cls_checks(
cls, format: str, raises: _RaisesType = ()
) -> typing.Callable[[_F], _F]:
def _checks(func: _F) -> _F:
cls.checkers[format] = (func, raises)
return func

return _checks

def check(self, instance, format):
def check(self, instance: object, format: str) -> None:
"""
Check whether the instance conforms to the given format.
Expand Down Expand Up @@ -109,7 +123,7 @@ def check(self, instance, format):
if not result:
raise FormatError(f"{instance!r} is not a {format!r}", cause=cause)

def conforms(self, instance, format):
def conforms(self, instance: object, format: str) -> bool:
"""
Check whether the instance conforms to the given format.
Expand Down Expand Up @@ -143,7 +157,7 @@ def conforms(self, instance, format):
draft201909_format_checker = FormatChecker()
draft202012_format_checker = FormatChecker()

_draft_checkers = dict(
_draft_checkers: dict[str, FormatChecker] = dict(
draft3=draft3_format_checker,
draft4=draft4_format_checker,
draft6=draft6_format_checker,
Expand All @@ -162,15 +176,15 @@ def _checks_drafts(
draft201909=None,
draft202012=None,
raises=(),
):
) -> typing.Callable[[_F], _F]:
draft3 = draft3 or name
draft4 = draft4 or name
draft6 = draft6 or name
draft7 = draft7 or name
draft201909 = draft201909 or name
draft202012 = draft202012 or name

def wrap(func):
def wrap(func: _F) -> _F:
if draft3:
func = _draft_checkers["draft3"].checks(draft3, raises)(func)
if draft4:
Expand All @@ -195,12 +209,13 @@ def wrap(func):
raises,
)(func)
return func

return wrap


@_checks_drafts(name="idn-email")
@_checks_drafts(name="email")
def is_email(instance):
def is_email(instance: object) -> bool:
if not isinstance(instance, str):
return True
return "@" in instance
Expand All @@ -215,14 +230,14 @@ def is_email(instance):
draft202012="ipv4",
raises=ipaddress.AddressValueError,
)
def is_ipv4(instance):
def is_ipv4(instance: object) -> bool:
if not isinstance(instance, str):
return True
return ipaddress.IPv4Address(instance)
return bool(ipaddress.IPv4Address(instance))


@_checks_drafts(name="ipv6", raises=ipaddress.AddressValueError)
def is_ipv6(instance):
def is_ipv6(instance: object) -> bool:
if not isinstance(instance, str):
return True
address = ipaddress.IPv6Address(instance)
Expand All @@ -240,7 +255,7 @@ def is_ipv6(instance):
draft201909="hostname",
draft202012="hostname",
)
def is_host_name(instance):
def is_host_name(instance: object) -> bool:
if not isinstance(instance, str):
return True
return FQDN(instance).is_valid
Expand All @@ -256,7 +271,7 @@ def is_host_name(instance):
draft202012="idn-hostname",
raises=(idna.IDNAError, UnicodeError),
)
def is_idn_host_name(instance):
def is_idn_host_name(instance: object) -> bool:
if not isinstance(instance, str):
return True
idna.encode(instance)
Expand All @@ -270,7 +285,7 @@ def is_idn_host_name(instance):
from rfc3986_validator import validate_rfc3986

@_checks_drafts(name="uri")
def is_uri(instance):
def is_uri(instance: object) -> bool:
if not isinstance(instance, str):
return True
return validate_rfc3986(instance, rule="URI")
Expand All @@ -282,19 +297,20 @@ def is_uri(instance):
draft202012="uri-reference",
raises=ValueError,
)
def is_uri_reference(instance):
def is_uri_reference(instance: object) -> bool:
if not isinstance(instance, str):
return True
return validate_rfc3986(instance, rule="URI_reference")

else:

@_checks_drafts(
draft7="iri",
draft201909="iri",
draft202012="iri",
raises=ValueError,
)
def is_iri(instance):
def is_iri(instance: object) -> bool:
if not isinstance(instance, str):
return True
return rfc3987.parse(instance, rule="IRI")
Expand All @@ -305,13 +321,13 @@ def is_iri(instance):
draft202012="iri-reference",
raises=ValueError,
)
def is_iri_reference(instance):
def is_iri_reference(instance: object) -> bool:
if not isinstance(instance, str):
return True
return rfc3987.parse(instance, rule="IRI_reference")

@_checks_drafts(name="uri", raises=ValueError)
def is_uri(instance):
def is_uri(instance: object) -> bool:
if not isinstance(instance, str):
return True
return rfc3987.parse(instance, rule="URI")
Expand All @@ -323,16 +339,17 @@ def is_uri(instance):
draft202012="uri-reference",
raises=ValueError,
)
def is_uri_reference(instance):
def is_uri_reference(instance: object) -> bool:
if not isinstance(instance, str):
return True
return rfc3987.parse(instance, rule="URI_reference")


with suppress(ImportError):
from rfc3339_validator import validate_rfc3339

@_checks_drafts(name="date-time")
def is_datetime(instance):
def is_datetime(instance: object) -> bool:
if not isinstance(instance, str):
return True
return validate_rfc3339(instance.upper())
Expand All @@ -342,17 +359,17 @@ def is_datetime(instance):
draft201909="time",
draft202012="time",
)
def is_time(instance):
def is_time(instance: object) -> bool:
if not isinstance(instance, str):
return True
return is_datetime("1970-01-01T" + instance)


@_checks_drafts(name="regex", raises=re.error)
def is_regex(instance):
def is_regex(instance: object) -> bool:
if not isinstance(instance, str):
return True
return re.compile(instance)
return bool(re.compile(instance))


@_checks_drafts(
Expand All @@ -362,32 +379,29 @@ def is_regex(instance):
draft202012="date",
raises=ValueError,
)
def is_date(instance):
def is_date(instance: object) -> bool:
if not isinstance(instance, str):
return True
return instance.isascii() and datetime.date.fromisoformat(instance)
return bool(instance.isascii() and datetime.date.fromisoformat(instance))


@_checks_drafts(draft3="time", raises=ValueError)
def is_draft3_time(instance):
def is_draft3_time(instance: object) -> bool:
if not isinstance(instance, str):
return True
return datetime.datetime.strptime(instance, "%H:%M:%S")
return bool(datetime.datetime.strptime(instance, "%H:%M:%S"))


with suppress(ImportError):
from webcolors import CSS21_NAMES_TO_HEX
import webcolors

def is_css_color_code(instance):
def is_css_color_code(instance: object) -> bool:
return webcolors.normalize_hex(instance)

@_checks_drafts(draft3="color", raises=(ValueError, TypeError))
def is_css21_color(instance):
if (
not isinstance(instance, str)
or instance.lower() in CSS21_NAMES_TO_HEX
):
def is_css21_color(instance: object) -> bool:
if not isinstance(instance, str) or instance.lower() in CSS21_NAMES_TO_HEX:
return True
return is_css_color_code(instance)

Expand All @@ -402,10 +416,10 @@ def is_css21_color(instance):
draft202012="json-pointer",
raises=jsonpointer.JsonPointerException,
)
def is_json_pointer(instance):
def is_json_pointer(instance: object) -> bool:
if not isinstance(instance, str):
return True
return jsonpointer.JsonPointer(instance)
return bool(jsonpointer.JsonPointer(instance))

# TODO: I don't want to maintain this, so it
# needs to go either into jsonpointer (pending
Expand All @@ -417,7 +431,7 @@ def is_json_pointer(instance):
draft202012="relative-json-pointer",
raises=jsonpointer.JsonPointerException,
)
def is_relative_json_pointer(instance):
def is_relative_json_pointer(instance: object) -> bool:
# Definition taken from:
# https://tools.ietf.org/html/draft-handrews-relative-json-pointer-01#section-3
if not isinstance(instance, str):
Expand All @@ -437,7 +451,7 @@ def is_relative_json_pointer(instance):

rest = instance[i:]
break
return (rest == "#") or jsonpointer.JsonPointer(rest)
return (rest == "#") or bool(jsonpointer.JsonPointer(rest))


with suppress(ImportError):
Expand All @@ -449,7 +463,7 @@ def is_relative_json_pointer(instance):
draft201909="uri-template",
draft202012="uri-template",
)
def is_uri_template(instance):
def is_uri_template(instance: object) -> bool:
if not isinstance(instance, str):
return True
return uri_template.validate(instance)
Expand All @@ -463,18 +477,18 @@ def is_uri_template(instance):
draft202012="duration",
raises=isoduration.DurationParsingException,
)
def is_duration(instance):
def is_duration(instance: object) -> bool:
if not isinstance(instance, str):
return True
return isoduration.parse_duration(instance)
return bool(isoduration.parse_duration(instance))


@_checks_drafts(
draft201909="uuid",
draft202012="uuid",
raises=ValueError,
)
def is_uuid(instance):
def is_uuid(instance: object) -> bool:
if not isinstance(instance, str):
return True
UUID(instance)
Expand Down

0 comments on commit 00b33bb

Please sign in to comment.