Skip to content

Commit

Permalink
Add cacheing functionality for JWK set (#781)
Browse files Browse the repository at this point in the history
* Initial implementation of ttl jwk set cache

(cherry picked from commit 479a7c1)

* Add unit test for jwk set cache

* Fix failed unit test

* Disable cache signing key by default

* Add a negative unit test for get_jwk_set

* Add functionality to force refresh the jwk set cache when no matching signing key can be found from the cache

* Add unit test for refresh cache

* Add unit test to unset cache when the network call throws error

* fix naming typo

* Update unit test naming

* Update comment

* Add check for lifespan

* Update comments for get_signing_key

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix ci error

* Add type declaration to fix CI error

* Add more unit tests to improve coverage

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Try to increase test coverage to 100%

Co-authored-by: Jerry Wu <hawu@roku.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Aug 1, 2022
1 parent ae3da74 commit fc5b94e
Show file tree
Hide file tree
Showing 4 changed files with 263 additions and 33 deletions.
13 changes: 13 additions & 0 deletions jwt/api_jwk.py
@@ -1,6 +1,7 @@
from __future__ import annotations

import json
import time

from .algorithms import get_default_algorithms
from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError
Expand Down Expand Up @@ -110,3 +111,15 @@ def __getitem__(self, kid):
if key.key_id == kid:
return key
raise KeyError(f"keyset has no key for kid: {kid}")


class PyJWTSetWithTimestamp:
def __init__(self, jwk_set: PyJWKSet):
self.jwk_set = jwk_set
self.timestamp = time.monotonic()

def get_jwk_set(self):
return self.jwk_set

def get_timestamp(self):
return self.timestamp
32 changes: 32 additions & 0 deletions jwt/jwk_set_cache.py
@@ -0,0 +1,32 @@
import time
from typing import Optional

from .api_jwk import PyJWKSet, PyJWTSetWithTimestamp


class JWKSetCache:
def __init__(self, lifespan: int):
self.jwk_set_with_timestamp: Optional[PyJWTSetWithTimestamp] = None
self.lifespan = lifespan

def put(self, jwk_set: PyJWKSet):
if jwk_set is not None:
self.jwk_set_with_timestamp = PyJWTSetWithTimestamp(jwk_set)
else:
# clear cache
self.jwk_set_with_timestamp = None

def get(self) -> Optional[PyJWKSet]:
if self.jwk_set_with_timestamp is None or self.is_expired():
return None

return self.jwk_set_with_timestamp.get_jwk_set()

def is_expired(self) -> bool:

return (
self.jwk_set_with_timestamp is not None
and self.lifespan > -1
and time.monotonic()
> self.jwk_set_with_timestamp.get_timestamp() + self.lifespan
)
82 changes: 65 additions & 17 deletions jwt/jwks_client.py
@@ -1,31 +1,68 @@
import json
import urllib.request
from functools import lru_cache
from typing import Any, List
from typing import Any, List, Optional
from urllib.error import URLError

from .api_jwk import PyJWK, PyJWKSet
from .api_jwt import decode_complete as decode_token
from .exceptions import PyJWKClientError
from .jwk_set_cache import JWKSetCache


class PyJWKClient:
def __init__(self, uri: str, cache_keys: bool = True, max_cached_keys: int = 16):
def __init__(
self,
uri: str,
cache_keys: bool = False,
max_cached_keys: int = 16,
cache_jwk_set: bool = True,
lifespan: int = 300,
):
self.uri = uri
self.jwk_set_cache: Optional[JWKSetCache] = None

if cache_jwk_set:
# Init jwt set cache with default or given lifespan.
# Default lifespan is 300 seconds (5 minutes).
if lifespan <= 0:
raise PyJWKClientError(
f'Lifespan must be greater than 0, the input is "{lifespan}"'
)
self.jwk_set_cache = JWKSetCache(lifespan)
else:
self.jwk_set_cache = None

if cache_keys:
# Cache signing keys
# Ignore mypy (https://github.com/python/mypy/issues/2427)
self.get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key) # type: ignore

def fetch_data(self) -> Any:
with urllib.request.urlopen(self.uri) as response:
return json.load(response)
jwk_set: Any = None
try:
with urllib.request.urlopen(self.uri) as response:
jwk_set = json.load(response)
except URLError as e:
raise PyJWKClientError(f'Fail to fetch data from the url, err: "{e}"')
else:
return jwk_set
finally:
if self.jwk_set_cache is not None:
self.jwk_set_cache.put(jwk_set)

def get_jwk_set(self, refresh: bool = False) -> PyJWKSet:
data = None
if self.jwk_set_cache is not None and not refresh:
data = self.jwk_set_cache.get()

if data is None:
data = self.fetch_data()

def get_jwk_set(self) -> PyJWKSet:
data = self.fetch_data()
return PyJWKSet.from_dict(data)

def get_signing_keys(self) -> List[PyJWK]:
jwk_set = self.get_jwk_set()
def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]:
jwk_set = self.get_jwk_set(refresh)
signing_keys = [
jwk_set_key
for jwk_set_key in jwk_set.keys
Expand All @@ -39,21 +76,32 @@ def get_signing_keys(self) -> List[PyJWK]:

def get_signing_key(self, kid: str) -> PyJWK:
signing_keys = self.get_signing_keys()
signing_key = None

for key in signing_keys:
if key.key_id == kid:
signing_key = key
break
signing_key = self.match_kid(signing_keys, kid)

if not signing_key:
raise PyJWKClientError(
f'Unable to find a signing key that matches: "{kid}"'
)
# If no matching signing key from the jwk set, refresh the jwk set and try again.
signing_keys = self.get_signing_keys(refresh=True)
signing_key = self.match_kid(signing_keys, kid)

if not signing_key:
raise PyJWKClientError(
f'Unable to find a signing key that matches: "{kid}"'
)

return signing_key

def get_signing_key_from_jwt(self, token: str) -> PyJWK:
unverified = decode_token(token, options={"verify_signature": False})
header = unverified["header"]
return self.get_signing_key(header.get("kid"))

@staticmethod
def match_kid(signing_keys: List[PyJWK], kid: str) -> Optional[PyJWK]:
signing_key = None

for key in signing_keys:
if key.key_id == kid:
signing_key = key
break

return signing_key

0 comments on commit fc5b94e

Please sign in to comment.