diff --git a/jwt/api_jws.py b/jwt/api_jws.py index 9e8e1787..f32de8fb 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -80,34 +80,54 @@ def encode( algorithm: Optional[str] = "HS256", headers: Optional[Dict] = None, json_encoder: Optional[Type[json.JSONEncoder]] = None, + is_payload_detached: bool = False, ) -> str: segments = [] if algorithm is None: algorithm = "none" - # Prefer headers["alg"] if present to algorithm parameter. - if headers and "alg" in headers and headers["alg"]: - algorithm = headers["alg"] + # Prefer headers values if present to function parameters. + if headers: + headers_alg = headers.get("alg") + if headers_alg: + algorithm = headers["alg"] + + headers_b64 = headers.get("b64") + if headers_b64 is False: + is_payload_detached = True # Header - header = {"typ": self.header_typ, "alg": algorithm} + header = {"typ": self.header_typ, "alg": algorithm} # type: Dict[str, Any] if headers: self._validate_headers(headers) header.update(headers) - if not header["typ"]: - del header["typ"] + + if not header["typ"]: + del header["typ"] + + if is_payload_detached: + header["b64"] = False + elif "b64" in header: + # True is the standard value for b64, so no need for it + del header["b64"] json_header = json.dumps( header, separators=(",", ":"), cls=json_encoder ).encode() segments.append(base64url_encode(json_header)) - segments.append(base64url_encode(payload)) + + if is_payload_detached: + msg_payload = payload + else: + msg_payload = base64url_encode(payload) + segments.append(msg_payload) # Segments signing_input = b".".join(segments) + try: alg_obj = self._algorithms[algorithm] key = alg_obj.prepare_key(key) @@ -119,11 +139,13 @@ def encode( "Algorithm '%s' could not be found. Do you have cryptography " "installed?" % algorithm ) from e - else: - raise NotImplementedError("Algorithm not supported") from e + raise NotImplementedError("Algorithm not supported") from e segments.append(base64url_encode(signature)) + # Don't put the payload content inside the encoded token when detached + if is_payload_detached: + segments[1] = b"" encoded_string = b".".join(segments) return encoded_string.decode("utf-8") @@ -134,6 +156,7 @@ def decode_complete( key: str = "", algorithms: Optional[List[str]] = None, options: Optional[Dict] = None, + detached_payload: Optional[bytes] = None, **kwargs, ) -> Dict[str, Any]: if options is None: @@ -148,6 +171,14 @@ def decode_complete( payload, signing_input, header, signature = self._load(jwt) + if header.get("b64", True) is False: + if detached_payload is None: + raise DecodeError( + 'It is required that you pass in a value for the "detached_payload" argument to decode a message having the b64 header set to false.' + ) + payload = detached_payload + signing_input = b".".join([signing_input.rsplit(b".", 1)[0], payload]) + if verify_signature: self._verify_signature(signing_input, header, signature, key, algorithms) diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py index 07d15110..0a0e2954 100644 --- a/tests/test_api_jws.py +++ b/tests/test_api_jws.py @@ -719,3 +719,54 @@ def test_encode_fails_on_invalid_kid_types(self, jws, payload): jws.encode(payload, "secret", headers={"kid": None}) assert "Key ID header parameter must be a string" == str(exc.value) + + def test_encode_decode_with_detached_content(self, jws, payload): + secret = "secret" + jws_message = jws.encode( + payload, secret, algorithm="HS256", is_payload_detached=True + ) + + jws.decode(jws_message, secret, algorithms=["HS256"], detached_payload=payload) + + def test_encode_detached_content_with_b64_header(self, jws, payload): + secret = "secret" + + # Check that detached content is automatically detected when b64 is false + headers = {"b64": False} + token = jws.encode(payload, secret, "HS256", headers) + + msg_header, msg_payload, _ = token.split(".") + msg_header = base64url_decode(msg_header.encode()) + msg_header_obj = json.loads(msg_header) + + assert "b64" in msg_header_obj + assert msg_header_obj["b64"] is False + # Check that the payload is not inside the token + assert not msg_payload + + # Check that content is not detached and b64 header removed when b64 is true + headers = {"b64": True} + token = jws.encode(payload, secret, "HS256", headers) + + msg_header, msg_payload, _ = token.split(".") + msg_header = base64url_decode(msg_header.encode()) + msg_header_obj = json.loads(msg_header) + + assert "b64" not in msg_header_obj + assert msg_payload + + def test_decode_detached_content_without_proper_argument(self, jws): + example_jws = ( + "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImI2NCI6ZmFsc2V9" + "." + ".65yNkX_ZH4A_6pHaTL_eI84OXOHtfl4K0k5UnlXZ8f4" + ) + example_secret = "secret" + + with pytest.raises(DecodeError) as exc: + jws.decode(example_jws, example_secret, algorithms=["HS256"]) + + assert ( + 'It is required that you pass in a value for the "detached_payload" argument to decode a message having the b64 header set to false.' + in str(exc.value) + )