Skip to content

Commit

Permalink
Merge pull request #958 from sirosen/bool-format-checks
Browse files Browse the repository at this point in the history
Type annotate format checker methods
  • Loading branch information
Julian committed Jun 2, 2022
2 parents fd3a457 + ded065d commit b1c1d00
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 45 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Expand Up @@ -25,6 +25,7 @@
"sphinx.ext.intersphinx",
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinx_autodoc_typehints",
"sphinxcontrib.spelling",
"jsonschema_role",
]
Expand Down
1 change: 1 addition & 0 deletions docs/requirements.in
Expand Up @@ -2,4 +2,5 @@ file:.#egg=jsonschema
furo
lxml
sphinx
sphinx-autodoc-typehints
sphinxcontrib-spelling
7 changes: 5 additions & 2 deletions docs/requirements.txt
Expand Up @@ -28,7 +28,7 @@ jinja2==3.1.2
# via sphinx
file:.#egg=jsonschema
# via -r docs/requirements.in
lxml==4.8.0
lxml==4.9.0
# via -r docs/requirements.in
markupsafe==2.1.1
# via jinja2
Expand Down Expand Up @@ -56,7 +56,10 @@ sphinx==4.5.0
# via
# -r docs/requirements.in
# furo
# sphinx-autodoc-typehints
# sphinxcontrib-spelling
sphinx-autodoc-typehints==1.18.1
# via -r docs/requirements.in
sphinxcontrib-applehelp==1.0.2
# via sphinx
sphinxcontrib-devhelp==1.0.2
Expand All @@ -69,7 +72,7 @@ sphinxcontrib-qthelp==1.0.3
# via sphinx
sphinxcontrib-serializinghtml==1.1.5
# via sphinx
sphinxcontrib-spelling==7.4.1
sphinxcontrib-spelling==7.5.0
# via -r docs/requirements.in
urllib3==1.26.9
# via requests
103 changes: 60 additions & 43 deletions jsonschema/_format.py
Expand Up @@ -9,6 +9,12 @@

from jsonschema.exceptions import FormatError

_FormatCheckCallable = typing.Callable[[object], bool]
_F = typing.TypeVar("_F", bound=_FormatCheckCallable)
_RaisesType = typing.Union[
typing.Type[Exception], typing.Tuple[typing.Type[Exception], ...],
]


class FormatChecker(object):
"""
Expand All @@ -35,13 +41,10 @@ class FormatChecker(object):

checkers: dict[
str,
tuple[
typing.Callable[[typing.Any], bool],
Exception | tuple[Exception, ...],
],
tuple[_FormatCheckCallable, _RaisesType],
] = {}

def __init__(self, formats=None):
def __init__(self, formats: typing.Iterable[str] | None = None):
if formats is None:
self.checkers = self.checkers.copy()
else:
Expand All @@ -50,7 +53,9 @@ def __init__(self, formats=None):
def __repr__(self):
return "<FormatChecker checkers={}>".format(sorted(self.checkers))

def checks(self, format, raises=()):
def checks(
self, format: str, raises: _RaisesType = (),
) -> typing.Callable[[_F], _F]:
"""
Register a decorated function as validating a new format.
Expand All @@ -70,14 +75,23 @@ def checks(self, format, raises=()):
resulting validation error.
"""

def _checks(func):
def _checks(func: _F) -> _F:
self.checkers[format] = (func, raises)
return func

return _checks

cls_checks = classmethod(checks)
@classmethod
def cls_checks(
cls, format: str, raises: _RaisesType = (),
) -> typing.Callable[[_F], _F]:
def _checks(func: _F) -> _F:
cls.checkers[format] = (func, raises)
return func

return _checks

def check(self, instance, format):
def check(self, instance: object, format: str) -> None:
"""
Check whether the instance conforms to the given format.
Expand Down Expand Up @@ -109,7 +123,7 @@ def check(self, instance, format):
if not result:
raise FormatError(f"{instance!r} is not a {format!r}", cause=cause)

def conforms(self, instance, format):
def conforms(self, instance: object, format: str) -> bool:
"""
Check whether the instance conforms to the given format.
Expand Down Expand Up @@ -143,7 +157,7 @@ def conforms(self, instance, format):
draft201909_format_checker = FormatChecker()
draft202012_format_checker = FormatChecker()

_draft_checkers = dict(
_draft_checkers: dict[str, FormatChecker] = dict(
draft3=draft3_format_checker,
draft4=draft4_format_checker,
draft6=draft6_format_checker,
Expand All @@ -162,15 +176,15 @@ def _checks_drafts(
draft201909=None,
draft202012=None,
raises=(),
):
) -> typing.Callable[[_F], _F]:
draft3 = draft3 or name
draft4 = draft4 or name
draft6 = draft6 or name
draft7 = draft7 or name
draft201909 = draft201909 or name
draft202012 = draft202012 or name

def wrap(func):
def wrap(func: _F) -> _F:
if draft3:
func = _draft_checkers["draft3"].checks(draft3, raises)(func)
if draft4:
Expand All @@ -195,12 +209,13 @@ def wrap(func):
raises,
)(func)
return func

return wrap


@_checks_drafts(name="idn-email")
@_checks_drafts(name="email")
def is_email(instance):
def is_email(instance: object) -> bool:
if not isinstance(instance, str):
return True
return "@" in instance
Expand All @@ -215,14 +230,14 @@ def is_email(instance):
draft202012="ipv4",
raises=ipaddress.AddressValueError,
)
def is_ipv4(instance):
def is_ipv4(instance: object) -> bool:
if not isinstance(instance, str):
return True
return ipaddress.IPv4Address(instance)
return bool(ipaddress.IPv4Address(instance))


@_checks_drafts(name="ipv6", raises=ipaddress.AddressValueError)
def is_ipv6(instance):
def is_ipv6(instance: object) -> bool:
if not isinstance(instance, str):
return True
address = ipaddress.IPv6Address(instance)
Expand All @@ -240,7 +255,7 @@ def is_ipv6(instance):
draft201909="hostname",
draft202012="hostname",
)
def is_host_name(instance):
def is_host_name(instance: object) -> bool:
if not isinstance(instance, str):
return True
return FQDN(instance).is_valid
Expand All @@ -256,7 +271,7 @@ def is_host_name(instance):
draft202012="idn-hostname",
raises=(idna.IDNAError, UnicodeError),
)
def is_idn_host_name(instance):
def is_idn_host_name(instance: object) -> bool:
if not isinstance(instance, str):
return True
idna.encode(instance)
Expand All @@ -270,7 +285,7 @@ def is_idn_host_name(instance):
from rfc3986_validator import validate_rfc3986

@_checks_drafts(name="uri")
def is_uri(instance):
def is_uri(instance: object) -> bool:
if not isinstance(instance, str):
return True
return validate_rfc3986(instance, rule="URI")
Expand All @@ -282,19 +297,20 @@ def is_uri(instance):
draft202012="uri-reference",
raises=ValueError,
)
def is_uri_reference(instance):
def is_uri_reference(instance: object) -> bool:
if not isinstance(instance, str):
return True
return validate_rfc3986(instance, rule="URI_reference")

else:

@_checks_drafts(
draft7="iri",
draft201909="iri",
draft202012="iri",
raises=ValueError,
)
def is_iri(instance):
def is_iri(instance: object) -> bool:
if not isinstance(instance, str):
return True
return rfc3987.parse(instance, rule="IRI")
Expand All @@ -305,13 +321,13 @@ def is_iri(instance):
draft202012="iri-reference",
raises=ValueError,
)
def is_iri_reference(instance):
def is_iri_reference(instance: object) -> bool:
if not isinstance(instance, str):
return True
return rfc3987.parse(instance, rule="IRI_reference")

@_checks_drafts(name="uri", raises=ValueError)
def is_uri(instance):
def is_uri(instance: object) -> bool:
if not isinstance(instance, str):
return True
return rfc3987.parse(instance, rule="URI")
Expand All @@ -323,16 +339,17 @@ def is_uri(instance):
draft202012="uri-reference",
raises=ValueError,
)
def is_uri_reference(instance):
def is_uri_reference(instance: object) -> bool:
if not isinstance(instance, str):
return True
return rfc3987.parse(instance, rule="URI_reference")


with suppress(ImportError):
from rfc3339_validator import validate_rfc3339

@_checks_drafts(name="date-time")
def is_datetime(instance):
def is_datetime(instance: object) -> bool:
if not isinstance(instance, str):
return True
return validate_rfc3339(instance.upper())
Expand All @@ -342,17 +359,17 @@ def is_datetime(instance):
draft201909="time",
draft202012="time",
)
def is_time(instance):
def is_time(instance: object) -> bool:
if not isinstance(instance, str):
return True
return is_datetime("1970-01-01T" + instance)


@_checks_drafts(name="regex", raises=re.error)
def is_regex(instance):
def is_regex(instance: object) -> bool:
if not isinstance(instance, str):
return True
return re.compile(instance)
return bool(re.compile(instance))


@_checks_drafts(
Expand All @@ -362,28 +379,28 @@ def is_regex(instance):
draft202012="date",
raises=ValueError,
)
def is_date(instance):
def is_date(instance: object) -> bool:
if not isinstance(instance, str):
return True
return instance.isascii() and datetime.date.fromisoformat(instance)
return bool(instance.isascii() and datetime.date.fromisoformat(instance))


@_checks_drafts(draft3="time", raises=ValueError)
def is_draft3_time(instance):
def is_draft3_time(instance: object) -> bool:
if not isinstance(instance, str):
return True
return datetime.datetime.strptime(instance, "%H:%M:%S")
return bool(datetime.datetime.strptime(instance, "%H:%M:%S"))


with suppress(ImportError):
from webcolors import CSS21_NAMES_TO_HEX
import webcolors

def is_css_color_code(instance):
def is_css_color_code(instance: object) -> bool:
return webcolors.normalize_hex(instance)

@_checks_drafts(draft3="color", raises=(ValueError, TypeError))
def is_css21_color(instance):
def is_css21_color(instance: object) -> bool:
if (
not isinstance(instance, str)
or instance.lower() in CSS21_NAMES_TO_HEX
Expand All @@ -402,10 +419,10 @@ def is_css21_color(instance):
draft202012="json-pointer",
raises=jsonpointer.JsonPointerException,
)
def is_json_pointer(instance):
def is_json_pointer(instance: object) -> bool:
if not isinstance(instance, str):
return True
return jsonpointer.JsonPointer(instance)
return bool(jsonpointer.JsonPointer(instance))

# TODO: I don't want to maintain this, so it
# needs to go either into jsonpointer (pending
Expand All @@ -417,7 +434,7 @@ def is_json_pointer(instance):
draft202012="relative-json-pointer",
raises=jsonpointer.JsonPointerException,
)
def is_relative_json_pointer(instance):
def is_relative_json_pointer(instance: object) -> bool:
# Definition taken from:
# https://tools.ietf.org/html/draft-handrews-relative-json-pointer-01#section-3
if not isinstance(instance, str):
Expand All @@ -437,7 +454,7 @@ def is_relative_json_pointer(instance):

rest = instance[i:]
break
return (rest == "#") or jsonpointer.JsonPointer(rest)
return (rest == "#") or bool(jsonpointer.JsonPointer(rest))


with suppress(ImportError):
Expand All @@ -449,7 +466,7 @@ def is_relative_json_pointer(instance):
draft201909="uri-template",
draft202012="uri-template",
)
def is_uri_template(instance):
def is_uri_template(instance: object) -> bool:
if not isinstance(instance, str):
return True
return uri_template.validate(instance)
Expand All @@ -463,18 +480,18 @@ def is_uri_template(instance):
draft202012="duration",
raises=isoduration.DurationParsingException,
)
def is_duration(instance):
def is_duration(instance: object) -> bool:
if not isinstance(instance, str):
return True
return isoduration.parse_duration(instance)
return bool(isoduration.parse_duration(instance))


@_checks_drafts(
draft201909="uuid",
draft202012="uuid",
raises=ValueError,
)
def is_uuid(instance):
def is_uuid(instance: object) -> bool:
if not isinstance(instance, str):
return True
UUID(instance)
Expand Down

0 comments on commit b1c1d00

Please sign in to comment.