Skip to content

Commit

Permalink
[azure] Use AZURE_CUSTOM_DOMAIN only for retrieving blob URLs, and …
Browse files Browse the repository at this point in the history
…use storage URL for other operations (#1176)
  • Loading branch information
JeffreyCA committed Sep 27, 2022
1 parent f7aa174 commit f0df471
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 29 deletions.
2 changes: 1 addition & 1 deletion docs/backends/azure.rst
Expand Up @@ -144,7 +144,7 @@ The following settings are available:

``AZURE_CUSTOM_DOMAIN``

The custom domain to use. This can be set in the Azure Portal. For
The custom domain to use for generating URLs for files. For
example, ``www.mydomain.com`` or ``mycdn.azureedge.net``.

``AZURE_CONNECTION_STRING``
Expand Down
32 changes: 22 additions & 10 deletions storages/backends/azure_storage.py
Expand Up @@ -123,7 +123,9 @@ class AzureStorage(BaseStorage):
def __init__(self, **settings):
super().__init__(**settings)
self._service_client = None
self._custom_service_client = None
self._client = None
self._custom_client = None
self._user_delegation_key = None
self._user_delegation_key_expiry = datetime.utcnow()

Expand All @@ -150,11 +152,11 @@ def get_default_settings(self):
"api_version": setting('AZURE_API_VERSION', None),
}

def _get_service_client(self):
def _get_service_client(self, use_custom_domain):
if self.connection_string is not None:
return BlobServiceClient.from_connection_string(self.connection_string)

account_domain = self.custom_domain or "{}.blob.{}".format(
account_domain = self.custom_domain if self.custom_domain and use_custom_domain else "{}.blob.{}".format(
self.account_name,
self.endpoint_suffix,
)
Expand All @@ -178,9 +180,15 @@ def _get_service_client(self):
@property
def service_client(self):
if self._service_client is None:
self._service_client = self._get_service_client()
self._service_client = self._get_service_client(use_custom_domain=False)
return self._service_client

@property
def custom_service_client(self):
if self._custom_service_client is None:
self._custom_service_client = self._get_service_client(use_custom_domain=True)
return self._custom_service_client

@property
def client(self):
if self._client is None:
Expand All @@ -189,6 +197,14 @@ def client(self):
)
return self._client

@property
def custom_client(self):
if self._custom_client is None:
self._custom_client = self.custom_service_client.get_container_client(
self.azure_container
)
return self._custom_client

def get_user_delegation_key(self, expiry):
# We'll only be able to get a user delegation key if we've authenticated with a
# token credential.
Expand All @@ -203,7 +219,7 @@ def get_user_delegation_key(self, expiry):
):
now = datetime.utcnow()
key_expiry_time = now + timedelta(days=7)
self._user_delegation_key = self.service_client.get_user_delegation_key(
self._user_delegation_key = self.custom_service_client.get_user_delegation_key(
key_start_time=now, key_expiry_time=key_expiry_time
)
self._user_delegation_key_expiry = key_expiry_time
Expand Down Expand Up @@ -244,11 +260,7 @@ def get_available_name(self, name, max_length=_AZURE_NAME_MAX_LEN):

def exists(self, name):
blob_client = self.client.get_blob_client(self._get_valid_path(name))
try:
blob_client.get_blob_properties()
return True
except ResourceNotFoundError:
return False
return blob_client.exists()

def delete(self, name):
try:
Expand Down Expand Up @@ -309,7 +321,7 @@ def url(self, name, expire=None, parameters=None):
)
credential = sas_token

container_blob_url = self.client.get_blob_client(name).url
container_blob_url = self.custom_client.get_blob_client(name).url
return BlobClient.from_blob_url(container_blob_url, credential=credential).url

def _get_content_settings_parameters(self, name, content=None):
Expand Down
99 changes: 81 additions & 18 deletions tests/test_azure.py
Expand Up @@ -3,7 +3,6 @@
from unittest import mock

import django
from azure.core.exceptions import ResourceNotFoundError
from azure.storage.blob import BlobProperties
from django.core.exceptions import SuspiciousOperation
from django.core.files.base import ContentFile
Expand All @@ -20,6 +19,7 @@ class AzureStorageTest(TestCase):
def setUp(self, *args):
self.storage = azure_storage.AzureStorage()
self.storage._client = mock.MagicMock()
self.storage._custom_client = mock.MagicMock()
self.storage.overwrite_files = True
self.account_name = 'test'
self.account_key = 'key'
Expand Down Expand Up @@ -88,41 +88,44 @@ def test_get_valid_path_idempotency(self):
def test_get_available_name(self):
self.storage.overwrite_files = False
client_mock = mock.MagicMock()
client_mock.get_blob_properties.side_effect = [True, ResourceNotFoundError]
client_mock.exists.side_effect = [True, False]
custom_client_mock = mock.MagicMock()
self.storage._client.get_blob_client.return_value = client_mock
self.storage._custom_client.get_blob_client.return_value = custom_client_mock
name = self.storage.get_available_name('foo.txt')
self.assertTrue(name.startswith('foo_'))
self.assertTrue(name.endswith('.txt'))
self.assertTrue(len(name) > len('foo.txt'))
self.assertEqual(client_mock.get_blob_properties.call_count, 2)
self.assertEqual(client_mock.exists.call_count, 2)
self.assertEqual(custom_client_mock.exists.call_count, 0)

def test_get_available_name_first(self):
self.storage.overwrite_files = False
client_mock = mock.MagicMock()
client_mock.get_blob_properties.side_effect = [ResourceNotFoundError]
client_mock.exists.return_value = False
self.storage._client.get_blob_client.return_value = client_mock
self.assertEqual(
self.storage.get_available_name('foo bar baz.txt'),
'foo bar baz.txt')
self.assertEqual(client_mock.get_blob_properties.call_count, 1)
self.assertEqual(client_mock.exists.call_count, 1)

def test_get_available_name_max_len(self):
self.storage.overwrite_files = False
# if you wonder why this is, file-system
# storage will raise when file name is too long as well,
# the form should validate this
client_mock = mock.MagicMock()
client_mock.get_blob_properties.side_effect = [False, ResourceNotFoundError]
client_mock.exists.side_effect = [True, False]
self.storage._client.get_blob_client.return_value = client_mock
self.assertRaises(ValueError, self.storage.get_available_name, 'a' * 1025)
name = self.storage.get_available_name('a' * 1000, max_length=100) # max_len == 1024
self.assertEqual(len(name), 100)
self.assertTrue('_' in name)
self.assertEqual(client_mock.get_blob_properties.call_count, 2)
self.assertEqual(client_mock.exists.call_count, 2)

def test_get_available_invalid(self):
self.storage.overwrite_files = False
self.storage._client.get_blob_properties.return_value = False
self.storage._client.exists.return_value = False
if django.VERSION[:2] == (3, 0):
# Django 2.2.21 added this security fix:
# https://docs.djangoproject.com/en/3.2/releases/2.2.21/#cve-2021-31542-potential-directory-traversal-via-uploaded-files
Expand All @@ -142,25 +145,27 @@ def test_get_available_invalid(self):
def test_url(self):
blob_mock = mock.MagicMock()
blob_mock.url = 'https://ret_foo.blob.core.windows.net/test/some%20blob'
self.storage._client.get_blob_client.return_value = blob_mock
self.storage._custom_client.get_blob_client.return_value = blob_mock
self.assertEqual(self.storage.url('some blob'), blob_mock.url)
self.storage._client.get_blob_client.assert_called_once_with('some blob')
self.storage.custom_client.get_blob_client.assert_called_once_with('some blob')
self.storage._client.get_blob_client.assert_not_called()

def test_url_unsafe_chars(self):
blob_mock = mock.MagicMock()
blob_mock.url = 'https://ret_foo.blob.core.windows.net/test/some%20blob'
self.storage._client.get_blob_client.return_value = blob_mock
self.storage._custom_client.get_blob_client.return_value = blob_mock
self.assertEqual(
self.storage.url('foo;?:@=&"<>#%{}|^~[]`bar/~!*()\''), blob_mock.url)
self.storage.client.get_blob_client.assert_called_once_with(
self.storage.custom_client.get_blob_client.assert_called_once_with(
'foo;?:@=&"<>#%{}|^~[]`bar/~!*()\'')
self.storage._client.get_blob_client.assert_not_called()

@mock.patch('storages.backends.azure_storage.generate_blob_sas')
def test_url_expire(self, generate_blob_sas_mocked):
generate_blob_sas_mocked.return_value = 'foo_token'
blob_mock = mock.MagicMock()
blob_mock.url = 'https://ret_foo.blob.core.windows.net/test/some%20blob'
self.storage._client.get_blob_client.return_value = blob_mock
self.storage._custom_client.get_blob_client.return_value = blob_mock
self.storage.account_name = self.account_name

fixed_time = make_aware(datetime.datetime(2016, 11, 6, 4), timezone.utc)
Expand All @@ -183,16 +188,16 @@ def test_url_expire_user_delegation_key(self, generate_blob_sas_mocked):
generate_blob_sas_mocked.return_value = 'foo_token'
blob_mock = mock.MagicMock()
blob_mock.url = 'https://ret_foo.blob.core.windows.net/test/some%20blob'
self.storage._client.get_blob_client.return_value = blob_mock
self.storage._custom_client.get_blob_client.return_value = blob_mock
self.storage.account_name = self.account_name
service_client = mock.MagicMock()
self.storage._service_client = service_client
custom_service_client = mock.MagicMock()
self.storage._custom_service_client = custom_service_client
self.storage.token_credential = 'token_credential'

fixed_time = make_aware(datetime.datetime(2016, 11, 6, 4), timezone.utc)
with mock.patch('storages.backends.azure_storage.datetime') as d_mocked:
d_mocked.utcnow.return_value = fixed_time
service_client.get_user_delegation_key.return_value = 'user delegation key'
custom_service_client.get_user_delegation_key.return_value = 'user delegation key'
self.assertEqual(
self.storage.url('some blob', 100),
'https://ret_foo.blob.core.windows.net/test/some%20blob')
Expand Down Expand Up @@ -228,9 +233,17 @@ def test_container_client_params_account_key(self):
'storages.backends.azure_storage.BlobServiceClient',
autospec=True) as bsc_mocked:
client_mock = mock.MagicMock()
custom_client_mock = mock.MagicMock()
bsc_mocked.return_value.get_container_client.return_value = client_mock
self.assertEqual(storage.client, client_mock)
bsc_mocked.assert_called_once_with(
'https://foo_name.blob.core.windows.net',
credential={'account_name': 'foo_name', 'account_key': 'foo_key'})

bsc_mocked.return_value.get_container_client.return_value = custom_client_mock
self.assertEqual(storage.custom_client, custom_client_mock)
self.assertEqual(bsc_mocked.call_count, 2)
bsc_mocked.assert_called_with(
'https://foo_domain',
credential={'account_name': 'foo_name', 'account_key': 'foo_key'})

Expand All @@ -244,9 +257,17 @@ def test_container_client_params_sas_token(self):
'storages.backends.azure_storage.BlobServiceClient',
autospec=True) as bsc_mocked:
client_mock = mock.MagicMock()
custom_client_mock = mock.MagicMock()
bsc_mocked.return_value.get_container_client.return_value = client_mock
self.assertEqual(storage.client, client_mock)
bsc_mocked.assert_called_once_with(
'http://foo_name.blob.core.windows.net',
credential='foo_token')

bsc_mocked.return_value.get_container_client.return_value = custom_client_mock
self.assertEqual(storage.custom_client, custom_client_mock)
self.assertEqual(bsc_mocked.call_count, 2)
bsc_mocked.assert_called_with(
'http://foo_domain',
credential='foo_token')

Expand Down Expand Up @@ -298,6 +319,7 @@ def test_storage_save(self):
content_type='text/plain',
content_encoding=None,
cache_control=None)
self.storage._custom_client.upload_blob.assert_not_called()

def test_storage_open_write(self):
"""
Expand All @@ -317,18 +339,23 @@ def test_storage_open_write(self):
max_concurrency=2,
timeout=20,
overwrite=True)
self.storage._custom_client.upload_blob.assert_not_called()

def test_storage_exists(self):
blob_name = "blob"
client_mock = mock.MagicMock()
custom_client_mock = mock.MagicMock()
self.storage._client.get_blob_client.return_value = client_mock
self.storage._custom_client.get_blob_client.return_value = client_mock
self.assertTrue(self.storage.exists(blob_name))
self.assertEqual(client_mock.get_blob_properties.call_count, 1)
self.assertEqual(client_mock.exists.call_count, 1)
self.assertEqual(custom_client_mock.exists.call_count, 0)

def test_delete_blob(self):
self.storage.delete("name")
self.storage._client.delete_blob.assert_called_once_with(
"name", timeout=20)
self.storage._custom_client.delete_blob.assert_not_called()

def test_storage_listdir_base(self):
file_names = ["some/path/1.txt", "2.txt", "other/path/3.txt", "4.txt"]
Expand All @@ -343,6 +370,7 @@ def test_storage_listdir_base(self):
dirs, files = self.storage.listdir("")
self.storage._client.list_blobs.assert_called_with(
name_starts_with="", timeout=20)
self.storage._custom_client.list_blobs.assert_not_called()

self.assertEqual(len(dirs), 2)
for directory in ["some", "other"]:
Expand All @@ -365,6 +393,7 @@ def test_storage_listdir_subdir(self):
obj.name = p
result.append(obj)
self.storage._client.list_blobs.return_value = iter(result)
self.storage._custom_client.list_blobs.assert_not_called()

dirs, files = self.storage.listdir("some/")
self.storage._client.list_blobs.assert_called_with(
Expand Down Expand Up @@ -424,3 +453,37 @@ def test_override_init_argument(self):
self.assertEqual(storage.azure_container, 'foo1')
storage = azure_storage.AzureStorage(azure_container='foo2')
self.assertEqual(storage.azure_container, 'foo2')

@mock.patch('storages.backends.azure_storage.AzureStorage._get_service_client',)
def test_get_service_client_use_custom_domain(self, gsc_mocked):
storage = azure_storage.AzureStorage()
storage.account_name = self.account_name

_ = storage.service_client
gsc_mocked.assert_called_once_with(use_custom_domain=False)

_ = storage.custom_service_client
gsc_mocked.assert_called_with(use_custom_domain=True)

def test_blobserviceclient_no_custom_domain(self):
storage = azure_storage.AzureStorage()
storage.account_name = 'foo_name'
storage.custom_domain = None
storage.account_key = 'foo_key'
with mock.patch(
'storages.backends.azure_storage.BlobServiceClient',
autospec=True) as bsc_mocked:
client_mock = mock.MagicMock()
custom_client_mock = mock.MagicMock()
bsc_mocked.return_value.get_container_client.return_value = client_mock
self.assertEqual(storage.client, client_mock)
bsc_mocked.assert_called_once_with(
'https://foo_name.blob.core.windows.net',
credential={'account_name': 'foo_name', 'account_key': 'foo_key'})

bsc_mocked.return_value.get_container_client.return_value = custom_client_mock
self.assertEqual(storage.custom_client, custom_client_mock)
self.assertEqual(bsc_mocked.call_count, 2)
bsc_mocked.assert_called_with(
'https://foo_name.blob.core.windows.net',
credential={'account_name': 'foo_name', 'account_key': 'foo_key'})

0 comments on commit f0df471

Please sign in to comment.