From 581f4e17b836ffc4600e413e93d39296291dacd4 Mon Sep 17 00:00:00 2001 From: Stephen Rosen Date: Tue, 5 Jul 2022 01:38:46 -0400 Subject: [PATCH] Emit a deprecation warning for unsupported kwargs (#776) `**kwargs` usages cannot be removed without breaking backwards compatibility. Unsupported kwargs cannot even be rejected without breaking compatibility. However, this does not mean that the library cannot identify and warn when unsupported arguments are used. The warning behavior simply has to be separated from any removal of `**kwargs`. All legitimate `**kwargs` usages have been replaced with explicit arguments. Any other arguments will be captured under `**kwargs` and trigger the deprecation warnings. In the cases of `decode() -> decode_complete()` passthrough, the passthrough has been removed to avoid duplicate deprecation warnings on a single usage. This makes a very subtle behavioral change to `**kwargs` *only* for the case of a subclass of PyJWT or PyJWS. Extra arguments used by a specialized subclass won't pass through transparently anymore. In such a case the subclass author has multiple resolutions available, including reimplementation of the `decode()` method to passthrough the additional argument. Although technically backwards-incompatible for a niche subclassing usage, this behavior is very nearly identical and shouldn't pose an issue for the vast majority of pyjwt users. The deprecation warning does not cover all deprecated usages. In particular, several passthrough arguments for claim validation should probably be made available via `options` and later removed. The arguments in need of attention now have inline comments in the signature definitions, but are otherwise left unmodified, leaving current usages correct and valid. --- CHANGELOG.rst | 3 +++ jwt/api_jws.py | 21 ++++++++++++++- jwt/api_jwt.py | 59 ++++++++++++++++++++++++++++++++++++++----- jwt/warnings.py | 2 ++ tests/test_api_jws.py | 35 +++++++++++++++++++++++++ tests/test_api_jwt.py | 19 ++++++++++++++ 6 files changed, 131 insertions(+), 8 deletions(-) create mode 100644 jwt/warnings.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 3ed65591..fb88a294 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -11,6 +11,9 @@ Changed ~~~~~~~ - Skip keys with incompatible alg when loading JWKSet by @DaGuich in https://github.com/jpadilla/pyjwt/pull/762 - Remove support for python3.6 +- PyJWT now emits a warning for unsupported keyword arguments being passed to + ``decode`` and ``decode_complete``. Additional keyword arguments are still + supported, but will be rejected in a future version. Fixed ~~~~~ diff --git a/jwt/api_jws.py b/jwt/api_jws.py index f8c8048d..75f826da 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -1,5 +1,6 @@ import binascii import json +import warnings from collections.abc import Mapping from typing import Any, Dict, List, Optional, Type @@ -16,6 +17,7 @@ InvalidTokenError, ) from .utils import base64url_decode, base64url_encode +from .warnings import RemovedInPyjwt3Warning class PyJWS: @@ -167,6 +169,13 @@ def decode_complete( detached_payload: Optional[bytes] = None, **kwargs, ) -> Dict[str, Any]: + if kwargs: + warnings.warn( + "passing additional kwargs to decode_complete() is deprecated " + "and will be removed in pyjwt version 3. " + f"Unsupported kwargs: {tuple(kwargs.keys())}", + RemovedInPyjwt3Warning, + ) if options is None: options = {} merged_options = {**self.options, **options} @@ -202,9 +211,19 @@ def decode( key: str = "", algorithms: Optional[List[str]] = None, options: Optional[Dict[str, Any]] = None, + detached_payload: Optional[bytes] = None, **kwargs, ) -> str: - decoded = self.decode_complete(jwt, key, algorithms, options, **kwargs) + if kwargs: + warnings.warn( + "passing additional kwargs to decode() is deprecated " + "and will be removed in pyjwt version 3. " + f"Unsupported kwargs: {tuple(kwargs.keys())}", + RemovedInPyjwt3Warning, + ) + decoded = self.decode_complete( + jwt, key, algorithms, options, detached_payload=detached_payload + ) return decoded["payload"] def get_unverified_header(self, jwt): diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index a011c0fe..b08e9502 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -15,6 +15,7 @@ InvalidIssuerError, MissingRequiredClaimError, ) +from .warnings import RemovedInPyjwt3Warning class PyJWT: @@ -69,15 +70,32 @@ def decode_complete( key: str = "", algorithms: Optional[List[str]] = None, options: Optional[Dict[str, Any]] = None, + # deprecated arg, remove in pyjwt3 + verify: Optional[bool] = None, + # could be used as passthrough to api_jws, consider removal in pyjwt3 + detached_payload: Optional[bytes] = None, + # passthrough arguments to _validate_claims + # consider putting in options + audience: Optional[str] = None, + issuer: Optional[str] = None, + leeway: Union[int, float, timedelta] = 0, + # kwargs **kwargs, ) -> Dict[str, Any]: + if kwargs: + warnings.warn( + "passing additional kwargs to decode_complete() is deprecated " + "and will be removed in pyjwt version 3. " + f"Unsupported kwargs: {tuple(kwargs.keys())}", + RemovedInPyjwt3Warning, + ) options = dict(options or {}) # shallow-copy or initialize an empty dict options.setdefault("verify_signature", True) # If the user has set the legacy `verify` argument, and it doesn't match # what the relevant `options` entry for the argument is, inform the user # that they're likely making a mistake. - if "verify" in kwargs and kwargs["verify"] != options["verify_signature"]: + if verify is not None and verify != options["verify_signature"]: warnings.warn( "The `verify` argument to `decode` does nothing in PyJWT 2.0 and newer. " "The equivalent is setting `verify_signature` to False in the `options` dictionary. " @@ -102,7 +120,7 @@ def decode_complete( key=key, algorithms=algorithms, options=options, - **kwargs, + detached_payload=detached_payload, ) try: @@ -113,7 +131,9 @@ def decode_complete( raise DecodeError("Invalid payload string: must be a json object") merged_options = {**self.options, **options} - self._validate_claims(payload, merged_options, **kwargs) + self._validate_claims( + payload, merged_options, audience=audience, issuer=issuer, leeway=leeway + ) decoded["payload"] = payload return decoded @@ -124,14 +144,39 @@ def decode( key: str = "", algorithms: Optional[List[str]] = None, options: Optional[Dict[str, Any]] = None, + # deprecated arg, remove in pyjwt3 + verify: Optional[bool] = None, + # could be used as passthrough to api_jws, consider removal in pyjwt3 + detached_payload: Optional[bytes] = None, + # passthrough arguments to _validate_claims + # consider putting in options + audience: Optional[str] = None, + issuer: Optional[str] = None, + leeway: Union[int, float, timedelta] = 0, + # kwargs **kwargs, ) -> Dict[str, Any]: - decoded = self.decode_complete(jwt, key, algorithms, options, **kwargs) + if kwargs: + warnings.warn( + "passing additional kwargs to decode() is deprecated " + "and will be removed in pyjwt version 3. " + f"Unsupported kwargs: {tuple(kwargs.keys())}", + RemovedInPyjwt3Warning, + ) + decoded = self.decode_complete( + jwt, + key, + algorithms, + options, + verify=verify, + detached_payload=detached_payload, + audience=audience, + issuer=issuer, + leeway=leeway, + ) return decoded["payload"] - def _validate_claims( - self, payload, options, audience=None, issuer=None, leeway=0, **kwargs - ): + def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=0): if isinstance(leeway, timedelta): leeway = leeway.total_seconds() diff --git a/jwt/warnings.py b/jwt/warnings.py new file mode 100644 index 00000000..8762a8cb --- /dev/null +++ b/jwt/warnings.py @@ -0,0 +1,2 @@ +class RemovedInPyjwt3Warning(DeprecationWarning): + pass diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py index 23975fa4..cfbbe212 100644 --- a/tests/test_api_jws.py +++ b/tests/test_api_jws.py @@ -12,6 +12,7 @@ InvalidTokenError, ) from jwt.utils import base64url_decode +from jwt.warnings import RemovedInPyjwt3Warning from .utils import crypto_required, key_path, no_crypto_required @@ -770,3 +771,37 @@ def test_decode_detached_content_without_proper_argument(self, jws): 'It is required that you pass in a value for the "detached_payload" argument to decode a message having the b64 header set to false.' in str(exc.value) ) + + def test_decode_warns_on_unsupported_kwarg(self, jws, payload): + secret = "secret" + jws_message = jws.encode( + payload, secret, algorithm="HS256", is_payload_detached=True + ) + + with pytest.warns(RemovedInPyjwt3Warning) as record: + jws.decode( + jws_message, + secret, + algorithms=["HS256"], + detached_payload=payload, + foo="bar", + ) + assert len(record) == 1 + assert "foo" in str(record[0].message) + + def test_decode_complete_warns_on_unuspported_kwarg(self, jws, payload): + secret = "secret" + jws_message = jws.encode( + payload, secret, algorithm="HS256", is_payload_detached=True + ) + + with pytest.warns(RemovedInPyjwt3Warning) as record: + jws.decode_complete( + jws_message, + secret, + algorithms=["HS256"], + detached_payload=payload, + foo="bar", + ) + assert len(record) == 1 + assert "foo" in str(record[0].message) diff --git a/tests/test_api_jwt.py b/tests/test_api_jwt.py index 84e41e0e..d0443e89 100644 --- a/tests/test_api_jwt.py +++ b/tests/test_api_jwt.py @@ -17,6 +17,7 @@ MissingRequiredClaimError, ) from jwt.utils import base64url_decode +from jwt.warnings import RemovedInPyjwt3Warning from .utils import crypto_required, key_path, utc_timestamp @@ -682,3 +683,21 @@ def test_decode_no_options_mutation(self, jwt, payload): jwt_message = jwt.encode(payload, secret) jwt.decode(jwt_message, secret, options=options, algorithms=["HS256"]) assert options == orig_options + + def test_decode_warns_on_unsupported_kwarg(self, jwt, payload): + secret = "secret" + jwt_message = jwt.encode(payload, secret) + + with pytest.warns(RemovedInPyjwt3Warning) as record: + jwt.decode(jwt_message, secret, algorithms=["HS256"], foo="bar") + assert len(record) == 1 + assert "foo" in str(record[0].message) + + def test_decode_complete_warns_on_unsupported_kwarg(self, jwt, payload): + secret = "secret" + jwt_message = jwt.encode(payload, secret) + + with pytest.warns(RemovedInPyjwt3Warning) as record: + jwt.decode_complete(jwt_message, secret, algorithms=["HS256"], foo="bar") + assert len(record) == 1 + assert "foo" in str(record[0].message)