diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 3ed65591..fb88a294 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -11,6 +11,9 @@ Changed ~~~~~~~ - Skip keys with incompatible alg when loading JWKSet by @DaGuich in https://github.com/jpadilla/pyjwt/pull/762 - Remove support for python3.6 +- PyJWT now emits a warning for unsupported keyword arguments being passed to + ``decode`` and ``decode_complete``. Additional keyword arguments are still + supported, but will be rejected in a future version. Fixed ~~~~~ diff --git a/jwt/api_jws.py b/jwt/api_jws.py index f8c8048d..75f826da 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -1,5 +1,6 @@ import binascii import json +import warnings from collections.abc import Mapping from typing import Any, Dict, List, Optional, Type @@ -16,6 +17,7 @@ InvalidTokenError, ) from .utils import base64url_decode, base64url_encode +from .warnings import RemovedInPyjwt3Warning class PyJWS: @@ -167,6 +169,13 @@ def decode_complete( detached_payload: Optional[bytes] = None, **kwargs, ) -> Dict[str, Any]: + if kwargs: + warnings.warn( + "passing additional kwargs to decode_complete() is deprecated " + "and will be removed in pyjwt version 3. " + f"Unsupported kwargs: {tuple(kwargs.keys())}", + RemovedInPyjwt3Warning, + ) if options is None: options = {} merged_options = {**self.options, **options} @@ -202,9 +211,19 @@ def decode( key: str = "", algorithms: Optional[List[str]] = None, options: Optional[Dict[str, Any]] = None, + detached_payload: Optional[bytes] = None, **kwargs, ) -> str: - decoded = self.decode_complete(jwt, key, algorithms, options, **kwargs) + if kwargs: + warnings.warn( + "passing additional kwargs to decode() is deprecated " + "and will be removed in pyjwt version 3. " + f"Unsupported kwargs: {tuple(kwargs.keys())}", + RemovedInPyjwt3Warning, + ) + decoded = self.decode_complete( + jwt, key, algorithms, options, detached_payload=detached_payload + ) return decoded["payload"] def get_unverified_header(self, jwt): diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index a011c0fe..b08e9502 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -15,6 +15,7 @@ InvalidIssuerError, MissingRequiredClaimError, ) +from .warnings import RemovedInPyjwt3Warning class PyJWT: @@ -69,15 +70,32 @@ def decode_complete( key: str = "", algorithms: Optional[List[str]] = None, options: Optional[Dict[str, Any]] = None, + # deprecated arg, remove in pyjwt3 + verify: Optional[bool] = None, + # could be used as passthrough to api_jws, consider removal in pyjwt3 + detached_payload: Optional[bytes] = None, + # passthrough arguments to _validate_claims + # consider putting in options + audience: Optional[str] = None, + issuer: Optional[str] = None, + leeway: Union[int, float, timedelta] = 0, + # kwargs **kwargs, ) -> Dict[str, Any]: + if kwargs: + warnings.warn( + "passing additional kwargs to decode_complete() is deprecated " + "and will be removed in pyjwt version 3. " + f"Unsupported kwargs: {tuple(kwargs.keys())}", + RemovedInPyjwt3Warning, + ) options = dict(options or {}) # shallow-copy or initialize an empty dict options.setdefault("verify_signature", True) # If the user has set the legacy `verify` argument, and it doesn't match # what the relevant `options` entry for the argument is, inform the user # that they're likely making a mistake. - if "verify" in kwargs and kwargs["verify"] != options["verify_signature"]: + if verify is not None and verify != options["verify_signature"]: warnings.warn( "The `verify` argument to `decode` does nothing in PyJWT 2.0 and newer. " "The equivalent is setting `verify_signature` to False in the `options` dictionary. " @@ -102,7 +120,7 @@ def decode_complete( key=key, algorithms=algorithms, options=options, - **kwargs, + detached_payload=detached_payload, ) try: @@ -113,7 +131,9 @@ def decode_complete( raise DecodeError("Invalid payload string: must be a json object") merged_options = {**self.options, **options} - self._validate_claims(payload, merged_options, **kwargs) + self._validate_claims( + payload, merged_options, audience=audience, issuer=issuer, leeway=leeway + ) decoded["payload"] = payload return decoded @@ -124,14 +144,39 @@ def decode( key: str = "", algorithms: Optional[List[str]] = None, options: Optional[Dict[str, Any]] = None, + # deprecated arg, remove in pyjwt3 + verify: Optional[bool] = None, + # could be used as passthrough to api_jws, consider removal in pyjwt3 + detached_payload: Optional[bytes] = None, + # passthrough arguments to _validate_claims + # consider putting in options + audience: Optional[str] = None, + issuer: Optional[str] = None, + leeway: Union[int, float, timedelta] = 0, + # kwargs **kwargs, ) -> Dict[str, Any]: - decoded = self.decode_complete(jwt, key, algorithms, options, **kwargs) + if kwargs: + warnings.warn( + "passing additional kwargs to decode() is deprecated " + "and will be removed in pyjwt version 3. " + f"Unsupported kwargs: {tuple(kwargs.keys())}", + RemovedInPyjwt3Warning, + ) + decoded = self.decode_complete( + jwt, + key, + algorithms, + options, + verify=verify, + detached_payload=detached_payload, + audience=audience, + issuer=issuer, + leeway=leeway, + ) return decoded["payload"] - def _validate_claims( - self, payload, options, audience=None, issuer=None, leeway=0, **kwargs - ): + def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=0): if isinstance(leeway, timedelta): leeway = leeway.total_seconds() diff --git a/jwt/warnings.py b/jwt/warnings.py new file mode 100644 index 00000000..8762a8cb --- /dev/null +++ b/jwt/warnings.py @@ -0,0 +1,2 @@ +class RemovedInPyjwt3Warning(DeprecationWarning): + pass diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py index 23975fa4..cfbbe212 100644 --- a/tests/test_api_jws.py +++ b/tests/test_api_jws.py @@ -12,6 +12,7 @@ InvalidTokenError, ) from jwt.utils import base64url_decode +from jwt.warnings import RemovedInPyjwt3Warning from .utils import crypto_required, key_path, no_crypto_required @@ -770,3 +771,37 @@ def test_decode_detached_content_without_proper_argument(self, jws): 'It is required that you pass in a value for the "detached_payload" argument to decode a message having the b64 header set to false.' in str(exc.value) ) + + def test_decode_warns_on_unsupported_kwarg(self, jws, payload): + secret = "secret" + jws_message = jws.encode( + payload, secret, algorithm="HS256", is_payload_detached=True + ) + + with pytest.warns(RemovedInPyjwt3Warning) as record: + jws.decode( + jws_message, + secret, + algorithms=["HS256"], + detached_payload=payload, + foo="bar", + ) + assert len(record) == 1 + assert "foo" in str(record[0].message) + + def test_decode_complete_warns_on_unuspported_kwarg(self, jws, payload): + secret = "secret" + jws_message = jws.encode( + payload, secret, algorithm="HS256", is_payload_detached=True + ) + + with pytest.warns(RemovedInPyjwt3Warning) as record: + jws.decode_complete( + jws_message, + secret, + algorithms=["HS256"], + detached_payload=payload, + foo="bar", + ) + assert len(record) == 1 + assert "foo" in str(record[0].message) diff --git a/tests/test_api_jwt.py b/tests/test_api_jwt.py index 84e41e0e..d0443e89 100644 --- a/tests/test_api_jwt.py +++ b/tests/test_api_jwt.py @@ -17,6 +17,7 @@ MissingRequiredClaimError, ) from jwt.utils import base64url_decode +from jwt.warnings import RemovedInPyjwt3Warning from .utils import crypto_required, key_path, utc_timestamp @@ -682,3 +683,21 @@ def test_decode_no_options_mutation(self, jwt, payload): jwt_message = jwt.encode(payload, secret) jwt.decode(jwt_message, secret, options=options, algorithms=["HS256"]) assert options == orig_options + + def test_decode_warns_on_unsupported_kwarg(self, jwt, payload): + secret = "secret" + jwt_message = jwt.encode(payload, secret) + + with pytest.warns(RemovedInPyjwt3Warning) as record: + jwt.decode(jwt_message, secret, algorithms=["HS256"], foo="bar") + assert len(record) == 1 + assert "foo" in str(record[0].message) + + def test_decode_complete_warns_on_unsupported_kwarg(self, jwt, payload): + secret = "secret" + jwt_message = jwt.encode(payload, secret) + + with pytest.warns(RemovedInPyjwt3Warning) as record: + jwt.decode_complete(jwt_message, secret, algorithms=["HS256"], foo="bar") + assert len(record) == 1 + assert "foo" in str(record[0].message)