Skip to content

Commit

Permalink
{Core} Retry loading MSAL HTTP cache in case of failure (#20722)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiasli committed Dec 27, 2021
1 parent 103c4e9 commit fb805f4
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 37 deletions.
7 changes: 6 additions & 1 deletion src/azure-cli-core/azure/cli/core/_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,12 @@ def _create_identity_instance(cli_ctx, *args, **kwargs):

# Only enable encryption for Windows (for now).
fallback = sys.platform.startswith('win32')

# EXPERIMENTAL: Use core.encrypt_token_cache=False to turn off token cache encryption.
# encrypt_token_cache affects both MSAL token cache and service principal entries.
encrypt = cli_ctx.config.getboolean('core', 'encrypt_token_cache', fallback=fallback)

return Identity(*args, encrypt=encrypt, **kwargs)
# EXPERIMENTAL: Use core.use_msal_http_cache=False to turn off MSAL HTTP cache.
use_msal_http_cache = cli_ctx.config.getboolean('core', 'use_msal_http_cache', fallback=True)

return Identity(*args, encrypt=encrypt, use_msal_http_cache=use_msal_http_cache, **kwargs)
54 changes: 39 additions & 15 deletions src/azure-cli-core/azure/cli/core/auth/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

import json
import os
import pickle
import re

from azure.cli.core._environment import get_config_dir
from azure.cli.core.decorators import retry
from msal import PublicClientApplication

from knack.log import get_logger
Expand Down Expand Up @@ -45,7 +47,7 @@ class Identity: # pylint: disable=too-many-instance-attributes
# It follows singleton pattern so that _secret_file is read only once.
_service_principal_store_instance = None

def __init__(self, authority, tenant_id=None, client_id=None, encrypt=False):
def __init__(self, authority, tenant_id=None, client_id=None, encrypt=False, use_msal_http_cache=True):
"""
:param authority: Authentication authority endpoint. For example,
- AAD: https://login.microsoftonline.com
Expand All @@ -58,7 +60,8 @@ def __init__(self, authority, tenant_id=None, client_id=None, encrypt=False):
self.authority = authority
self.tenant_id = tenant_id
self.client_id = client_id or AZURE_CLI_CLIENT_ID
self.encrypt = encrypt
self._encrypt = encrypt
self._use_msal_http_cache = use_msal_http_cache

# Build the authority in MSAL style
self._msal_authority, self._is_adfs = _get_authority_url(authority, tenant_id)
Expand All @@ -80,7 +83,7 @@ def _msal_app_kwargs(self):
if not Identity._msal_token_cache:
Identity._msal_token_cache = self._load_msal_token_cache()

if not Identity._msal_http_cache:
if self._use_msal_http_cache and not Identity._msal_http_cache:
Identity._msal_http_cache = self._load_msal_http_cache()

return {
Expand All @@ -100,25 +103,46 @@ def _msal_app(self):

def _load_msal_token_cache(self):
# Store for user token persistence
cache = load_persisted_token_cache(self._token_cache_file, self.encrypt)
cache = load_persisted_token_cache(self._token_cache_file, self._encrypt)
return cache

@retry()
def __load_msal_http_cache(self):
"""Load MSAL HTTP cache with retry. If it still fails at last, raise the original exception as-is."""
logger.debug("__load_msal_http_cache: %s", self._http_cache_file)
try:
with open(self._http_cache_file, 'rb') as f:
return pickle.load(f)
except FileNotFoundError:
# The cache file has not been created. This is expected.
logger.debug("%s not found. Using a fresh one.", self._http_cache_file)
return {}

def _dump_msal_http_cache(self):
logger.debug("_dump_msal_http_cache: %s", self._http_cache_file)
with open(self._http_cache_file, 'wb') as f:
# At this point, an empty cache file will be created. Loading this cache file will
# trigger EOFError. This can be simulated by adding time.sleep(30) here.
# So, during loading, EOFError is ignored.
pickle.dump(self._msal_http_cache, f)

def _load_msal_http_cache(self):
import atexit
import pickle

logger.debug("_load_msal_http_cache: %s", self._http_cache_file)
try:
with open(self._http_cache_file, 'rb') as f:
persisted_http_cache = pickle.load(f)
except (pickle.UnpicklingError, FileNotFoundError) as ex:
logger.debug("Failed to load MSAL HTTP cache: %s", ex)
persisted_http_cache = self.__load_msal_http_cache()
except (pickle.UnpicklingError, EOFError) as ex:
# We still get exception after retry:
# - pickle.UnpicklingError is caused by corrupted cache file, perhaps due to concurrent writes.
# - EOFError is caused by empty cache file created by other az instance, but hasn't been filled yet.
logger.debug("Failed to load MSAL HTTP cache: %s. Using a fresh one.", ex)
persisted_http_cache = {} # Ignore a non-exist or corrupted http_cache
atexit.register(lambda: pickle.dump(
# When exit, flush it back to the file.
# If 2 processes write at the same time, the cache will be corrupted,
# but that is fine. Subsequent runs would reach eventual consistency.
persisted_http_cache, open(self._http_cache_file, 'wb')))

# When exiting, flush it back to the file.
# If 2 processes write at the same time, the cache will be corrupted,
# but that is fine. Subsequent runs would reach eventual consistency.
atexit.register(self._dump_msal_http_cache)

return persisted_http_cache

Expand All @@ -128,7 +152,7 @@ def _service_principal_store(self):
The instance is lazily created.
"""
if not Identity._service_principal_store_instance:
store = load_secret_store(self._secret_file, self.encrypt)
store = load_secret_store(self._secret_file, self._encrypt)
Identity._service_principal_store_instance = ServicePrincipalStore(store)
return Identity._service_principal_store_instance

Expand Down
25 changes: 4 additions & 21 deletions src/azure-cli-core/azure/cli/core/auth/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@

import json
import sys
import time

from msal_extensions import (FilePersistenceWithDataProtection, KeychainPersistence, LibsecretPersistence,
FilePersistence, PersistedTokenCache, CrossPlatLock)
from msal_extensions.persistence import PersistenceNotFound

from knack.log import get_logger
from azure.cli.core.decorators import retry


logger = get_logger(__name__)

Expand Down Expand Up @@ -60,27 +61,9 @@ def save(self, content):
with CrossPlatLock(self._lock_file):
self._persistence.save(json.dumps(content, indent=4))

def _load(self):
@retry()
def load(self):
try:
return json.loads(self._persistence.load())
except PersistenceNotFound:
return []

def load(self):
# Use optimistic locking rather than CrossPlatLock, so that multiple processes can
# read the same file at the same time.
retry = 3
for attempt in range(1, retry + 1):
try:
return self._load()
except Exception: # pylint: disable=broad-except
# Presumably other processes are writing the file, causing dirty read
if attempt < retry:
logger.debug("Unable to load secret store in No. %d attempt", attempt)
import traceback
logger.debug(traceback.format_exc())
time.sleep(0.5)
else:
raise # End of retry. Re-raise the exception as-is.

return [] # Not really reachable here. Just to keep pylint happy.
30 changes: 30 additions & 0 deletions src/azure-cli-core/azure/cli/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from knack.log import get_logger


logger = get_logger(__name__)


# pylint: disable=too-few-public-methods
class Completer:

Expand Down Expand Up @@ -81,3 +84,30 @@ def _wrapped_func(*args, **kwargs):
return fallback_return
return _wrapped_func
return _decorator


def retry(retry_times=3, interval=0.5, exceptions=Exception):
"""Use optimistic locking to call a function, so that multiple processes can
access the same resource (such as a file) at the same time.
:param retry_times: Times to retry.
:param interval: Interval between retries.
:param exceptions: Exceptions that can be ignored. Use a tuple if multiple exceptions should be ignored.
"""
def _decorator(func):
@wraps(func)
def _wrapped_func(*args, **kwargs):
for attempt in range(1, retry_times + 1):
try:
return func(*args, **kwargs)
except exceptions: # pylint: disable=broad-except
if attempt < retry_times:
logger.debug("%s failed in No. %d attempt", func, attempt)
import traceback
import time
logger.debug(traceback.format_exc())
time.sleep(interval)
else:
raise # End of retry. Re-raise the exception as-is.
return _wrapped_func
return _decorator

0 comments on commit fb805f4

Please sign in to comment.