diff --git a/docs/backends/azure.rst b/docs/backends/azure.rst index 872a6aaa2..2a2089b4d 100644 --- a/docs/backends/azure.rst +++ b/docs/backends/azure.rst @@ -144,7 +144,7 @@ The following settings are available: ``AZURE_CUSTOM_DOMAIN`` - The custom domain to use. This can be set in the Azure Portal. For + The custom domain to use for generating URLs for files. For example, ``www.mydomain.com`` or ``mycdn.azureedge.net``. ``AZURE_CONNECTION_STRING`` diff --git a/storages/backends/azure_storage.py b/storages/backends/azure_storage.py index f20c49989..6f32919b7 100644 --- a/storages/backends/azure_storage.py +++ b/storages/backends/azure_storage.py @@ -123,7 +123,9 @@ class AzureStorage(BaseStorage): def __init__(self, **settings): super().__init__(**settings) self._service_client = None + self._custom_service_client = None self._client = None + self._custom_client = None self._user_delegation_key = None self._user_delegation_key_expiry = datetime.utcnow() @@ -150,11 +152,11 @@ def get_default_settings(self): "api_version": setting('AZURE_API_VERSION', None), } - def _get_service_client(self): + def _get_service_client(self, use_custom_domain): if self.connection_string is not None: return BlobServiceClient.from_connection_string(self.connection_string) - account_domain = self.custom_domain or "{}.blob.{}".format( + account_domain = self.custom_domain if self.custom_domain and use_custom_domain else "{}.blob.{}".format( self.account_name, self.endpoint_suffix, ) @@ -178,9 +180,15 @@ def _get_service_client(self): @property def service_client(self): if self._service_client is None: - self._service_client = self._get_service_client() + self._service_client = self._get_service_client(use_custom_domain=False) return self._service_client + @property + def custom_service_client(self): + if self._custom_service_client is None: + self._custom_service_client = self._get_service_client(use_custom_domain=True) + return self._custom_service_client + @property def client(self): if self._client is None: @@ -189,6 +197,14 @@ def client(self): ) return self._client + @property + def custom_client(self): + if self._custom_client is None: + self._custom_client = self.custom_service_client.get_container_client( + self.azure_container + ) + return self._custom_client + def get_user_delegation_key(self, expiry): # We'll only be able to get a user delegation key if we've authenticated with a # token credential. @@ -203,7 +219,7 @@ def get_user_delegation_key(self, expiry): ): now = datetime.utcnow() key_expiry_time = now + timedelta(days=7) - self._user_delegation_key = self.service_client.get_user_delegation_key( + self._user_delegation_key = self.custom_service_client.get_user_delegation_key( key_start_time=now, key_expiry_time=key_expiry_time ) self._user_delegation_key_expiry = key_expiry_time @@ -244,11 +260,7 @@ def get_available_name(self, name, max_length=_AZURE_NAME_MAX_LEN): def exists(self, name): blob_client = self.client.get_blob_client(self._get_valid_path(name)) - try: - blob_client.get_blob_properties() - return True - except ResourceNotFoundError: - return False + return blob_client.exists() def delete(self, name): try: @@ -309,7 +321,7 @@ def url(self, name, expire=None, parameters=None): ) credential = sas_token - container_blob_url = self.client.get_blob_client(name).url + container_blob_url = self.custom_client.get_blob_client(name).url return BlobClient.from_blob_url(container_blob_url, credential=credential).url def _get_content_settings_parameters(self, name, content=None): diff --git a/tests/test_azure.py b/tests/test_azure.py index 149cdd721..941635f7c 100644 --- a/tests/test_azure.py +++ b/tests/test_azure.py @@ -3,7 +3,6 @@ from unittest import mock import django -from azure.core.exceptions import ResourceNotFoundError from azure.storage.blob import BlobProperties from django.core.exceptions import SuspiciousOperation from django.core.files.base import ContentFile @@ -20,6 +19,7 @@ class AzureStorageTest(TestCase): def setUp(self, *args): self.storage = azure_storage.AzureStorage() self.storage._client = mock.MagicMock() + self.storage._custom_client = mock.MagicMock() self.storage.overwrite_files = True self.account_name = 'test' self.account_key = 'key' @@ -88,23 +88,26 @@ def test_get_valid_path_idempotency(self): def test_get_available_name(self): self.storage.overwrite_files = False client_mock = mock.MagicMock() - client_mock.get_blob_properties.side_effect = [True, ResourceNotFoundError] + client_mock.exists.side_effect = [True, False] + custom_client_mock = mock.MagicMock() self.storage._client.get_blob_client.return_value = client_mock + self.storage._custom_client.get_blob_client.return_value = custom_client_mock name = self.storage.get_available_name('foo.txt') self.assertTrue(name.startswith('foo_')) self.assertTrue(name.endswith('.txt')) self.assertTrue(len(name) > len('foo.txt')) - self.assertEqual(client_mock.get_blob_properties.call_count, 2) + self.assertEqual(client_mock.exists.call_count, 2) + self.assertEqual(custom_client_mock.exists.call_count, 0) def test_get_available_name_first(self): self.storage.overwrite_files = False client_mock = mock.MagicMock() - client_mock.get_blob_properties.side_effect = [ResourceNotFoundError] + client_mock.exists.return_value = False self.storage._client.get_blob_client.return_value = client_mock self.assertEqual( self.storage.get_available_name('foo bar baz.txt'), 'foo bar baz.txt') - self.assertEqual(client_mock.get_blob_properties.call_count, 1) + self.assertEqual(client_mock.exists.call_count, 1) def test_get_available_name_max_len(self): self.storage.overwrite_files = False @@ -112,17 +115,17 @@ def test_get_available_name_max_len(self): # storage will raise when file name is too long as well, # the form should validate this client_mock = mock.MagicMock() - client_mock.get_blob_properties.side_effect = [False, ResourceNotFoundError] + client_mock.exists.side_effect = [True, False] self.storage._client.get_blob_client.return_value = client_mock self.assertRaises(ValueError, self.storage.get_available_name, 'a' * 1025) name = self.storage.get_available_name('a' * 1000, max_length=100) # max_len == 1024 self.assertEqual(len(name), 100) self.assertTrue('_' in name) - self.assertEqual(client_mock.get_blob_properties.call_count, 2) + self.assertEqual(client_mock.exists.call_count, 2) def test_get_available_invalid(self): self.storage.overwrite_files = False - self.storage._client.get_blob_properties.return_value = False + self.storage._client.exists.return_value = False if django.VERSION[:2] == (3, 0): # Django 2.2.21 added this security fix: # https://docs.djangoproject.com/en/3.2/releases/2.2.21/#cve-2021-31542-potential-directory-traversal-via-uploaded-files @@ -142,25 +145,27 @@ def test_get_available_invalid(self): def test_url(self): blob_mock = mock.MagicMock() blob_mock.url = 'https://ret_foo.blob.core.windows.net/test/some%20blob' - self.storage._client.get_blob_client.return_value = blob_mock + self.storage._custom_client.get_blob_client.return_value = blob_mock self.assertEqual(self.storage.url('some blob'), blob_mock.url) - self.storage._client.get_blob_client.assert_called_once_with('some blob') + self.storage.custom_client.get_blob_client.assert_called_once_with('some blob') + self.storage._client.get_blob_client.assert_not_called() def test_url_unsafe_chars(self): blob_mock = mock.MagicMock() blob_mock.url = 'https://ret_foo.blob.core.windows.net/test/some%20blob' - self.storage._client.get_blob_client.return_value = blob_mock + self.storage._custom_client.get_blob_client.return_value = blob_mock self.assertEqual( self.storage.url('foo;?:@=&"<>#%{}|^~[]`bar/~!*()\''), blob_mock.url) - self.storage.client.get_blob_client.assert_called_once_with( + self.storage.custom_client.get_blob_client.assert_called_once_with( 'foo;?:@=&"<>#%{}|^~[]`bar/~!*()\'') + self.storage._client.get_blob_client.assert_not_called() @mock.patch('storages.backends.azure_storage.generate_blob_sas') def test_url_expire(self, generate_blob_sas_mocked): generate_blob_sas_mocked.return_value = 'foo_token' blob_mock = mock.MagicMock() blob_mock.url = 'https://ret_foo.blob.core.windows.net/test/some%20blob' - self.storage._client.get_blob_client.return_value = blob_mock + self.storage._custom_client.get_blob_client.return_value = blob_mock self.storage.account_name = self.account_name fixed_time = make_aware(datetime.datetime(2016, 11, 6, 4), timezone.utc) @@ -183,16 +188,16 @@ def test_url_expire_user_delegation_key(self, generate_blob_sas_mocked): generate_blob_sas_mocked.return_value = 'foo_token' blob_mock = mock.MagicMock() blob_mock.url = 'https://ret_foo.blob.core.windows.net/test/some%20blob' - self.storage._client.get_blob_client.return_value = blob_mock + self.storage._custom_client.get_blob_client.return_value = blob_mock self.storage.account_name = self.account_name - service_client = mock.MagicMock() - self.storage._service_client = service_client + custom_service_client = mock.MagicMock() + self.storage._custom_service_client = custom_service_client self.storage.token_credential = 'token_credential' fixed_time = make_aware(datetime.datetime(2016, 11, 6, 4), timezone.utc) with mock.patch('storages.backends.azure_storage.datetime') as d_mocked: d_mocked.utcnow.return_value = fixed_time - service_client.get_user_delegation_key.return_value = 'user delegation key' + custom_service_client.get_user_delegation_key.return_value = 'user delegation key' self.assertEqual( self.storage.url('some blob', 100), 'https://ret_foo.blob.core.windows.net/test/some%20blob') @@ -228,9 +233,17 @@ def test_container_client_params_account_key(self): 'storages.backends.azure_storage.BlobServiceClient', autospec=True) as bsc_mocked: client_mock = mock.MagicMock() + custom_client_mock = mock.MagicMock() bsc_mocked.return_value.get_container_client.return_value = client_mock self.assertEqual(storage.client, client_mock) bsc_mocked.assert_called_once_with( + 'https://foo_name.blob.core.windows.net', + credential={'account_name': 'foo_name', 'account_key': 'foo_key'}) + + bsc_mocked.return_value.get_container_client.return_value = custom_client_mock + self.assertEqual(storage.custom_client, custom_client_mock) + self.assertEqual(bsc_mocked.call_count, 2) + bsc_mocked.assert_called_with( 'https://foo_domain', credential={'account_name': 'foo_name', 'account_key': 'foo_key'}) @@ -244,9 +257,17 @@ def test_container_client_params_sas_token(self): 'storages.backends.azure_storage.BlobServiceClient', autospec=True) as bsc_mocked: client_mock = mock.MagicMock() + custom_client_mock = mock.MagicMock() bsc_mocked.return_value.get_container_client.return_value = client_mock self.assertEqual(storage.client, client_mock) bsc_mocked.assert_called_once_with( + 'http://foo_name.blob.core.windows.net', + credential='foo_token') + + bsc_mocked.return_value.get_container_client.return_value = custom_client_mock + self.assertEqual(storage.custom_client, custom_client_mock) + self.assertEqual(bsc_mocked.call_count, 2) + bsc_mocked.assert_called_with( 'http://foo_domain', credential='foo_token') @@ -298,6 +319,7 @@ def test_storage_save(self): content_type='text/plain', content_encoding=None, cache_control=None) + self.storage._custom_client.upload_blob.assert_not_called() def test_storage_open_write(self): """ @@ -317,18 +339,23 @@ def test_storage_open_write(self): max_concurrency=2, timeout=20, overwrite=True) + self.storage._custom_client.upload_blob.assert_not_called() def test_storage_exists(self): blob_name = "blob" client_mock = mock.MagicMock() + custom_client_mock = mock.MagicMock() self.storage._client.get_blob_client.return_value = client_mock + self.storage._custom_client.get_blob_client.return_value = client_mock self.assertTrue(self.storage.exists(blob_name)) - self.assertEqual(client_mock.get_blob_properties.call_count, 1) + self.assertEqual(client_mock.exists.call_count, 1) + self.assertEqual(custom_client_mock.exists.call_count, 0) def test_delete_blob(self): self.storage.delete("name") self.storage._client.delete_blob.assert_called_once_with( "name", timeout=20) + self.storage._custom_client.delete_blob.assert_not_called() def test_storage_listdir_base(self): file_names = ["some/path/1.txt", "2.txt", "other/path/3.txt", "4.txt"] @@ -343,6 +370,7 @@ def test_storage_listdir_base(self): dirs, files = self.storage.listdir("") self.storage._client.list_blobs.assert_called_with( name_starts_with="", timeout=20) + self.storage._custom_client.list_blobs.assert_not_called() self.assertEqual(len(dirs), 2) for directory in ["some", "other"]: @@ -365,6 +393,7 @@ def test_storage_listdir_subdir(self): obj.name = p result.append(obj) self.storage._client.list_blobs.return_value = iter(result) + self.storage._custom_client.list_blobs.assert_not_called() dirs, files = self.storage.listdir("some/") self.storage._client.list_blobs.assert_called_with( @@ -424,3 +453,37 @@ def test_override_init_argument(self): self.assertEqual(storage.azure_container, 'foo1') storage = azure_storage.AzureStorage(azure_container='foo2') self.assertEqual(storage.azure_container, 'foo2') + + @mock.patch('storages.backends.azure_storage.AzureStorage._get_service_client',) + def test_get_service_client_use_custom_domain(self, gsc_mocked): + storage = azure_storage.AzureStorage() + storage.account_name = self.account_name + + _ = storage.service_client + gsc_mocked.assert_called_once_with(use_custom_domain=False) + + _ = storage.custom_service_client + gsc_mocked.assert_called_with(use_custom_domain=True) + + def test_blobserviceclient_no_custom_domain(self): + storage = azure_storage.AzureStorage() + storage.account_name = 'foo_name' + storage.custom_domain = None + storage.account_key = 'foo_key' + with mock.patch( + 'storages.backends.azure_storage.BlobServiceClient', + autospec=True) as bsc_mocked: + client_mock = mock.MagicMock() + custom_client_mock = mock.MagicMock() + bsc_mocked.return_value.get_container_client.return_value = client_mock + self.assertEqual(storage.client, client_mock) + bsc_mocked.assert_called_once_with( + 'https://foo_name.blob.core.windows.net', + credential={'account_name': 'foo_name', 'account_key': 'foo_key'}) + + bsc_mocked.return_value.get_container_client.return_value = custom_client_mock + self.assertEqual(storage.custom_client, custom_client_mock) + self.assertEqual(bsc_mocked.call_count, 2) + bsc_mocked.assert_called_with( + 'https://foo_name.blob.core.windows.net', + credential={'account_name': 'foo_name', 'account_key': 'foo_key'})