diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 763822a1..b049e9e8 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -13,6 +13,7 @@ Changed Fixed ~~~~~ +- Prefer `headers["alg"]` to `algorithm` in `jwt.encode()`. `#673 `__ - Fix aud validation to support {'aud': null} case. `#670 `__ Added diff --git a/docs/api.rst b/docs/api.rst index d23d4bdd..f1720b93 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -13,8 +13,9 @@ API Reference * for **asymmetric algorithms**: PEM-formatted private key, a multiline string * for **symmetric algorithms**: plain string, sufficiently long for security - :param str algorithm: algorithm to sign the token with, e.g. ``"ES256"`` - :param dict headers: additional JWT header fields, e.g. ``dict(kid="my-key-id")`` + :param str algorithm: algorithm to sign the token with, e.g. ``"ES256"``. + If ``headers`` includes ``alg``, it will be preferred to this parameter. + :param dict headers: additional JWT header fields, e.g. ``dict(kid="my-key-id")``. :param json.JSONEncoder json_encoder: custom JSON encoder for ``payload`` and ``headers`` :rtype: str :returns: a JSON Web Token diff --git a/jwt/api_jws.py b/jwt/api_jws.py index 6d136999..6cefb2c2 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -77,7 +77,7 @@ def encode( self, payload: bytes, key: str, - algorithm: str = "HS256", + algorithm: Optional[str] = "HS256", headers: Optional[Dict] = None, json_encoder: Optional[Type[json.JSONEncoder]] = None, ) -> str: @@ -86,6 +86,10 @@ def encode( if algorithm is None: algorithm = "none" + # Prefer headers["alg"] if present to algorithm parameter. + if headers and "alg" in headers and headers["alg"]: + algorithm = headers["alg"] + if algorithm not in self._valid_algs: pass diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index 55a8f291..48a93162 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -38,7 +38,7 @@ def encode( self, payload: Dict[str, Any], key: str, - algorithm: str = "HS256", + algorithm: Optional[str] = "HS256", headers: Optional[Dict] = None, json_encoder: Optional[Type[json.JSONEncoder]] = None, ) -> str: diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py index b731edce..fbeda024 100644 --- a/tests/test_api_jws.py +++ b/tests/test_api_jws.py @@ -166,6 +166,32 @@ def test_encode_algorithm_param_should_be_case_sensitive(self, jws, payload): exception = context.value assert str(exception) == "Algorithm not supported" + def test_encode_with_headers_alg_none(self, jws, payload): + msg = jws.encode(payload, key=None, headers={"alg": "none"}) + with pytest.raises(DecodeError) as context: + jws.decode(msg, algorithms=["none"]) + assert str(context.value) == "Signature verification failed" + + @crypto_required + def test_encode_with_headers_alg_es256(self, jws, payload): + with open(key_path("testkey_ec.priv"), "rb") as ec_priv_file: + priv_key = load_pem_private_key(ec_priv_file.read(), password=None) + with open(key_path("testkey_ec.pub"), "rb") as ec_pub_file: + pub_key = load_pem_public_key(ec_pub_file.read()) + + msg = jws.encode(payload, priv_key, headers={"alg": "ES256"}) + assert b"hello world" == jws.decode(msg, pub_key, algorithms=["ES256"]) + + @crypto_required + def test_encode_with_alg_hs256_and_headers_alg_es256(self, jws, payload): + with open(key_path("testkey_ec.priv"), "rb") as ec_priv_file: + priv_key = load_pem_private_key(ec_priv_file.read(), password=None) + with open(key_path("testkey_ec.pub"), "rb") as ec_pub_file: + pub_key = load_pem_public_key(ec_pub_file.read()) + + msg = jws.encode(payload, priv_key, algorithm="HS256", headers={"alg": "ES256"}) + assert b"hello world" == jws.decode(msg, pub_key, algorithms=["ES256"]) + def test_decode_algorithm_param_should_be_case_sensitive(self, jws): example_jws = ( "eyJhbGciOiJoczI1NiIsInR5cCI6IkpXVCJ9" # alg = hs256