From d5c93e40418ae55df7d0317e9ebd88be0250629a Mon Sep 17 00:00:00 2001 From: Lukas Lihotzki Date: Fri, 15 Sep 2023 12:51:21 +0200 Subject: [PATCH] ServiceApplicationClient: Add extra_jwt_headers --- .../oauth2/rfc6749/clients/service_application.py | 9 +++++++-- .../rfc6749/clients/test_service_application.py | 11 ++++++++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/oauthlib/oauth2/rfc6749/clients/service_application.py b/oauthlib/oauth2/rfc6749/clients/service_application.py index 8fb17377..3a76c1d6 100644 --- a/oauthlib/oauth2/rfc6749/clients/service_application.py +++ b/oauthlib/oauth2/rfc6749/clients/service_application.py @@ -68,6 +68,7 @@ def prepare_request_body(self, audience=None, expires_at=None, issued_at=None, + extra_jwt_headers=None, extra_claims=None, body='', scope=None, @@ -96,7 +97,11 @@ def prepare_request_body(self, :param issued_at: A unix timestamp of when the JWT was created. Defaults to now, i.e. ``time.time()``. - :param extra_claims: A dict of additional claims to include in the JWT. + :param extra_jwt_headers: A dict of additional headers to include + in the JWT header. + + :param extra_claims: A dict of additional claims to include + in the JWT payload. :param body: Existing request body (URL encoded string) to embed parameters into. This may contain extra parameters. Default ''. @@ -176,7 +181,7 @@ def prepare_request_body(self, claim.update(extra_claims or {}) - assertion = jwt.encode(claim, key, 'RS256') + assertion = jwt.encode(claim, key, 'RS256', extra_jwt_headers) assertion = to_unicode(assertion) kwargs['client_id'] = self.client_id diff --git a/tests/oauth2/rfc6749/clients/test_service_application.py b/tests/oauth2/rfc6749/clients/test_service_application.py index 84361d8b..08a599eb 100644 --- a/tests/oauth2/rfc6749/clients/test_service_application.py +++ b/tests/oauth2/rfc6749/clients/test_service_application.py @@ -114,18 +114,24 @@ def test_request_body(self, t): # Optional kwargs not_before = time() - 3600 jwt_id = '8zd15df4s35f43sd' + extra_jwt_headers = {'extra': 'header'} + extra_claims = {'extra': 'claim'} body = client.prepare_request_body(issuer=self.issuer, subject=self.subject, audience=self.audience, body=self.body, not_before=not_before, + extra_jwt_headers=extra_jwt_headers, + extra_claims=extra_claims, jwt_id=jwt_id) r = Request('https://a.b', body=body) self.assertEqual(r.isnot, 'empty') self.assertEqual(r.grant_type, ServiceApplicationClient.grant_type) - claim = jwt.decode(r.assertion, self.public_key, audience=self.audience, algorithms=['RS256']) + token = jwt.api_jwt.decode_complete(r.assertion, self.public_key, audience=self.audience, algorithms=['RS256']) + header = token['header'] + claim = token['payload'] self.assertEqual(claim['iss'], self.issuer) # audience verification is handled during decode now @@ -134,6 +140,9 @@ def test_request_body(self, t): self.assertEqual(claim['nbf'], not_before) self.assertEqual(claim['jti'], jwt_id) + self.assertLessEqual(extra_jwt_headers.items(), header.items()) + self.assertLessEqual(extra_claims.items(), claim.items()) + @patch('time.time') def test_request_body_no_initial_private_key(self, t): t.return_value = time()