Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mypy as pre-commit check + api_jws typing #787

Merged
merged 3 commits into from Aug 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Expand Up @@ -34,3 +34,8 @@ repos:
hooks:
- id: check-manifest
args: [--no-build-isolation]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v0.971"
hooks:
- id: mypy
61 changes: 31 additions & 30 deletions jwt/api_jws.py
@@ -1,8 +1,9 @@
from __future__ import annotations

import binascii
import json
import warnings
from collections.abc import Mapping
from typing import Any, Dict, List, Optional, Type
from typing import Any, Type

from .algorithms import (
Algorithm,
Expand All @@ -23,7 +24,7 @@
class PyJWS:
header_typ = "JWT"

def __init__(self, algorithms=None, options=None):
def __init__(self, algorithms=None, options=None) -> None:
self._algorithms = get_default_algorithms()
self._valid_algs = (
set(algorithms) if algorithms is not None else set(self._algorithms)
Expand All @@ -39,10 +40,10 @@ def __init__(self, algorithms=None, options=None):
self.options = {**self._get_default_options(), **options}

@staticmethod
def _get_default_options():
def _get_default_options() -> dict[str, bool]:
return {"verify_signature": True}

def register_algorithm(self, alg_id, alg_obj):
def register_algorithm(self, alg_id: str, alg_obj: Algorithm) -> None:
"""
Registers a new Algorithm for use when creating and verifying tokens.
"""
Expand All @@ -55,7 +56,7 @@ def register_algorithm(self, alg_id, alg_obj):
self._algorithms[alg_id] = alg_obj
self._valid_algs.add(alg_id)

def unregister_algorithm(self, alg_id):
def unregister_algorithm(self, alg_id: str) -> None:
"""
Unregisters an Algorithm for use when creating and verifying tokens
Throws KeyError if algorithm is not registered.
Expand All @@ -69,7 +70,7 @@ def unregister_algorithm(self, alg_id):
del self._algorithms[alg_id]
self._valid_algs.remove(alg_id)

def get_algorithms(self):
def get_algorithms(self) -> list[str]:
"""
Returns a list of supported values for the 'alg' parameter.
"""
Expand All @@ -96,9 +97,9 @@ def encode(
self,
payload: bytes,
key: str,
algorithm: Optional[str] = "HS256",
headers: Optional[Dict[str, Any]] = None,
json_encoder: Optional[Type[json.JSONEncoder]] = None,
algorithm: str | None = "HS256",
headers: dict[str, Any] | None = None,
json_encoder: Type[json.JSONEncoder] | None = None,
is_payload_detached: bool = False,
) -> str:
segments = []
Expand All @@ -117,7 +118,7 @@ def encode(
is_payload_detached = True

# Header
header = {"typ": self.header_typ, "alg": algorithm_} # type: Dict[str, Any]
header: dict[str, Any] = {"typ": self.header_typ, "alg": algorithm_}

if headers:
self._validate_headers(headers)
Expand Down Expand Up @@ -165,11 +166,11 @@ def decode_complete(
self,
jwt: str,
key: str = "",
algorithms: Optional[List[str]] = None,
options: Optional[Dict[str, Any]] = None,
detached_payload: Optional[bytes] = None,
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
detached_payload: bytes | None = None,
**kwargs,
) -> Dict[str, Any]:
) -> dict[str, Any]:
if kwargs:
warnings.warn(
"passing additional kwargs to decode_complete() is deprecated "
Expand Down Expand Up @@ -210,9 +211,9 @@ def decode(
self,
jwt: str,
key: str = "",
algorithms: Optional[List[str]] = None,
options: Optional[Dict[str, Any]] = None,
detached_payload: Optional[bytes] = None,
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
detached_payload: bytes | None = None,
**kwargs,
) -> str:
if kwargs:
Expand All @@ -227,7 +228,7 @@ def decode(
)
return decoded["payload"]

def get_unverified_header(self, jwt):
def get_unverified_header(self, jwt: str | bytes) -> dict:
"""Returns back the JWT header parameters as a dict()

Note: The signature is not verified so the header parameters
Expand All @@ -238,7 +239,7 @@ def get_unverified_header(self, jwt):

return headers

def _load(self, jwt):
def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict, bytes]:
if isinstance(jwt, str):
jwt = jwt.encode("utf-8")

Expand All @@ -261,7 +262,7 @@ def _load(self, jwt):
except ValueError as e:
raise DecodeError(f"Invalid header string: {e}") from e

if not isinstance(header, Mapping):
if not isinstance(header, dict):
raise DecodeError("Invalid header string: must be a json object")

try:
Expand All @@ -278,16 +279,16 @@ def _load(self, jwt):

def _verify_signature(
self,
signing_input,
header,
signature,
key="",
algorithms=None,
):
signing_input: bytes,
header: dict,
signature: bytes,
key: str = "",
algorithms: list[str] | None = None,
) -> None:

alg = header.get("alg")

if algorithms is not None and alg not in algorithms:
if not alg or (algorithms is not None and alg not in algorithms):
raise InvalidAlgorithmError("The specified alg value is not allowed")

try:
Expand All @@ -299,11 +300,11 @@ def _verify_signature(
if not alg_obj.verify(signing_input, key, signature):
raise InvalidSignatureError("Signature verification failed")

def _validate_headers(self, headers):
def _validate_headers(self, headers: dict[str, Any]) -> None:
if "kid" in headers:
self._validate_kid(headers["kid"])

def _validate_kid(self, kid):
def _validate_kid(self, kid: str) -> None:
if not isinstance(kid, str):
raise InvalidTokenError("Key ID header parameter must be a string")

Expand Down
13 changes: 7 additions & 6 deletions jwt/help.py
Expand Up @@ -8,7 +8,7 @@
try:
import cryptography
except ModuleNotFoundError:
cryptography = None # type: ignore
cryptography = None


def info() -> Dict[str, Dict[str, str]]:
Expand All @@ -29,14 +29,15 @@ def info() -> Dict[str, Dict[str, str]]:
if implementation == "CPython":
implementation_version = platform.python_version()
elif implementation == "PyPy":
pypy_version_info = getattr(sys, "pypy_version_info")
implementation_version = (
f"{sys.pypy_version_info.major}." # type: ignore[attr-defined]
f"{sys.pypy_version_info.minor}."
f"{sys.pypy_version_info.micro}"
f"{pypy_version_info.major}."
f"{pypy_version_info.minor}."
f"{pypy_version_info.micro}"
)
if sys.pypy_version_info.releaselevel != "final": # type: ignore[attr-defined]
if pypy_version_info.releaselevel != "final":
implementation_version = "".join(
[implementation_version, sys.pypy_version_info.releaselevel] # type: ignore[attr-defined]
[implementation_version, pypy_version_info.releaselevel]
)
else:
implementation_version = "Unknown"
Expand Down
4 changes: 2 additions & 2 deletions jwt/utils.py
@@ -1,7 +1,7 @@
import base64
import binascii
import re
from typing import Any, Union
from typing import Union

try:
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve
Expand All @@ -10,7 +10,7 @@
encode_dss_signature,
)
except ModuleNotFoundError:
EllipticCurve = Any # type: ignore
EllipticCurve = None


def force_bytes(value: Union[str, bytes]) -> bytes:
Expand Down
8 changes: 6 additions & 2 deletions pyproject.toml
Expand Up @@ -2,7 +2,6 @@
requires = ["setuptools"]
build-backend = "setuptools.build_meta"


[tool.coverage.run]
parallel = true
branch = true
Expand All @@ -14,8 +13,13 @@ source = ["jwt", ".tox/*/site-packages"]
[tool.coverage.report]
show_missing = true


[tool.isort]
profile = "black"
atomic = true
combine_as_imports = true

[tool.mypy]
python_version = 3.7
ignore_missing_imports = true
warn_unused_ignores = true
no_implicit_optional = true
7 changes: 0 additions & 7 deletions setup.cfg
Expand Up @@ -57,7 +57,6 @@ dev =
types-cryptography>=3.3.21
pytest>=6.0.0,<7.0.0
coverage[toml]==5.0.4
mypy
pre-commit

[options.packages.find]
Expand All @@ -67,9 +66,3 @@ exclude =

[flake8]
extend-ignore = E203, E501

[mypy]
python_version = 3.7
ignore_missing_imports = true
warn_unused_ignores = true
no_implicit_optional = true
6 changes: 0 additions & 6 deletions tox.ini
Expand Up @@ -48,12 +48,6 @@ commands =
python -m doctest README.rst


[testenv:typing]
basepython = python3.8
extras = dev
commands = mypy jwt


[testenv:lint]
basepython = python3.8
extras = dev
Expand Down