Skip to content

Commit

Permalink
Make Algorithm an abstract base class (#845)
Browse files Browse the repository at this point in the history
* Make `Algorithm` an abstract base class

This also removes some tests that are not relevant anymore

Raise `NotImplementedError` for `NoneAlgorithm`

* Use `hasattr` instead of `getattr`

* Only allow `dict` in `encode`
  • Loading branch information
Viicos committed Mar 7, 2023
1 parent 5a2a6b6 commit 777efa2
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 52 deletions.
29 changes: 19 additions & 10 deletions jwt/algorithms.py
@@ -1,6 +1,7 @@
import hashlib
import hmac
import json
from abc import ABC, abstractmethod
from typing import Any, ClassVar, Dict, Type, Union

from .exceptions import InvalidKeyError
Expand Down Expand Up @@ -117,7 +118,7 @@ def get_default_algorithms() -> Dict[str, "Algorithm"]:
return default_algorithms


class Algorithm:
class Algorithm(ABC):
"""
The interface for an algorithm used to sign and verify tokens.
"""
Expand Down Expand Up @@ -148,40 +149,40 @@ def compute_hash_digest(self, bytestr: bytes) -> bytes:
# variadic (TypeVar) but as discussed in https://github.com/jpadilla/pyjwt/pull/605
# that may still be poorly supported.

@abstractmethod
def prepare_key(self, key: Any) -> Any:
"""
Performs necessary validation and conversions on the key and returns
the key value in the proper format for sign() and verify().
"""
raise NotImplementedError

@abstractmethod
def sign(self, msg: bytes, key: Any) -> bytes:
"""
Returns a digital signature for the specified message
using the specified key value.
"""
raise NotImplementedError

@abstractmethod
def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
"""
Verifies that the specified digital signature is valid
for the specified message and key values.
"""
raise NotImplementedError

@staticmethod
@abstractmethod
def to_jwk(key_obj) -> JWKDict:
"""
Serializes a given RSA key into a JWK
"""
raise NotImplementedError

@staticmethod
@abstractmethod
def from_jwk(jwk: JWKDict):
"""
Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
"""
raise NotImplementedError


class NoneAlgorithm(Algorithm):
Expand All @@ -205,6 +206,14 @@ def sign(self, msg, key):
def verify(self, msg, key, sig):
return False

@staticmethod
def to_jwk(key_obj) -> JWKDict:
raise NotImplementedError()

@staticmethod
def from_jwk(jwk: JWKDict):
raise NotImplementedError()


class HMACAlgorithm(Algorithm):
"""
Expand Down Expand Up @@ -299,7 +308,7 @@ def prepare_key(self, key):
def to_jwk(key_obj):
obj = None

if getattr(key_obj, "private_numbers", None):
if hasattr(key_obj, "private_numbers"):
# Private key
numbers = key_obj.private_numbers()

Expand All @@ -316,7 +325,7 @@ def to_jwk(key_obj):
"qi": to_base64url_uint(numbers.iqmp).decode(),
}

elif getattr(key_obj, "verify", None):
elif hasattr(key_obj, "verify"):
# Public key
numbers = key_obj.public_numbers()

Expand Down Expand Up @@ -587,7 +596,7 @@ def sign(self, msg, key):
msg,
padding.PSS(
mgf=padding.MGF1(self.hash_alg()),
salt_length=self.hash_alg.digest_size,
salt_length=self.hash_alg().digest_size,
),
self.hash_alg(),
)
Expand All @@ -599,7 +608,7 @@ def verify(self, msg, key, sig):
msg,
padding.PSS(
mgf=padding.MGF1(self.hash_alg()),
salt_length=self.hash_alg.digest_size,
salt_length=self.hash_alg().digest_size,
),
self.hash_alg(),
)
Expand Down
8 changes: 4 additions & 4 deletions jwt/api_jwt.py
Expand Up @@ -3,7 +3,7 @@
import json
import warnings
from calendar import timegm
from collections.abc import Iterable, Mapping
from collections.abc import Iterable
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Type, Union

Expand Down Expand Up @@ -47,10 +47,10 @@ def encode(
json_encoder: Optional[Type[json.JSONEncoder]] = None,
sort_headers: bool = True,
) -> str:
# Check that we get a mapping
if not isinstance(payload, Mapping):
# Check that we get a dict
if not isinstance(payload, dict):
raise TypeError(
"Expecting a mapping object, as JWT only supports "
"Expecting a dict object, as JWT only supports "
"JSON objects as payloads."
)

Expand Down
46 changes: 11 additions & 35 deletions tests/test_algorithms.py
Expand Up @@ -3,7 +3,7 @@

import pytest

from jwt.algorithms import Algorithm, HMACAlgorithm, NoneAlgorithm, has_crypto
from jwt.algorithms import HMACAlgorithm, NoneAlgorithm, has_crypto
from jwt.exceptions import InvalidKeyError
from jwt.utils import base64url_decode

Expand All @@ -15,47 +15,23 @@


class TestAlgorithms:
def test_algorithm_should_throw_exception_if_prepare_key_not_impl(self):
algo = Algorithm()

with pytest.raises(NotImplementedError):
algo.prepare_key("test")

def test_algorithm_should_throw_exception_if_sign_not_impl(self):
algo = Algorithm()

with pytest.raises(NotImplementedError):
algo.sign(b"message", "key")

def test_algorithm_should_throw_exception_if_verify_not_impl(self):
algo = Algorithm()

with pytest.raises(NotImplementedError):
algo.verify(b"message", "key", b"signature")

def test_algorithm_should_throw_exception_if_to_jwk_not_impl(self):
algo = Algorithm()

with pytest.raises(NotImplementedError):
algo.from_jwk({"val": "ue"})

def test_algorithm_should_throw_exception_if_from_jwk_not_impl(self):
algo = Algorithm()
def test_none_algorithm_should_throw_exception_if_key_is_not_none(self):
algo = NoneAlgorithm()

with pytest.raises(NotImplementedError):
algo.to_jwk("value")
with pytest.raises(InvalidKeyError):
algo.prepare_key("123")

def test_algorithm_should_throw_exception_if_compute_hash_digest_not_impl(self):
algo = Algorithm()
def test_none_algorithm_should_throw_exception_on_to_jwk(self):
algo = NoneAlgorithm()

with pytest.raises(NotImplementedError):
algo.compute_hash_digest(b"value")
algo.to_jwk("dummy") # Using a dummy argument as is it not relevant

def test_none_algorithm_should_throw_exception_if_key_is_not_none(self):
def test_none_algorithm_should_throw_exception_on_from_jwk(self):
algo = NoneAlgorithm()

with pytest.raises(InvalidKeyError):
algo.prepare_key("123")
with pytest.raises(NotImplementedError):
algo.from_jwk({}) # Using a dummy argument as is it not relevant

def test_hmac_should_reject_nonstring_key(self):
algo = HMACAlgorithm(HMACAlgorithm.SHA256)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_api_jws.py
Expand Up @@ -3,7 +3,7 @@

import pytest

from jwt.algorithms import Algorithm, has_crypto
from jwt.algorithms import NoneAlgorithm, has_crypto
from jwt.api_jws import PyJWS
from jwt.exceptions import (
DecodeError,
Expand Down Expand Up @@ -39,10 +39,10 @@ def payload():

class TestJWS:
def test_register_algo_does_not_allow_duplicate_registration(self, jws):
jws.register_algorithm("AAA", Algorithm())
jws.register_algorithm("AAA", NoneAlgorithm())

with pytest.raises(ValueError):
jws.register_algorithm("AAA", Algorithm())
jws.register_algorithm("AAA", NoneAlgorithm())

def test_register_algo_rejects_non_algorithm_obj(self, jws):
with pytest.raises(TypeError):
Expand Down

0 comments on commit 777efa2

Please sign in to comment.