diff --git a/CHANGELOG.rst b/CHANGELOG.rst index c11ce06a..81d421e2 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -19,6 +19,8 @@ Fixed Added ~~~~~ - Add to_jwk static method to ECAlgorithm by @leonsmith in https://github.com/jpadilla/pyjwt/pull/732 +- Add ``get_algorithm_by_name`` as a method of ``PyJWS`` objects, and expose + the global PyJWS method as part of the public API `v2.4.0 `__ ----------------------------------------------------------------------- diff --git a/jwt/__init__.py b/jwt/__init__.py index 6b3f8ab1..a96cc6ee 100644 --- a/jwt/__init__.py +++ b/jwt/__init__.py @@ -1,6 +1,7 @@ from .api_jwk import PyJWK, PyJWKSet from .api_jws import ( PyJWS, + get_algorithm_by_name, get_unverified_header, register_algorithm, unregister_algorithm, @@ -51,6 +52,7 @@ "get_unverified_header", "register_algorithm", "unregister_algorithm", + "get_algorithm_by_name", # Exceptions "DecodeError", "ExpiredSignatureError", diff --git a/jwt/api_jws.py b/jwt/api_jws.py index 16ec846c..f8c8048d 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -73,6 +73,23 @@ def get_algorithms(self): """ return list(self._valid_algs) + def get_algorithm_by_name(self, alg_name: str) -> Algorithm: + """ + For a given string name, return the matching Algorithm object. + + Example usage: + + >>> jws_obj.get_algorithm_by_name("RS256") + """ + try: + return self._algorithms[alg_name] + except KeyError as e: + if not has_crypto and alg_name in requires_cryptography: + raise NotImplementedError( + f"Algorithm '{alg_name}' could not be found. Do you have cryptography installed?" + ) from e + raise NotImplementedError("Algorithm not supported") from e + def encode( self, payload: bytes, @@ -84,21 +101,21 @@ def encode( ) -> str: segments = [] - if algorithm is None: - algorithm = "none" + # declare a new var to narrow the type for type checkers + algorithm_: str = algorithm if algorithm is not None else "none" # Prefer headers values if present to function parameters. if headers: headers_alg = headers.get("alg") if headers_alg: - algorithm = headers["alg"] + algorithm_ = headers["alg"] headers_b64 = headers.get("b64") if headers_b64 is False: is_payload_detached = True # Header - header = {"typ": self.header_typ, "alg": algorithm} # type: Dict[str, Any] + header = {"typ": self.header_typ, "alg": algorithm_} # type: Dict[str, Any] if headers: self._validate_headers(headers) @@ -128,17 +145,9 @@ def encode( # Segments signing_input = b".".join(segments) - try: - alg_obj = self._algorithms[algorithm] - key = alg_obj.prepare_key(key) - signature = alg_obj.sign(signing_input, key) - - except KeyError as e: - if not has_crypto and algorithm in requires_cryptography: - raise NotImplementedError( - f"Algorithm '{algorithm}' could not be found. Do you have cryptography installed?" - ) from e - raise NotImplementedError("Algorithm not supported") from e + alg_obj = self.get_algorithm_by_name(algorithm_) + key = alg_obj.prepare_key(key) + signature = alg_obj.sign(signing_input, key) segments.append(base64url_encode(signature)) @@ -262,14 +271,13 @@ def _verify_signature( raise InvalidAlgorithmError("The specified alg value is not allowed") try: - alg_obj = self._algorithms[alg] - key = alg_obj.prepare_key(key) - - if not alg_obj.verify(signing_input, key, signature): - raise InvalidSignatureError("Signature verification failed") - - except KeyError as e: + alg_obj = self.get_algorithm_by_name(alg) + except NotImplementedError as e: raise InvalidAlgorithmError("Algorithm not supported") from e + key = alg_obj.prepare_key(key) + + if not alg_obj.verify(signing_input, key, signature): + raise InvalidSignatureError("Signature verification failed") def _validate_headers(self, headers): if "kid" in headers: @@ -286,4 +294,5 @@ def _validate_kid(self, kid): decode = _jws_global_obj.decode register_algorithm = _jws_global_obj.register_algorithm unregister_algorithm = _jws_global_obj.unregister_algorithm +get_algorithm_by_name = _jws_global_obj.get_algorithm_by_name get_unverified_header = _jws_global_obj.get_unverified_header