diff --git a/jwt/api_jws.py b/jwt/api_jws.py index a61d2277..c0c7409d 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -134,6 +134,7 @@ def decode_complete( key: str = "", algorithms: List[str] = None, options: Dict = None, + **kwargs, ) -> Dict[str, Any]: if options is None: options = {} @@ -162,8 +163,9 @@ def decode( key: str = "", algorithms: List[str] = None, options: Dict = None, + **kwargs, ) -> str: - decoded = self.decode_complete(jwt, key, algorithms, options) + decoded = self.decode_complete(jwt, key, algorithms, options, **kwargs) return decoded["payload"] def get_unverified_header(self, jwt): diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index 7e21b754..f3b55d36 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -68,9 +68,7 @@ def decode_complete( key: str = "", algorithms: List[str] = None, options: Dict = None, - audience: Optional[Union[str, List[str]]] = None, - issuer: Optional[str] = None, - leeway: Union[float, timedelta] = 0, + **kwargs, ) -> Dict[str, Any]: if options is None: options = {"verify_signature": True} @@ -94,6 +92,7 @@ def decode_complete( key=key, algorithms=algorithms, options=options, + **kwargs, ) try: @@ -104,7 +103,7 @@ def decode_complete( raise DecodeError("Invalid payload string: must be a json object") merged_options = {**self.options, **options} - self._validate_claims(payload, merged_options, audience, issuer, leeway) + self._validate_claims(payload, merged_options, **kwargs) decoded["payload"] = payload return decoded @@ -115,20 +114,18 @@ def decode( key: str = "", algorithms: List[str] = None, options: Dict = None, - audience: Optional[Union[str, List[str]]] = None, - issuer: Optional[str] = None, - leeway: Union[float, timedelta] = 0, + **kwargs, ) -> Dict[str, Any]: - decoded = self.decode_complete( - jwt, key, algorithms, options, audience, issuer, leeway - ) + decoded = self.decode_complete(jwt, key, algorithms, options, **kwargs) return decoded["payload"] - def _validate_claims(self, payload, options, audience, issuer, leeway): + def _validate_claims( + self, payload, options, audience=None, issuer=None, leeway=0, **kwargs + ): if isinstance(leeway, timedelta): leeway = leeway.total_seconds() - if not isinstance(audience, (str, type(None), Iterable)): + if not isinstance(audience, (bytes, str, type(None), Iterable)): raise TypeError("audience must be a string, iterable, or None") self._validate_required_claims(payload, options) diff --git a/tests/test_api_jwt.py b/tests/test_api_jwt.py index 53fd150e..fa3167a4 100644 --- a/tests/test_api_jwt.py +++ b/tests/test_api_jwt.py @@ -106,17 +106,6 @@ def test_decode_with_non_mapping_payload_throws_exception(self, jwt): exception = context.value assert str(exception) == "Invalid payload string: must be a json object" - def test_decode_with_unknown_parameter_throws_exception(self, jwt): - secret = "secret" - example_jwt = ( - b"eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9" - b".eyJoZWxsbyI6ICJ3b3JsZCJ9" - b".tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8" - ) - - with pytest.raises(TypeError): - jwt.decode(example_jwt, key=secret, foo="bar", algorithms=["HS256"]) - def test_decode_with_invalid_audience_param_throws_exception(self, jwt): secret = "secret" example_jwt = (