Skip to content

Commit

Permalink
Prefer headers['alg'] to algorithm parameter in encode(). (#673)
Browse files Browse the repository at this point in the history
* Prefer headers['alg'] to algorithm parameter in encode().

* Fix lack of @crypto_required.

* Prefer headers['alg'] to algorithm parameter in encode().

* Prefer headers['alg'] to algorithm parameter in encode().

* Make algorithm parameter of encode() Optioanl explicitly.
  • Loading branch information
dajiaji committed Aug 6, 2021
1 parent cfe5261 commit 8de4428
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Expand Up @@ -13,6 +13,7 @@ Changed
Fixed
~~~~~

- Prefer `headers["alg"]` to `algorithm` in `jwt.encode()`. `#673 <https://github.com/jpadilla/pyjwt/pull/673>`__
- Fix aud validation to support {'aud': null} case. `#670 <https://github.com/jpadilla/pyjwt/pull/670>`__

Added
Expand Down
5 changes: 3 additions & 2 deletions docs/api.rst
Expand Up @@ -13,8 +13,9 @@ API Reference
* for **asymmetric algorithms**: PEM-formatted private key, a multiline string
* for **symmetric algorithms**: plain string, sufficiently long for security

:param str algorithm: algorithm to sign the token with, e.g. ``"ES256"``
:param dict headers: additional JWT header fields, e.g. ``dict(kid="my-key-id")``
:param str algorithm: algorithm to sign the token with, e.g. ``"ES256"``.
If ``headers`` includes ``alg``, it will be preferred to this parameter.
:param dict headers: additional JWT header fields, e.g. ``dict(kid="my-key-id")``.
:param json.JSONEncoder json_encoder: custom JSON encoder for ``payload`` and ``headers``
:rtype: str
:returns: a JSON Web Token
Expand Down
6 changes: 5 additions & 1 deletion jwt/api_jws.py
Expand Up @@ -77,7 +77,7 @@ def encode(
self,
payload: bytes,
key: str,
algorithm: str = "HS256",
algorithm: Optional[str] = "HS256",
headers: Optional[Dict] = None,
json_encoder: Optional[Type[json.JSONEncoder]] = None,
) -> str:
Expand All @@ -86,6 +86,10 @@ def encode(
if algorithm is None:
algorithm = "none"

# Prefer headers["alg"] if present to algorithm parameter.
if headers and "alg" in headers and headers["alg"]:
algorithm = headers["alg"]

if algorithm not in self._valid_algs:
pass

Expand Down
2 changes: 1 addition & 1 deletion jwt/api_jwt.py
Expand Up @@ -38,7 +38,7 @@ def encode(
self,
payload: Dict[str, Any],
key: str,
algorithm: str = "HS256",
algorithm: Optional[str] = "HS256",
headers: Optional[Dict] = None,
json_encoder: Optional[Type[json.JSONEncoder]] = None,
) -> str:
Expand Down
26 changes: 26 additions & 0 deletions tests/test_api_jws.py
Expand Up @@ -166,6 +166,32 @@ def test_encode_algorithm_param_should_be_case_sensitive(self, jws, payload):
exception = context.value
assert str(exception) == "Algorithm not supported"

def test_encode_with_headers_alg_none(self, jws, payload):
msg = jws.encode(payload, key=None, headers={"alg": "none"})
with pytest.raises(DecodeError) as context:
jws.decode(msg, algorithms=["none"])
assert str(context.value) == "Signature verification failed"

@crypto_required
def test_encode_with_headers_alg_es256(self, jws, payload):
with open(key_path("testkey_ec.priv"), "rb") as ec_priv_file:
priv_key = load_pem_private_key(ec_priv_file.read(), password=None)
with open(key_path("testkey_ec.pub"), "rb") as ec_pub_file:
pub_key = load_pem_public_key(ec_pub_file.read())

msg = jws.encode(payload, priv_key, headers={"alg": "ES256"})
assert b"hello world" == jws.decode(msg, pub_key, algorithms=["ES256"])

@crypto_required
def test_encode_with_alg_hs256_and_headers_alg_es256(self, jws, payload):
with open(key_path("testkey_ec.priv"), "rb") as ec_priv_file:
priv_key = load_pem_private_key(ec_priv_file.read(), password=None)
with open(key_path("testkey_ec.pub"), "rb") as ec_pub_file:
pub_key = load_pem_public_key(ec_pub_file.read())

msg = jws.encode(payload, priv_key, algorithm="HS256", headers={"alg": "ES256"})
assert b"hello world" == jws.decode(msg, pub_key, algorithms=["ES256"])

def test_decode_algorithm_param_should_be_case_sensitive(self, jws):
example_jws = (
"eyJhbGciOiJoczI1NiIsInR5cCI6IkpXVCJ9" # alg = hs256
Expand Down

0 comments on commit 8de4428

Please sign in to comment.