Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add detached payload support for JWS encoding and decoding #723

Merged
merged 1 commit into from Mar 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
49 changes: 40 additions & 9 deletions jwt/api_jws.py
Expand Up @@ -80,34 +80,54 @@ def encode(
algorithm: Optional[str] = "HS256",
headers: Optional[Dict] = None,
json_encoder: Optional[Type[json.JSONEncoder]] = None,
is_payload_detached: bool = False,
) -> str:
segments = []

if algorithm is None:
algorithm = "none"

# Prefer headers["alg"] if present to algorithm parameter.
if headers and "alg" in headers and headers["alg"]:
algorithm = headers["alg"]
# Prefer headers values if present to function parameters.
if headers:
headers_alg = headers.get("alg")
if headers_alg:
algorithm = headers["alg"]

headers_b64 = headers.get("b64")
if headers_b64 is False:
is_payload_detached = True

# Header
header = {"typ": self.header_typ, "alg": algorithm}
header = {"typ": self.header_typ, "alg": algorithm} # type: Dict[str, Any]

if headers:
self._validate_headers(headers)
header.update(headers)
if not header["typ"]:
del header["typ"]

if not header["typ"]:
del header["typ"]

if is_payload_detached:
header["b64"] = False
elif "b64" in header:
# True is the standard value for b64, so no need for it
del header["b64"]

json_header = json.dumps(
header, separators=(",", ":"), cls=json_encoder
).encode()

segments.append(base64url_encode(json_header))
segments.append(base64url_encode(payload))

if is_payload_detached:
msg_payload = payload
else:
msg_payload = base64url_encode(payload)
segments.append(msg_payload)

# Segments
signing_input = b".".join(segments)

try:
alg_obj = self._algorithms[algorithm]
key = alg_obj.prepare_key(key)
Expand All @@ -119,11 +139,13 @@ def encode(
"Algorithm '%s' could not be found. Do you have cryptography "
"installed?" % algorithm
) from e
else:
raise NotImplementedError("Algorithm not supported") from e
raise NotImplementedError("Algorithm not supported") from e

segments.append(base64url_encode(signature))

# Don't put the payload content inside the encoded token when detached
if is_payload_detached:
segments[1] = b""
encoded_string = b".".join(segments)

return encoded_string.decode("utf-8")
Expand All @@ -134,6 +156,7 @@ def decode_complete(
key: str = "",
algorithms: Optional[List[str]] = None,
options: Optional[Dict] = None,
detached_payload: Optional[bytes] = None,
**kwargs,
) -> Dict[str, Any]:
if options is None:
Expand All @@ -148,6 +171,14 @@ def decode_complete(

payload, signing_input, header, signature = self._load(jwt)

if header.get("b64", True) is False:
if detached_payload is None:
raise DecodeError(
'It is required that you pass in a value for the "detached_payload" argument to decode a message having the b64 header set to false.'
)
payload = detached_payload
signing_input = b".".join([signing_input.rsplit(b".", 1)[0], payload])

if verify_signature:
self._verify_signature(signing_input, header, signature, key, algorithms)

Expand Down
51 changes: 51 additions & 0 deletions tests/test_api_jws.py
Expand Up @@ -719,3 +719,54 @@ def test_encode_fails_on_invalid_kid_types(self, jws, payload):
jws.encode(payload, "secret", headers={"kid": None})

assert "Key ID header parameter must be a string" == str(exc.value)

def test_encode_decode_with_detached_content(self, jws, payload):
secret = "secret"
jws_message = jws.encode(
payload, secret, algorithm="HS256", is_payload_detached=True
)

jws.decode(jws_message, secret, algorithms=["HS256"], detached_payload=payload)

def test_encode_detached_content_with_b64_header(self, jws, payload):
secret = "secret"

# Check that detached content is automatically detected when b64 is false
headers = {"b64": False}
token = jws.encode(payload, secret, "HS256", headers)

msg_header, msg_payload, _ = token.split(".")
msg_header = base64url_decode(msg_header.encode())
msg_header_obj = json.loads(msg_header)

assert "b64" in msg_header_obj
assert msg_header_obj["b64"] is False
# Check that the payload is not inside the token
assert not msg_payload

# Check that content is not detached and b64 header removed when b64 is true
headers = {"b64": True}
token = jws.encode(payload, secret, "HS256", headers)

msg_header, msg_payload, _ = token.split(".")
msg_header = base64url_decode(msg_header.encode())
msg_header_obj = json.loads(msg_header)

assert "b64" not in msg_header_obj
assert msg_payload

def test_decode_detached_content_without_proper_argument(self, jws):
example_jws = (
"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImI2NCI6ZmFsc2V9"
"."
".65yNkX_ZH4A_6pHaTL_eI84OXOHtfl4K0k5UnlXZ8f4"
)
example_secret = "secret"

with pytest.raises(DecodeError) as exc:
jws.decode(example_jws, example_secret, algorithms=["HS256"])

assert (
'It is required that you pass in a value for the "detached_payload" argument to decode a message having the b64 header set to false.'
in str(exc.value)
)