Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose get_algorithm_by_name as new method #773

Merged
merged 2 commits into from Jul 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Expand Up @@ -19,6 +19,8 @@ Fixed
Added
~~~~~
- Add to_jwk static method to ECAlgorithm by @leonsmith in https://github.com/jpadilla/pyjwt/pull/732
- Add ``get_algorithm_by_name`` as a method of ``PyJWS`` objects, and expose
the global PyJWS method as part of the public API

`v2.4.0 <https://github.com/jpadilla/pyjwt/compare/2.3.0...2.4.0>`__
-----------------------------------------------------------------------
Expand Down
2 changes: 2 additions & 0 deletions jwt/__init__.py
@@ -1,6 +1,7 @@
from .api_jwk import PyJWK, PyJWKSet
from .api_jws import (
PyJWS,
get_algorithm_by_name,
get_unverified_header,
register_algorithm,
unregister_algorithm,
Expand Down Expand Up @@ -51,6 +52,7 @@
"get_unverified_header",
"register_algorithm",
"unregister_algorithm",
"get_algorithm_by_name",
# Exceptions
"DecodeError",
"ExpiredSignatureError",
Expand Down
53 changes: 31 additions & 22 deletions jwt/api_jws.py
Expand Up @@ -73,6 +73,23 @@ def get_algorithms(self):
"""
return list(self._valid_algs)

def get_algorithm_by_name(self, alg_name: str) -> Algorithm:
"""
For a given string name, return the matching Algorithm object.

Example usage:

>>> jws_obj.get_algorithm_by_name("RS256")
"""
try:
return self._algorithms[alg_name]
except KeyError as e:
if not has_crypto and alg_name in requires_cryptography:
raise NotImplementedError(
f"Algorithm '{alg_name}' could not be found. Do you have cryptography installed?"
) from e
raise NotImplementedError("Algorithm not supported") from e

def encode(
self,
payload: bytes,
Expand All @@ -84,21 +101,21 @@ def encode(
) -> str:
segments = []

if algorithm is None:
algorithm = "none"
# declare a new var to narrow the type for type checkers
algorithm_: str = algorithm if algorithm is not None else "none"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no strong opinions about this 👍


# Prefer headers values if present to function parameters.
if headers:
headers_alg = headers.get("alg")
if headers_alg:
algorithm = headers["alg"]
algorithm_ = headers["alg"]

headers_b64 = headers.get("b64")
if headers_b64 is False:
is_payload_detached = True

# Header
header = {"typ": self.header_typ, "alg": algorithm} # type: Dict[str, Any]
header = {"typ": self.header_typ, "alg": algorithm_} # type: Dict[str, Any]

if headers:
self._validate_headers(headers)
Expand Down Expand Up @@ -128,17 +145,9 @@ def encode(
# Segments
signing_input = b".".join(segments)

try:
alg_obj = self._algorithms[algorithm]
key = alg_obj.prepare_key(key)
signature = alg_obj.sign(signing_input, key)

except KeyError as e:
if not has_crypto and algorithm in requires_cryptography:
raise NotImplementedError(
f"Algorithm '{algorithm}' could not be found. Do you have cryptography installed?"
) from e
raise NotImplementedError("Algorithm not supported") from e
alg_obj = self.get_algorithm_by_name(algorithm_)
key = alg_obj.prepare_key(key)
signature = alg_obj.sign(signing_input, key)

segments.append(base64url_encode(signature))

Expand Down Expand Up @@ -262,14 +271,13 @@ def _verify_signature(
raise InvalidAlgorithmError("The specified alg value is not allowed")

try:
alg_obj = self._algorithms[alg]
key = alg_obj.prepare_key(key)

if not alg_obj.verify(signing_input, key, signature):
raise InvalidSignatureError("Signature verification failed")

except KeyError as e:
alg_obj = self.get_algorithm_by_name(alg)
except NotImplementedError as e:
raise InvalidAlgorithmError("Algorithm not supported") from e
key = alg_obj.prepare_key(key)

if not alg_obj.verify(signing_input, key, signature):
raise InvalidSignatureError("Signature verification failed")

def _validate_headers(self, headers):
if "kid" in headers:
Expand All @@ -286,4 +294,5 @@ def _validate_kid(self, kid):
decode = _jws_global_obj.decode
register_algorithm = _jws_global_obj.register_algorithm
unregister_algorithm = _jws_global_obj.unregister_algorithm
get_algorithm_by_name = _jws_global_obj.get_algorithm_by_name
get_unverified_header = _jws_global_obj.get_unverified_header