From 9cebaeec957e31c4bca0f7434804ae5cc8d29857 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Wed, 29 Jun 2022 15:17:41 -0700 Subject: [PATCH 01/19] Initial implementation of ttl jwk set cache (cherry picked from commit 479a7c124d63113a2190bd48972cc19172215096) --- jwt/api_jwk.py | 17 +++++++++++++++++ jwt/jwk_set_cache.py | 29 +++++++++++++++++++++++++++++ jwt/jwks_client.py | 32 +++++++++++++++++++++++++++----- 3 files changed, 73 insertions(+), 5 deletions(-) create mode 100644 jwt/jwk_set_cache.py diff --git a/jwt/api_jwk.py b/jwt/api_jwk.py index aae2cf12..e3e8ae82 100644 --- a/jwt/api_jwk.py +++ b/jwt/api_jwk.py @@ -1,4 +1,5 @@ import json +from datetime import datetime, timezone from .algorithms import get_default_algorithms from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError @@ -108,3 +109,19 @@ def __getitem__(self, kid): if key.key_id == kid: return key raise KeyError(f"keyset has no key for kid: {kid}") + + +class PyJWTSetWithTimestamp: + def __init__(self, jwt_set: PyJWKSet, timestamp: datetime = None): + self.jwt_set = jwt_set + + if timestamp is None: + self.timestamp = datetime.now(timezone.utc) + else: + self.timestamp = timestamp + + def get_jwk_set(self): + return self.jwt_set + + def get_timestamp(self): + return self.timestamp diff --git a/jwt/jwk_set_cache.py b/jwt/jwk_set_cache.py new file mode 100644 index 00000000..4127fb81 --- /dev/null +++ b/jwt/jwk_set_cache.py @@ -0,0 +1,29 @@ +from typing import Optional +from datetime import datetime, timezone + +from .api_jwk import PyJWKSet, PyJWTSetWithTimestamp + + +class JWKSetCache: + def __init__(self, lifespan: int): + self.jwk_set_with_timestamp = None + self.lifespan = lifespan + + def put(self, jwk_set: PyJWKSet): + if jwk_set is not None: + self.jwk_set_with_timestamp = PyJWTSetWithTimestamp(jwk_set) + else: + # clear cache + self.jwk_set_with_timestamp = None + + def get(self) -> Optional: + if self.jwk_set_with_timestamp is None or self.is_expired(): + return None + + return self.jwk_set_with_timestamp.get_jwk_set() + + def is_expired(self) -> bool: + return self.jwk_set_with_timestamp is not None \ + and self.lifespan > -1 \ + and datetime.now(timezone.utc) > self.jwk_set_with_timestamp.get_timestamp() + self.lifespan + diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py index 767b7179..79098f83 100644 --- a/jwt/jwks_client.py +++ b/jwt/jwks_client.py @@ -1,27 +1,49 @@ import json -import urllib.request +from urllib.request import urlopen +from urllib.error import URLError from functools import lru_cache from typing import Any, List from .api_jwk import PyJWK, PyJWKSet from .api_jwt import decode_complete as decode_token +from .jwk_set_cache import JWKSetCache from .exceptions import PyJWKClientError class PyJWKClient: - def __init__(self, uri: str, cache_keys: bool = True, max_cached_keys: int = 16): + def __init__(self, uri: str, cache_keys: bool = True, max_cached_keys: int = 16, + cache_jwk_set: bool = True, lifespan: int = 5): self.uri = uri + + if cache_jwk_set: + # Init jwt set cache with default or given lifespan. + self.jwk_set_cache = JWKSetCache(lifespan) + if cache_keys: # Cache signing keys # Ignore mypy (https://github.com/python/mypy/issues/2427) self.get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key) # type: ignore def fetch_data(self) -> Any: - with urllib.request.urlopen(self.uri) as response: - return json.load(response) + try: + with urlopen(self.uri) as response: + jwk_set = json.load(response) + except URLError as e: + raise PyJWKClientError(f'Fail to fetch data from the url, err: "{e}"') + + if self.jwk_set_cache is not None: + self.jwk_set_cache.put(jwk_set) + + return jwk_set def get_jwk_set(self) -> PyJWKSet: - data = self.fetch_data() + data = None + if self.jwk_set_cache is not None: + data = self.jwk_set_cache.get() + + if data is None: + data = self.fetch_data() + return PyJWKSet.from_dict(data) def get_signing_keys(self) -> List[PyJWK]: From 41b6f7f65d252bf31632e82d255a0fac05880f9a Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Wed, 29 Jun 2022 16:44:37 -0700 Subject: [PATCH 02/19] Add unit test for jwk set cache --- jwt/api_jwk.py | 10 +++----- jwt/jwk_set_cache.py | 8 ++++--- jwt/jwks_client.py | 5 +++- tests/test_jwks_client.py | 50 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 62 insertions(+), 11 deletions(-) diff --git a/jwt/api_jwk.py b/jwt/api_jwk.py index e3e8ae82..e59f125d 100644 --- a/jwt/api_jwk.py +++ b/jwt/api_jwk.py @@ -1,5 +1,5 @@ import json -from datetime import datetime, timezone +import time from .algorithms import get_default_algorithms from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError @@ -112,13 +112,9 @@ def __getitem__(self, kid): class PyJWTSetWithTimestamp: - def __init__(self, jwt_set: PyJWKSet, timestamp: datetime = None): + def __init__(self, jwt_set: PyJWKSet): self.jwt_set = jwt_set - - if timestamp is None: - self.timestamp = datetime.now(timezone.utc) - else: - self.timestamp = timestamp + self.timestamp = time.monotonic() def get_jwk_set(self): return self.jwt_set diff --git a/jwt/jwk_set_cache.py b/jwt/jwk_set_cache.py index 4127fb81..6731977d 100644 --- a/jwt/jwk_set_cache.py +++ b/jwt/jwk_set_cache.py @@ -1,5 +1,6 @@ +import time from typing import Optional -from datetime import datetime, timezone +from datetime import datetime, timezone, timedelta from .api_jwk import PyJWKSet, PyJWTSetWithTimestamp @@ -23,7 +24,8 @@ def get(self) -> Optional: return self.jwk_set_with_timestamp.get_jwk_set() def is_expired(self) -> bool: + return self.jwk_set_with_timestamp is not None \ and self.lifespan > -1 \ - and datetime.now(timezone.utc) > self.jwk_set_with_timestamp.get_timestamp() + self.lifespan - + and time.monotonic() > \ + self.jwk_set_with_timestamp.get_timestamp() + self.lifespan diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py index 79098f83..06d37e74 100644 --- a/jwt/jwks_client.py +++ b/jwt/jwks_client.py @@ -12,12 +12,15 @@ class PyJWKClient: def __init__(self, uri: str, cache_keys: bool = True, max_cached_keys: int = 16, - cache_jwk_set: bool = True, lifespan: int = 5): + cache_jwk_set: bool = True, lifespan: int = 300): self.uri = uri if cache_jwk_set: # Init jwt set cache with default or given lifespan. + # Default lifespan is 300 seconds (5 minutes). self.jwk_set_cache = JWKSetCache(lifespan) + else: + self.jwk_set_cache = None if cache_keys: # Cache signing keys diff --git a/tests/test_jwks_client.py b/tests/test_jwks_client.py index 3e42da17..20ce5556 100644 --- a/tests/test_jwks_client.py +++ b/tests/test_jwks_client.py @@ -1,5 +1,6 @@ import contextlib import json +import time from unittest import mock import pytest @@ -159,3 +160,52 @@ def test_get_signing_key_from_jwt(self): "azp": "aW4Cca79xReLWUz0aE2H6kD0O3cXBVtC", "gty": "client-credentials", } + + def test_get_jwk_set_caches_result(self): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + + jwks_client = PyJWKClient(url) + + with mocked_response(RESPONSE_DATA): + jwks_client.get_jwk_set() + + # mocked_response does not allow urllib.request.urlopen to be called twice + # so a second mock is needed + with mocked_response(RESPONSE_DATA) as repeated_call: + jwks_client.get_jwk_set() + + assert repeated_call.call_count == 0 + + def test_get_jwt_set_cache_expired_result(self): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + + jwks_client = PyJWKClient(url, lifespan=1) + with mocked_response(RESPONSE_DATA): + jwks_client.get_jwk_set() + + time.sleep(1) + + # mocked_response does not allow urllib.request.urlopen to be called twice + # so a second mock is needed + with mocked_response(RESPONSE_DATA) as repeated_call: + jwks_client.get_jwk_set() + + assert repeated_call.call_count == 1 + + def test_get_jwt_set_cache_disabled(self): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + + jwks_client = PyJWKClient(url, cache_jwk_set=False, lifespan=1) + with mocked_response(RESPONSE_DATA): + jwks_client.get_jwk_set() + + time.sleep(1) + + # mocked_response does not allow urllib.request.urlopen to be called twice + # so a second mock is needed + with mocked_response(RESPONSE_DATA) as repeated_call: + jwks_client.get_jwk_set() + + assert repeated_call.call_count == 1 + + From 544d7ca1a958b445605235fc8dea786e02c9b991 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Thu, 30 Jun 2022 11:29:17 -0700 Subject: [PATCH 03/19] Fix failed unit test --- jwt/jwks_client.py | 4 ++-- tests/test_jwks_client.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py index 06d37e74..7fc9a6b8 100644 --- a/jwt/jwks_client.py +++ b/jwt/jwks_client.py @@ -1,5 +1,5 @@ import json -from urllib.request import urlopen +import urllib.request from urllib.error import URLError from functools import lru_cache from typing import Any, List @@ -29,7 +29,7 @@ def __init__(self, uri: str, cache_keys: bool = True, max_cached_keys: int = 16, def fetch_data(self) -> Any: try: - with urlopen(self.uri) as response: + with urllib.request.urlopen(self.uri) as response: jwk_set = json.load(response) except URLError as e: raise PyJWKClientError(f'Fail to fetch data from the url, err: "{e}"') diff --git a/tests/test_jwks_client.py b/tests/test_jwks_client.py index 20ce5556..eeaf7870 100644 --- a/tests/test_jwks_client.py +++ b/tests/test_jwks_client.py @@ -123,9 +123,9 @@ def test_get_signing_key_does_not_cache_opt_out(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw" - jwks_client = PyJWKClient(url, cache_keys=False) + jwks_client = PyJWKClient(url, cache_keys=False, cache_jwk_set=False) - with mocked_response(RESPONSE_DATA): + with mocked_response(RESPONSE_DATA) as first_call: jwks_client.get_signing_key(kid) # mocked_response does not allow urllib.request.urlopen to be called twice From ac9ae720287847ebaf0219030d25fe73f464ddfa Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Thu, 30 Jun 2022 14:22:14 -0700 Subject: [PATCH 04/19] Disable cache signing key by default --- jwt/jwks_client.py | 2 +- tests/test_jwks_client.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py index 7fc9a6b8..4b2c8013 100644 --- a/jwt/jwks_client.py +++ b/jwt/jwks_client.py @@ -11,7 +11,7 @@ class PyJWKClient: - def __init__(self, uri: str, cache_keys: bool = True, max_cached_keys: int = 16, + def __init__(self, uri: str, cache_keys: bool = False, max_cached_keys: int = 16, cache_jwk_set: bool = True, lifespan: int = 300): self.uri = uri diff --git a/tests/test_jwks_client.py b/tests/test_jwks_client.py index eeaf7870..cc3f8bc1 100644 --- a/tests/test_jwks_client.py +++ b/tests/test_jwks_client.py @@ -107,7 +107,7 @@ def test_get_signing_key_caches_result(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw" - jwks_client = PyJWKClient(url) + jwks_client = PyJWKClient(url, cache_keys=True) with mocked_response(RESPONSE_DATA): jwks_client.get_signing_key(kid) @@ -123,9 +123,9 @@ def test_get_signing_key_does_not_cache_opt_out(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw" - jwks_client = PyJWKClient(url, cache_keys=False, cache_jwk_set=False) + jwks_client = PyJWKClient(url, cache_jwk_set=False) - with mocked_response(RESPONSE_DATA) as first_call: + with mocked_response(RESPONSE_DATA): jwks_client.get_signing_key(kid) # mocked_response does not allow urllib.request.urlopen to be called twice @@ -195,7 +195,7 @@ def test_get_jwt_set_cache_expired_result(self): def test_get_jwt_set_cache_disabled(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" - jwks_client = PyJWKClient(url, cache_jwk_set=False, lifespan=1) + jwks_client = PyJWKClient(url, cache_jwk_set=False) with mocked_response(RESPONSE_DATA): jwks_client.get_jwk_set() From a209586e570dc4fae7680093a480d4c3892b09f4 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Thu, 30 Jun 2022 14:49:20 -0700 Subject: [PATCH 05/19] Add a negative unit test for get_jwk_set --- tests/test_jwks_client.py | 48 +++++++++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/tests/test_jwks_client.py b/tests/test_jwks_client.py index cc3f8bc1..1115f61f 100644 --- a/tests/test_jwks_client.py +++ b/tests/test_jwks_client.py @@ -2,6 +2,7 @@ import json import time from unittest import mock +from urllib.error import URLError import pytest @@ -31,7 +32,7 @@ @contextlib.contextmanager -def mocked_response(data): +def mocked_success_response(data): with mock.patch("urllib.request.urlopen") as urlopen_mock: response = mock.Mock() response.__enter__ = mock.Mock(return_value=response) @@ -41,12 +42,19 @@ def mocked_response(data): yield urlopen_mock +@contextlib.contextmanager +def mocked_failed_response(): + with mock.patch("urllib.request.urlopen") as urlopen_mock: + urlopen_mock.side_effect = URLError("Fail to process the request.") + yield urlopen_mock + + @crypto_required class TestPyJWKClient: def test_get_jwk_set(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" - with mocked_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA): jwks_client = PyJWKClient(url) jwk_set = jwks_client.get_jwk_set() @@ -55,7 +63,7 @@ def test_get_jwk_set(self): def test_get_signing_keys(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" - with mocked_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA): jwks_client = PyJWKClient(url) signing_keys = jwks_client.get_signing_keys() @@ -69,7 +77,7 @@ def test_get_signing_keys_if_no_use_provided(self): del mocked_key["use"] response = {"keys": [mocked_key]} - with mocked_response(response): + with mocked_success_response(response): jwks_client = PyJWKClient(url) signing_keys = jwks_client.get_signing_keys() @@ -82,7 +90,7 @@ def test_get_signing_keys_raises_if_none_found(self): mocked_key = RESPONSE_DATA["keys"][0].copy() mocked_key["use"] = "enc" response = {"keys": [mocked_key]} - with mocked_response(response): + with mocked_success_response(response): jwks_client = PyJWKClient(url) with pytest.raises(PyJWKClientError) as exc: @@ -94,7 +102,7 @@ def test_get_signing_key(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw" - with mocked_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA): jwks_client = PyJWKClient(url) signing_key = jwks_client.get_signing_key(kid) @@ -109,12 +117,12 @@ def test_get_signing_key_caches_result(self): jwks_client = PyJWKClient(url, cache_keys=True) - with mocked_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA): jwks_client.get_signing_key(kid) # mocked_response does not allow urllib.request.urlopen to be called twice # so a second mock is needed - with mocked_response(RESPONSE_DATA) as repeated_call: + with mocked_success_response(RESPONSE_DATA) as repeated_call: jwks_client.get_signing_key(kid) assert repeated_call.call_count == 0 @@ -125,12 +133,12 @@ def test_get_signing_key_does_not_cache_opt_out(self): jwks_client = PyJWKClient(url, cache_jwk_set=False) - with mocked_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA): jwks_client.get_signing_key(kid) # mocked_response does not allow urllib.request.urlopen to be called twice # so a second mock is needed - with mocked_response(RESPONSE_DATA) as repeated_call: + with mocked_success_response(RESPONSE_DATA) as repeated_call: jwks_client.get_signing_key(kid) assert repeated_call.call_count == 1 @@ -139,7 +147,7 @@ def test_get_signing_key_from_jwt(self): token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6Ik5FRTFRVVJCT1RNNE16STVSa0ZETlRZeE9UVTFNRGcyT0Rnd1EwVXpNVGsxUWpZeVJrUkZRdyJ9.eyJpc3MiOiJodHRwczovL2Rldi04N2V2eDlydS5hdXRoMC5jb20vIiwic3ViIjoiYVc0Q2NhNzl4UmVMV1V6MGFFMkg2a0QwTzNjWEJWdENAY2xpZW50cyIsImF1ZCI6Imh0dHBzOi8vZXhwZW5zZXMtYXBpIiwiaWF0IjoxNTcyMDA2OTU0LCJleHAiOjE1NzIwMDY5NjQsImF6cCI6ImFXNENjYTc5eFJlTFdVejBhRTJINmtEME8zY1hCVnRDIiwiZ3R5IjoiY2xpZW50LWNyZWRlbnRpYWxzIn0.PUxE7xn52aTCohGiWoSdMBZGiYAHwE5FYie0Y1qUT68IHSTXwXVd6hn02HTah6epvHHVKA2FqcFZ4GGv5VTHEvYpeggiiZMgbxFrmTEY0csL6VNkX1eaJGcuehwQCRBKRLL3zKmA5IKGy5GeUnIbpPHLHDxr-GXvgFzsdsyWlVQvPX2xjeaQ217r2PtxDeqjlf66UYl6oY6AqNS8DH3iryCvIfCcybRZkc_hdy-6ZMoKT6Piijvk_aXdm7-QQqKJFHLuEqrVSOuBqqiNfVrG27QzAPuPOxvfXTVLXL2jek5meH6n-VWgrBdoMFH93QEszEDowDAEhQPHVs0xj7SIzA" url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" - with mocked_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA): jwks_client = PyJWKClient(url) signing_key = jwks_client.get_signing_key_from_jwt(token) @@ -166,12 +174,12 @@ def test_get_jwk_set_caches_result(self): jwks_client = PyJWKClient(url) - with mocked_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA): jwks_client.get_jwk_set() # mocked_response does not allow urllib.request.urlopen to be called twice # so a second mock is needed - with mocked_response(RESPONSE_DATA) as repeated_call: + with mocked_success_response(RESPONSE_DATA) as repeated_call: jwks_client.get_jwk_set() assert repeated_call.call_count == 0 @@ -180,14 +188,14 @@ def test_get_jwt_set_cache_expired_result(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" jwks_client = PyJWKClient(url, lifespan=1) - with mocked_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA): jwks_client.get_jwk_set() time.sleep(1) # mocked_response does not allow urllib.request.urlopen to be called twice # so a second mock is needed - with mocked_response(RESPONSE_DATA) as repeated_call: + with mocked_success_response(RESPONSE_DATA) as repeated_call: jwks_client.get_jwk_set() assert repeated_call.call_count == 1 @@ -196,16 +204,22 @@ def test_get_jwt_set_cache_disabled(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" jwks_client = PyJWKClient(url, cache_jwk_set=False) - with mocked_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA): jwks_client.get_jwk_set() time.sleep(1) # mocked_response does not allow urllib.request.urlopen to be called twice # so a second mock is needed - with mocked_response(RESPONSE_DATA) as repeated_call: + with mocked_success_response(RESPONSE_DATA) as repeated_call: jwks_client.get_jwk_set() assert repeated_call.call_count == 1 + def test_get_jwt_set_failed_request(self): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + jwks_client = PyJWKClient(url) + with pytest.raises(PyJWKClientError): + with mocked_failed_response(): + jwks_client.get_jwk_set() From b21d67bf5f9c6f9a24a3591b979c8e0b22b634de Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Thu, 30 Jun 2022 16:26:58 -0700 Subject: [PATCH 06/19] Add functionality to force refresh the jwk set cache when no matching signing key can be found from the cache --- jwt/jwks_client.py | 39 +++++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py index 4b2c8013..21f0a5fe 100644 --- a/jwt/jwks_client.py +++ b/jwt/jwks_client.py @@ -2,7 +2,7 @@ import urllib.request from urllib.error import URLError from functools import lru_cache -from typing import Any, List +from typing import Any, List, Optional from .api_jwk import PyJWK, PyJWKSet from .api_jwt import decode_complete as decode_token @@ -39,9 +39,9 @@ def fetch_data(self) -> Any: return jwk_set - def get_jwk_set(self) -> PyJWKSet: + def get_jwk_set(self, refresh: bool = False) -> PyJWKSet: data = None - if self.jwk_set_cache is not None: + if self.jwk_set_cache is not None and not refresh: data = self.jwk_set_cache.get() if data is None: @@ -49,8 +49,8 @@ def get_jwk_set(self) -> PyJWKSet: return PyJWKSet.from_dict(data) - def get_signing_keys(self) -> List[PyJWK]: - jwk_set = self.get_jwk_set() + def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]: + jwk_set = self.get_jwk_set(refresh) signing_keys = [ jwk_set_key for jwk_set_key in jwk_set.keys @@ -64,17 +64,17 @@ def get_signing_keys(self) -> List[PyJWK]: def get_signing_key(self, kid: str) -> PyJWK: signing_keys = self.get_signing_keys() - signing_key = None - - for key in signing_keys: - if key.key_id == kid: - signing_key = key - break + signing_key = self.match_kid(signing_keys, kid) if not signing_key: - raise PyJWKClientError( - f'Unable to find a signing key that matches: "{kid}"' - ) + # If no matching signing key from the cached jwk set, refresh the jwk set. + signing_keys = self.get_signing_keys(refresh=True) + signing_key = self.match_kid(signing_keys, kid) + + if not signing_key: + raise PyJWKClientError( + f'Unable to find a signing key that matches: "{kid}"' + ) return signing_key @@ -82,3 +82,14 @@ def get_signing_key_from_jwt(self, token: str) -> PyJWK: unverified = decode_token(token, options={"verify_signature": False}) header = unverified["header"] return self.get_signing_key(header.get("kid")) + + @staticmethod + def match_kid(signing_keys: list[PyJWK], kid: str) -> Optional[PyJWK]: + signing_key = None + + for key in signing_keys: + if key.key_id == kid: + signing_key = key + break + + return signing_key From 70d1d2f394f5ae1c85e1476656ea880899424c54 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Thu, 30 Jun 2022 17:18:11 -0700 Subject: [PATCH 07/19] Add unit test for refresh cache --- tests/test_jwks_client.py | 69 +++++++++++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 17 deletions(-) diff --git a/tests/test_jwks_client.py b/tests/test_jwks_client.py index 1115f61f..f8b91677 100644 --- a/tests/test_jwks_client.py +++ b/tests/test_jwks_client.py @@ -13,7 +13,7 @@ from .utils import crypto_required -RESPONSE_DATA = { +RESPONSE_DATA_ONE = { "keys": [ { "alg": "RS256", @@ -30,6 +30,19 @@ ] } +RESPONSE_DATA_TWO = { + "keys": [ + { + "alg": "RS256", + "kty": "RSA", + "use": "sig", + "n": "39SJ39VgrQ0qMNK74CaueUBlyYsUyuA7yWlHYZ-jAj6tlFKugEVUTBUVbhGF44uOr99iL_cwmr-srqQDEi-jFHdkS6WFkYyZ03oyyx5dtBMtzrXPieFipSGfQ5EGUGloaKDjL-Ry9tiLnysH2VVWZ5WDDN-DGHxuCOWWjiBNcTmGfnj5_NvRHNUh2iTLuiJpHbGcPzWc5-lc4r-_ehw9EFfp2XsxE9xvtbMZ4SouJCiv9xnrnhe2bdpWuu34hXZCrQwE8DjRY3UR8LjyMxHHPLzX2LWNMHjfN3nAZMteS-Ok11VYDFI-4qCCVGo_WesBCAeqCjPLRyZoV27x1YGsUQ", + "e": "AQAB", + "kid": "MLYHNMMhwCNXw9roHIILFsK4nLs=", + } + ] +} + @contextlib.contextmanager def mocked_success_response(data): @@ -49,12 +62,23 @@ def mocked_failed_response(): yield urlopen_mock +@contextlib.contextmanager +def mocked_first_call_empty_second_call_with_response(response_data_one, response_data_two): + with mock.patch("urllib.request.urlopen") as urlopen_mock: + response = mock.Mock() + response.__enter__ = mock.Mock(return_value=response) + response.__exit__ = mock.Mock() + response.read.side_effect = [json.dumps(response_data_one), json.dumps(response_data_two)] + urlopen_mock.return_value = response + yield urlopen_mock + + @crypto_required class TestPyJWKClient: def test_get_jwk_set(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" - with mocked_success_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA_ONE): jwks_client = PyJWKClient(url) jwk_set = jwks_client.get_jwk_set() @@ -63,7 +87,7 @@ def test_get_jwk_set(self): def test_get_signing_keys(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" - with mocked_success_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA_ONE): jwks_client = PyJWKClient(url) signing_keys = jwks_client.get_signing_keys() @@ -73,7 +97,7 @@ def test_get_signing_keys(self): def test_get_signing_keys_if_no_use_provided(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" - mocked_key = RESPONSE_DATA["keys"][0].copy() + mocked_key = RESPONSE_DATA_ONE["keys"][0].copy() del mocked_key["use"] response = {"keys": [mocked_key]} @@ -87,7 +111,7 @@ def test_get_signing_keys_if_no_use_provided(self): def test_get_signing_keys_raises_if_none_found(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" - mocked_key = RESPONSE_DATA["keys"][0].copy() + mocked_key = RESPONSE_DATA_ONE["keys"][0].copy() mocked_key["use"] = "enc" response = {"keys": [mocked_key]} with mocked_success_response(response): @@ -102,7 +126,7 @@ def test_get_signing_key(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw" - with mocked_success_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA_ONE): jwks_client = PyJWKClient(url) signing_key = jwks_client.get_signing_key(kid) @@ -117,12 +141,12 @@ def test_get_signing_key_caches_result(self): jwks_client = PyJWKClient(url, cache_keys=True) - with mocked_success_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA_ONE): jwks_client.get_signing_key(kid) # mocked_response does not allow urllib.request.urlopen to be called twice # so a second mock is needed - with mocked_success_response(RESPONSE_DATA) as repeated_call: + with mocked_success_response(RESPONSE_DATA_ONE) as repeated_call: jwks_client.get_signing_key(kid) assert repeated_call.call_count == 0 @@ -133,12 +157,12 @@ def test_get_signing_key_does_not_cache_opt_out(self): jwks_client = PyJWKClient(url, cache_jwk_set=False) - with mocked_success_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA_ONE): jwks_client.get_signing_key(kid) # mocked_response does not allow urllib.request.urlopen to be called twice # so a second mock is needed - with mocked_success_response(RESPONSE_DATA) as repeated_call: + with mocked_success_response(RESPONSE_DATA_ONE) as repeated_call: jwks_client.get_signing_key(kid) assert repeated_call.call_count == 1 @@ -147,7 +171,7 @@ def test_get_signing_key_from_jwt(self): token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6Ik5FRTFRVVJCT1RNNE16STVSa0ZETlRZeE9UVTFNRGcyT0Rnd1EwVXpNVGsxUWpZeVJrUkZRdyJ9.eyJpc3MiOiJodHRwczovL2Rldi04N2V2eDlydS5hdXRoMC5jb20vIiwic3ViIjoiYVc0Q2NhNzl4UmVMV1V6MGFFMkg2a0QwTzNjWEJWdENAY2xpZW50cyIsImF1ZCI6Imh0dHBzOi8vZXhwZW5zZXMtYXBpIiwiaWF0IjoxNTcyMDA2OTU0LCJleHAiOjE1NzIwMDY5NjQsImF6cCI6ImFXNENjYTc5eFJlTFdVejBhRTJINmtEME8zY1hCVnRDIiwiZ3R5IjoiY2xpZW50LWNyZWRlbnRpYWxzIn0.PUxE7xn52aTCohGiWoSdMBZGiYAHwE5FYie0Y1qUT68IHSTXwXVd6hn02HTah6epvHHVKA2FqcFZ4GGv5VTHEvYpeggiiZMgbxFrmTEY0csL6VNkX1eaJGcuehwQCRBKRLL3zKmA5IKGy5GeUnIbpPHLHDxr-GXvgFzsdsyWlVQvPX2xjeaQ217r2PtxDeqjlf66UYl6oY6AqNS8DH3iryCvIfCcybRZkc_hdy-6ZMoKT6Piijvk_aXdm7-QQqKJFHLuEqrVSOuBqqiNfVrG27QzAPuPOxvfXTVLXL2jek5meH6n-VWgrBdoMFH93QEszEDowDAEhQPHVs0xj7SIzA" url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" - with mocked_success_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA_ONE): jwks_client = PyJWKClient(url) signing_key = jwks_client.get_signing_key_from_jwt(token) @@ -174,12 +198,12 @@ def test_get_jwk_set_caches_result(self): jwks_client = PyJWKClient(url) - with mocked_success_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA_ONE): jwks_client.get_jwk_set() # mocked_response does not allow urllib.request.urlopen to be called twice # so a second mock is needed - with mocked_success_response(RESPONSE_DATA) as repeated_call: + with mocked_success_response(RESPONSE_DATA_ONE) as repeated_call: jwks_client.get_jwk_set() assert repeated_call.call_count == 0 @@ -188,14 +212,14 @@ def test_get_jwt_set_cache_expired_result(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" jwks_client = PyJWKClient(url, lifespan=1) - with mocked_success_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA_ONE): jwks_client.get_jwk_set() time.sleep(1) # mocked_response does not allow urllib.request.urlopen to be called twice # so a second mock is needed - with mocked_success_response(RESPONSE_DATA) as repeated_call: + with mocked_success_response(RESPONSE_DATA_ONE) as repeated_call: jwks_client.get_jwk_set() assert repeated_call.call_count == 1 @@ -204,14 +228,14 @@ def test_get_jwt_set_cache_disabled(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" jwks_client = PyJWKClient(url, cache_jwk_set=False) - with mocked_success_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA_ONE): jwks_client.get_jwk_set() time.sleep(1) # mocked_response does not allow urllib.request.urlopen to be called twice # so a second mock is needed - with mocked_success_response(RESPONSE_DATA) as repeated_call: + with mocked_success_response(RESPONSE_DATA_ONE) as repeated_call: jwks_client.get_jwk_set() assert repeated_call.call_count == 1 @@ -223,3 +247,14 @@ def test_get_jwt_set_failed_request(self): with pytest.raises(PyJWKClientError): with mocked_failed_response(): jwks_client.get_jwk_set() + + def test_get_jwt_set_refresh_cache(self): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + jwks_client = PyJWKClient(url) + + kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw" + + with mocked_first_call_empty_second_call_with_response(RESPONSE_DATA_TWO, RESPONSE_DATA_ONE) as call_data: + jwks_client.get_signing_key(kid) + + assert call_data.call_count == 2 From d611b8dcf8222ce39fd30404ed13a5a1fcc06fc2 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Thu, 30 Jun 2022 17:19:41 -0700 Subject: [PATCH 08/19] Add unit test to unset cache when the network call throws error --- tests/test_jwks_client.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_jwks_client.py b/tests/test_jwks_client.py index f8b91677..5439ab40 100644 --- a/tests/test_jwks_client.py +++ b/tests/test_jwks_client.py @@ -248,6 +248,8 @@ def test_get_jwt_set_failed_request(self): with mocked_failed_response(): jwks_client.get_jwk_set() + assert jwks_client.jwk_set_cache is None + def test_get_jwt_set_refresh_cache(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" jwks_client = PyJWKClient(url) From 56076f0790f1ae9936ba1652fc861a74c720caa6 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Fri, 1 Jul 2022 10:20:44 -0700 Subject: [PATCH 09/19] fix naming typo --- jwt/api_jwk.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jwt/api_jwk.py b/jwt/api_jwk.py index e59f125d..2457c461 100644 --- a/jwt/api_jwk.py +++ b/jwt/api_jwk.py @@ -112,12 +112,12 @@ def __getitem__(self, kid): class PyJWTSetWithTimestamp: - def __init__(self, jwt_set: PyJWKSet): - self.jwt_set = jwt_set + def __init__(self, jwk_set: PyJWKSet): + self.jwk_set = jwk_set self.timestamp = time.monotonic() def get_jwk_set(self): - return self.jwt_set + return self.jwk_set def get_timestamp(self): return self.timestamp From a4c28d16801d25315f5e6a70ceb35b4c3c098c06 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Fri, 1 Jul 2022 11:26:32 -0700 Subject: [PATCH 10/19] Update unit test naming --- jwt/jwk_set_cache.py | 3 +-- tests/test_jwks_client.py | 42 ++++++++++++++++++++------------------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/jwt/jwk_set_cache.py b/jwt/jwk_set_cache.py index 6731977d..d8c09ee2 100644 --- a/jwt/jwk_set_cache.py +++ b/jwt/jwk_set_cache.py @@ -1,6 +1,5 @@ import time from typing import Optional -from datetime import datetime, timezone, timedelta from .api_jwk import PyJWKSet, PyJWTSetWithTimestamp @@ -17,7 +16,7 @@ def put(self, jwk_set: PyJWKSet): # clear cache self.jwk_set_with_timestamp = None - def get(self) -> Optional: + def get(self) -> Optional[PyJWKSet]: if self.jwk_set_with_timestamp is None or self.is_expired(): return None diff --git a/tests/test_jwks_client.py b/tests/test_jwks_client.py index 5439ab40..b195b47e 100644 --- a/tests/test_jwks_client.py +++ b/tests/test_jwks_client.py @@ -13,7 +13,7 @@ from .utils import crypto_required -RESPONSE_DATA_ONE = { +RESPONSE_DATA_WITH_MATCHING_KID = { "keys": [ { "alg": "RS256", @@ -30,7 +30,7 @@ ] } -RESPONSE_DATA_TWO = { +RESPONSE_DATA_NO_MATCHING_KID = { "keys": [ { "alg": "RS256", @@ -63,7 +63,7 @@ def mocked_failed_response(): @contextlib.contextmanager -def mocked_first_call_empty_second_call_with_response(response_data_one, response_data_two): +def mocked_first_call_wrong_kid_second_call_correct_kid(response_data_one, response_data_two): with mock.patch("urllib.request.urlopen") as urlopen_mock: response = mock.Mock() response.__enter__ = mock.Mock(return_value=response) @@ -78,7 +78,7 @@ class TestPyJWKClient: def test_get_jwk_set(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" - with mocked_success_response(RESPONSE_DATA_ONE): + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): jwks_client = PyJWKClient(url) jwk_set = jwks_client.get_jwk_set() @@ -87,7 +87,7 @@ def test_get_jwk_set(self): def test_get_signing_keys(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" - with mocked_success_response(RESPONSE_DATA_ONE): + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): jwks_client = PyJWKClient(url) signing_keys = jwks_client.get_signing_keys() @@ -97,7 +97,7 @@ def test_get_signing_keys(self): def test_get_signing_keys_if_no_use_provided(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" - mocked_key = RESPONSE_DATA_ONE["keys"][0].copy() + mocked_key = RESPONSE_DATA_WITH_MATCHING_KID["keys"][0].copy() del mocked_key["use"] response = {"keys": [mocked_key]} @@ -111,7 +111,7 @@ def test_get_signing_keys_if_no_use_provided(self): def test_get_signing_keys_raises_if_none_found(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" - mocked_key = RESPONSE_DATA_ONE["keys"][0].copy() + mocked_key = RESPONSE_DATA_WITH_MATCHING_KID["keys"][0].copy() mocked_key["use"] = "enc" response = {"keys": [mocked_key]} with mocked_success_response(response): @@ -126,7 +126,7 @@ def test_get_signing_key(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw" - with mocked_success_response(RESPONSE_DATA_ONE): + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): jwks_client = PyJWKClient(url) signing_key = jwks_client.get_signing_key(kid) @@ -141,12 +141,12 @@ def test_get_signing_key_caches_result(self): jwks_client = PyJWKClient(url, cache_keys=True) - with mocked_success_response(RESPONSE_DATA_ONE): + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): jwks_client.get_signing_key(kid) # mocked_response does not allow urllib.request.urlopen to be called twice # so a second mock is needed - with mocked_success_response(RESPONSE_DATA_ONE) as repeated_call: + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID) as repeated_call: jwks_client.get_signing_key(kid) assert repeated_call.call_count == 0 @@ -157,12 +157,12 @@ def test_get_signing_key_does_not_cache_opt_out(self): jwks_client = PyJWKClient(url, cache_jwk_set=False) - with mocked_success_response(RESPONSE_DATA_ONE): + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): jwks_client.get_signing_key(kid) # mocked_response does not allow urllib.request.urlopen to be called twice # so a second mock is needed - with mocked_success_response(RESPONSE_DATA_ONE) as repeated_call: + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID) as repeated_call: jwks_client.get_signing_key(kid) assert repeated_call.call_count == 1 @@ -171,7 +171,7 @@ def test_get_signing_key_from_jwt(self): token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6Ik5FRTFRVVJCT1RNNE16STVSa0ZETlRZeE9UVTFNRGcyT0Rnd1EwVXpNVGsxUWpZeVJrUkZRdyJ9.eyJpc3MiOiJodHRwczovL2Rldi04N2V2eDlydS5hdXRoMC5jb20vIiwic3ViIjoiYVc0Q2NhNzl4UmVMV1V6MGFFMkg2a0QwTzNjWEJWdENAY2xpZW50cyIsImF1ZCI6Imh0dHBzOi8vZXhwZW5zZXMtYXBpIiwiaWF0IjoxNTcyMDA2OTU0LCJleHAiOjE1NzIwMDY5NjQsImF6cCI6ImFXNENjYTc5eFJlTFdVejBhRTJINmtEME8zY1hCVnRDIiwiZ3R5IjoiY2xpZW50LWNyZWRlbnRpYWxzIn0.PUxE7xn52aTCohGiWoSdMBZGiYAHwE5FYie0Y1qUT68IHSTXwXVd6hn02HTah6epvHHVKA2FqcFZ4GGv5VTHEvYpeggiiZMgbxFrmTEY0csL6VNkX1eaJGcuehwQCRBKRLL3zKmA5IKGy5GeUnIbpPHLHDxr-GXvgFzsdsyWlVQvPX2xjeaQ217r2PtxDeqjlf66UYl6oY6AqNS8DH3iryCvIfCcybRZkc_hdy-6ZMoKT6Piijvk_aXdm7-QQqKJFHLuEqrVSOuBqqiNfVrG27QzAPuPOxvfXTVLXL2jek5meH6n-VWgrBdoMFH93QEszEDowDAEhQPHVs0xj7SIzA" url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" - with mocked_success_response(RESPONSE_DATA_ONE): + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): jwks_client = PyJWKClient(url) signing_key = jwks_client.get_signing_key_from_jwt(token) @@ -198,12 +198,12 @@ def test_get_jwk_set_caches_result(self): jwks_client = PyJWKClient(url) - with mocked_success_response(RESPONSE_DATA_ONE): + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): jwks_client.get_jwk_set() # mocked_response does not allow urllib.request.urlopen to be called twice # so a second mock is needed - with mocked_success_response(RESPONSE_DATA_ONE) as repeated_call: + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID) as repeated_call: jwks_client.get_jwk_set() assert repeated_call.call_count == 0 @@ -212,14 +212,14 @@ def test_get_jwt_set_cache_expired_result(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" jwks_client = PyJWKClient(url, lifespan=1) - with mocked_success_response(RESPONSE_DATA_ONE): + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): jwks_client.get_jwk_set() time.sleep(1) # mocked_response does not allow urllib.request.urlopen to be called twice # so a second mock is needed - with mocked_success_response(RESPONSE_DATA_ONE) as repeated_call: + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID) as repeated_call: jwks_client.get_jwk_set() assert repeated_call.call_count == 1 @@ -228,14 +228,14 @@ def test_get_jwt_set_cache_disabled(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" jwks_client = PyJWKClient(url, cache_jwk_set=False) - with mocked_success_response(RESPONSE_DATA_ONE): + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): jwks_client.get_jwk_set() time.sleep(1) # mocked_response does not allow urllib.request.urlopen to be called twice # so a second mock is needed - with mocked_success_response(RESPONSE_DATA_ONE) as repeated_call: + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID) as repeated_call: jwks_client.get_jwk_set() assert repeated_call.call_count == 1 @@ -256,7 +256,9 @@ def test_get_jwt_set_refresh_cache(self): kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw" - with mocked_first_call_empty_second_call_with_response(RESPONSE_DATA_TWO, RESPONSE_DATA_ONE) as call_data: + # The first call will return + with mocked_first_call_wrong_kid_second_call_correct_kid( + RESPONSE_DATA_NO_MATCHING_KID, RESPONSE_DATA_WITH_MATCHING_KID) as call_data: jwks_client.get_signing_key(kid) assert call_data.call_count == 2 From e4b29b02dc0fc6fd6aeeb560488f28922f9765f8 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Fri, 1 Jul 2022 11:27:52 -0700 Subject: [PATCH 11/19] Update comment --- tests/test_jwks_client.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_jwks_client.py b/tests/test_jwks_client.py index b195b47e..153a5502 100644 --- a/tests/test_jwks_client.py +++ b/tests/test_jwks_client.py @@ -256,7 +256,8 @@ def test_get_jwt_set_refresh_cache(self): kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw" - # The first call will return + # The first call will return response with no matching kid, + # the function should make another call to try to refresh the cache. with mocked_first_call_wrong_kid_second_call_correct_kid( RESPONSE_DATA_NO_MATCHING_KID, RESPONSE_DATA_WITH_MATCHING_KID) as call_data: jwks_client.get_signing_key(kid) From 2c1bd08e44c9918d9bdf63c496470fa749aaaa1c Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Sun, 10 Jul 2022 23:17:49 -0700 Subject: [PATCH 12/19] Add check for lifespan --- jwt/jwks_client.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py index 21f0a5fe..fbcbe058 100644 --- a/jwt/jwks_client.py +++ b/jwt/jwks_client.py @@ -18,6 +18,8 @@ def __init__(self, uri: str, cache_keys: bool = False, max_cached_keys: int = 16 if cache_jwk_set: # Init jwt set cache with default or given lifespan. # Default lifespan is 300 seconds (5 minutes). + if lifespan < 0: + raise PyJWKClientError(f'Lifespan must be greater than 0, the input is "{lifespan}"') self.jwk_set_cache = JWKSetCache(lifespan) else: self.jwk_set_cache = None From 913017e1682ad660f5dc440e607c14ebd23f9130 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Sun, 10 Jul 2022 23:22:59 -0700 Subject: [PATCH 13/19] Update comments for get_signing_key --- jwt/jwks_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py index fbcbe058..0f35b601 100644 --- a/jwt/jwks_client.py +++ b/jwt/jwks_client.py @@ -69,7 +69,7 @@ def get_signing_key(self, kid: str) -> PyJWK: signing_key = self.match_kid(signing_keys, kid) if not signing_key: - # If no matching signing key from the cached jwk set, refresh the jwk set. + # If no matching signing key from the jwk set, refresh the jwk set and try again. signing_keys = self.get_signing_keys(refresh=True) signing_key = self.match_kid(signing_keys, kid) From f7e3cade3f7508f28aed07c0533db501aa595b29 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 11 Jul 2022 06:29:17 +0000 Subject: [PATCH 14/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- jwt/jwk_set_cache.py | 10 ++++++---- jwt/jwks_client.py | 18 +++++++++++++----- tests/test_jwks_client.py | 12 +++++++++--- 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/jwt/jwk_set_cache.py b/jwt/jwk_set_cache.py index d8c09ee2..2f52aa4d 100644 --- a/jwt/jwk_set_cache.py +++ b/jwt/jwk_set_cache.py @@ -24,7 +24,9 @@ def get(self) -> Optional[PyJWKSet]: def is_expired(self) -> bool: - return self.jwk_set_with_timestamp is not None \ - and self.lifespan > -1 \ - and time.monotonic() > \ - self.jwk_set_with_timestamp.get_timestamp() + self.lifespan + return ( + self.jwk_set_with_timestamp is not None + and self.lifespan > -1 + and time.monotonic() + > self.jwk_set_with_timestamp.get_timestamp() + self.lifespan + ) diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py index 0f35b601..8d407ac1 100644 --- a/jwt/jwks_client.py +++ b/jwt/jwks_client.py @@ -1,25 +1,33 @@ import json import urllib.request -from urllib.error import URLError from functools import lru_cache from typing import Any, List, Optional +from urllib.error import URLError from .api_jwk import PyJWK, PyJWKSet from .api_jwt import decode_complete as decode_token -from .jwk_set_cache import JWKSetCache from .exceptions import PyJWKClientError +from .jwk_set_cache import JWKSetCache class PyJWKClient: - def __init__(self, uri: str, cache_keys: bool = False, max_cached_keys: int = 16, - cache_jwk_set: bool = True, lifespan: int = 300): + def __init__( + self, + uri: str, + cache_keys: bool = False, + max_cached_keys: int = 16, + cache_jwk_set: bool = True, + lifespan: int = 300, + ): self.uri = uri if cache_jwk_set: # Init jwt set cache with default or given lifespan. # Default lifespan is 300 seconds (5 minutes). if lifespan < 0: - raise PyJWKClientError(f'Lifespan must be greater than 0, the input is "{lifespan}"') + raise PyJWKClientError( + f'Lifespan must be greater than 0, the input is "{lifespan}"' + ) self.jwk_set_cache = JWKSetCache(lifespan) else: self.jwk_set_cache = None diff --git a/tests/test_jwks_client.py b/tests/test_jwks_client.py index 153a5502..6111c69b 100644 --- a/tests/test_jwks_client.py +++ b/tests/test_jwks_client.py @@ -63,12 +63,17 @@ def mocked_failed_response(): @contextlib.contextmanager -def mocked_first_call_wrong_kid_second_call_correct_kid(response_data_one, response_data_two): +def mocked_first_call_wrong_kid_second_call_correct_kid( + response_data_one, response_data_two +): with mock.patch("urllib.request.urlopen") as urlopen_mock: response = mock.Mock() response.__enter__ = mock.Mock(return_value=response) response.__exit__ = mock.Mock() - response.read.side_effect = [json.dumps(response_data_one), json.dumps(response_data_two)] + response.read.side_effect = [ + json.dumps(response_data_one), + json.dumps(response_data_two), + ] urlopen_mock.return_value = response yield urlopen_mock @@ -259,7 +264,8 @@ def test_get_jwt_set_refresh_cache(self): # The first call will return response with no matching kid, # the function should make another call to try to refresh the cache. with mocked_first_call_wrong_kid_second_call_correct_kid( - RESPONSE_DATA_NO_MATCHING_KID, RESPONSE_DATA_WITH_MATCHING_KID) as call_data: + RESPONSE_DATA_NO_MATCHING_KID, RESPONSE_DATA_WITH_MATCHING_KID + ) as call_data: jwks_client.get_signing_key(kid) assert call_data.call_count == 2 From 3dfe73fdb690e899b2cfd016a8344316a44140e1 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Sun, 10 Jul 2022 23:34:16 -0700 Subject: [PATCH 15/19] Fix ci error --- jwt/jwks_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py index 8d407ac1..fe3c3acd 100644 --- a/jwt/jwks_client.py +++ b/jwt/jwks_client.py @@ -94,7 +94,7 @@ def get_signing_key_from_jwt(self, token: str) -> PyJWK: return self.get_signing_key(header.get("kid")) @staticmethod - def match_kid(signing_keys: list[PyJWK], kid: str) -> Optional[PyJWK]: + def match_kid(signing_keys: List[PyJWK], kid: str) -> Optional[PyJWK]: signing_key = None for key in signing_keys: From 8c595b3a7349751a217a3f51e22bd98569f2eded Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Mon, 11 Jul 2022 09:55:21 -0700 Subject: [PATCH 16/19] Add type declaration to fix CI error --- jwt/jwk_set_cache.py | 2 +- jwt/jwks_client.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/jwt/jwk_set_cache.py b/jwt/jwk_set_cache.py index 2f52aa4d..e8c2a7e0 100644 --- a/jwt/jwk_set_cache.py +++ b/jwt/jwk_set_cache.py @@ -6,7 +6,7 @@ class JWKSetCache: def __init__(self, lifespan: int): - self.jwk_set_with_timestamp = None + self.jwk_set_with_timestamp: Optional[PyJWTSetWithTimestamp] = None self.lifespan = lifespan def put(self, jwk_set: PyJWKSet): diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py index fe3c3acd..18811a85 100644 --- a/jwt/jwks_client.py +++ b/jwt/jwks_client.py @@ -20,6 +20,7 @@ def __init__( lifespan: int = 300, ): self.uri = uri + self.jwk_set_cache: Optional[JWKSetCache] = None if cache_jwk_set: # Init jwt set cache with default or given lifespan. From e5dc7f7ec2b9083d90baca4c12e785b78d5aba94 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Mon, 11 Jul 2022 12:38:43 -0700 Subject: [PATCH 17/19] Add more unit tests to improve coverage --- jwt/jwks_client.py | 21 +++++++++++---------- tests/test_jwks_client.py | 30 ++++++++++++++++++++++++++---- 2 files changed, 37 insertions(+), 14 deletions(-) diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py index 18811a85..76b6f1d6 100644 --- a/jwt/jwks_client.py +++ b/jwt/jwks_client.py @@ -12,12 +12,12 @@ class PyJWKClient: def __init__( - self, - uri: str, - cache_keys: bool = False, - max_cached_keys: int = 16, - cache_jwk_set: bool = True, - lifespan: int = 300, + self, + uri: str, + cache_keys: bool = False, + max_cached_keys: int = 16, + cache_jwk_set: bool = True, + lifespan: int = 300, ): self.uri = uri self.jwk_set_cache: Optional[JWKSetCache] = None @@ -25,7 +25,7 @@ def __init__( if cache_jwk_set: # Init jwt set cache with default or given lifespan. # Default lifespan is 300 seconds (5 minutes). - if lifespan < 0: + if lifespan <= 0: raise PyJWKClientError( f'Lifespan must be greater than 0, the input is "{lifespan}"' ) @@ -39,14 +39,15 @@ def __init__( self.get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key) # type: ignore def fetch_data(self) -> Any: + jwk_set: Any = None try: with urllib.request.urlopen(self.uri) as response: jwk_set = json.load(response) except URLError as e: raise PyJWKClientError(f'Fail to fetch data from the url, err: "{e}"') - - if self.jwk_set_cache is not None: - self.jwk_set_cache.put(jwk_set) + finally: + if self.jwk_set_cache is not None: + self.jwk_set_cache.put(jwk_set) return jwk_set diff --git a/tests/test_jwks_client.py b/tests/test_jwks_client.py index 6111c69b..e5cf5c54 100644 --- a/tests/test_jwks_client.py +++ b/tests/test_jwks_client.py @@ -220,7 +220,7 @@ def test_get_jwt_set_cache_expired_result(self): with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): jwks_client.get_jwk_set() - time.sleep(1) + time.sleep(2) # mocked_response does not allow urllib.request.urlopen to be called twice # so a second mock is needed @@ -236,7 +236,7 @@ def test_get_jwt_set_cache_disabled(self): with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): jwks_client.get_jwk_set() - time.sleep(1) + time.sleep(2) # mocked_response does not allow urllib.request.urlopen to be called twice # so a second mock is needed @@ -245,13 +245,16 @@ def test_get_jwt_set_cache_disabled(self): assert repeated_call.call_count == 1 - def test_get_jwt_set_failed_request(self): + def test_get_jwt_set_failed_request_should_clear_cache(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" jwks_client = PyJWKClient(url) + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): + jwks_client.get_jwk_set() + with pytest.raises(PyJWKClientError): with mocked_failed_response(): - jwks_client.get_jwk_set() + jwks_client.get_jwk_set(refresh=True) assert jwks_client.jwk_set_cache is None @@ -269,3 +272,22 @@ def test_get_jwt_set_refresh_cache(self): jwks_client.get_signing_key(kid) assert call_data.call_count == 2 + + def test_get_jwt_set_no_matching_kid_after_second_attempt(self): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + jwks_client = PyJWKClient(url) + + kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw" + + with pytest.raises(PyJWKClientError): + with mocked_first_call_wrong_kid_second_call_correct_kid( + RESPONSE_DATA_NO_MATCHING_KID, RESPONSE_DATA_NO_MATCHING_KID + ): + jwks_client.get_signing_key(kid) + + def test_get_jwt_set_invalid_lifespan(self): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + + with pytest.raises(PyJWKClientError): + jwks_client = PyJWKClient(url, lifespan=-1) + assert jwks_client is None From 6d9149713e41d98c350cb93e61f1233edf8936c1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 11 Jul 2022 19:39:02 +0000 Subject: [PATCH 18/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- jwt/jwks_client.py | 12 ++++++------ tests/test_jwks_client.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py index 76b6f1d6..4f4a705f 100644 --- a/jwt/jwks_client.py +++ b/jwt/jwks_client.py @@ -12,12 +12,12 @@ class PyJWKClient: def __init__( - self, - uri: str, - cache_keys: bool = False, - max_cached_keys: int = 16, - cache_jwk_set: bool = True, - lifespan: int = 300, + self, + uri: str, + cache_keys: bool = False, + max_cached_keys: int = 16, + cache_jwk_set: bool = True, + lifespan: int = 300, ): self.uri = uri self.jwk_set_cache: Optional[JWKSetCache] = None diff --git a/tests/test_jwks_client.py b/tests/test_jwks_client.py index e5cf5c54..e963828f 100644 --- a/tests/test_jwks_client.py +++ b/tests/test_jwks_client.py @@ -281,7 +281,7 @@ def test_get_jwt_set_no_matching_kid_after_second_attempt(self): with pytest.raises(PyJWKClientError): with mocked_first_call_wrong_kid_second_call_correct_kid( - RESPONSE_DATA_NO_MATCHING_KID, RESPONSE_DATA_NO_MATCHING_KID + RESPONSE_DATA_NO_MATCHING_KID, RESPONSE_DATA_NO_MATCHING_KID ): jwks_client.get_signing_key(kid) From 6bb6aa753599f10b5ffa4c9fb531f2c044f50390 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Mon, 11 Jul 2022 15:05:49 -0700 Subject: [PATCH 19/19] Try to increase test coverage to 100% --- jwt/jwks_client.py | 4 ++-- tests/test_jwks_client.py | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py index 4f4a705f..b4e98007 100644 --- a/jwt/jwks_client.py +++ b/jwt/jwks_client.py @@ -45,12 +45,12 @@ def fetch_data(self) -> Any: jwk_set = json.load(response) except URLError as e: raise PyJWKClientError(f'Fail to fetch data from the url, err: "{e}"') + else: + return jwk_set finally: if self.jwk_set_cache is not None: self.jwk_set_cache.put(jwk_set) - return jwk_set - def get_jwk_set(self, refresh: bool = False) -> PyJWKSet: data = None if self.jwk_set_cache is not None and not refresh: diff --git a/tests/test_jwks_client.py b/tests/test_jwks_client.py index e963828f..c95dfcc0 100644 --- a/tests/test_jwks_client.py +++ b/tests/test_jwks_client.py @@ -202,6 +202,7 @@ def test_get_jwk_set_caches_result(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" jwks_client = PyJWKClient(url) + assert jwks_client.jwk_set_cache is not None with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): jwks_client.get_jwk_set() @@ -233,9 +234,13 @@ def test_get_jwt_set_cache_disabled(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" jwks_client = PyJWKClient(url, cache_jwk_set=False) + assert jwks_client.jwk_set_cache is None + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): jwks_client.get_jwk_set() + assert jwks_client.jwk_set_cache is None + time.sleep(2) # mocked_response does not allow urllib.request.urlopen to be called twice