Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decode with PyJWK #886

Merged
merged 14 commits into from May 11, 2024
1 change: 1 addition & 0 deletions jwt/api_jwk.py
Expand Up @@ -52,6 +52,7 @@ def __init__(self, jwk_data: JWKDict, algorithm: str | None = None) -> None:
if not has_crypto and algorithm in requires_cryptography:
raise PyJWKError(f"{algorithm} requires 'cryptography' to be installed.")

self.algorithm = algorithm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo having both self.algorithm and self.Algorithm may be too similar

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understandable. Any suggestions? self.algorithm_name? self.alg? (To match the JWK field)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.algorithm_name seems alright, actually I think self.Algorithm is the one being wrong, and self.algorithm_class or something similar might be better. Anyway better not breaking any public attribute

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm +1 on @Viicos's suggestion.

self.Algorithm = self._algorithms.get(algorithm)

if not self.Algorithm:
Expand Down
10 changes: 8 additions & 2 deletions jwt/api_jws.py
Expand Up @@ -11,6 +11,7 @@
has_crypto,
requires_cryptography,
)
from .api_jwk import PyJWK
from .exceptions import (
DecodeError,
InvalidAlgorithmError,
Expand Down Expand Up @@ -172,7 +173,7 @@
def decode_complete(
self,
jwt: str | bytes,
key: AllowedPublicKeys | str | bytes = "",
key: AllowedPublicKeys | PyJWK | str | bytes = "",
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
detached_payload: bytes | None = None,
Expand All @@ -190,6 +191,11 @@
merged_options = {**self.options, **options}
verify_signature = merged_options["verify_signature"]

if isinstance(key, PyJWK):
if algorithms is None:
algorithms = [key.algorithm]
key = key.key

Check warning on line 197 in jwt/api_jws.py

View check run for this annotation

Codecov / codecov/patch

jwt/api_jws.py#L196-L197

Added lines #L196 - L197 were not covered by tests

if verify_signature and not algorithms:
raise DecodeError(
'It is required that you pass in a value for the "algorithms" argument when calling decode().'
Expand Down Expand Up @@ -217,7 +223,7 @@
def decode(
self,
jwt: str | bytes,
key: AllowedPublicKeys | str | bytes = "",
key: AllowedPublicKeys | PyJWK | str | bytes = "",
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
detached_payload: bytes | None = None,
Expand Down
7 changes: 4 additions & 3 deletions jwt/api_jwt.py
Expand Up @@ -7,7 +7,7 @@
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any

from . import api_jws
from . import api_jwk, api_jws
from .exceptions import (
DecodeError,
ExpiredSignatureError,
Expand All @@ -21,6 +21,7 @@

if TYPE_CHECKING:
from .algorithms import AllowedPrivateKeys, AllowedPublicKeys
from .api_jwk import PyJWK


class PyJWT:
Expand Down Expand Up @@ -100,7 +101,7 @@ def _encode_payload(
def decode_complete(
self,
jwt: str | bytes,
key: AllowedPublicKeys | str | bytes = "",
key: AllowedPublicKeys | PyJWK | str | bytes = "",
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
# deprecated arg, remove in pyjwt3
Expand Down Expand Up @@ -185,7 +186,7 @@ def _decode_payload(self, decoded: dict[str, Any]) -> Any:
def decode(
self,
jwt: str | bytes,
key: AllowedPublicKeys | str | bytes = "",
key: AllowedPublicKeys | PyJWK | str | bytes = "",
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
# deprecated arg, remove in pyjwt3
Expand Down