Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[azure] Use AZURE_CUSTOM_DOMAIN only for retrieving blob URLs, and use storage URL for other operations #1176

Merged
merged 8 commits into from Sep 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/backends/azure.rst
Expand Up @@ -144,7 +144,7 @@ The following settings are available:

``AZURE_CUSTOM_DOMAIN``

The custom domain to use. This can be set in the Azure Portal. For
The custom domain to use for generating URLs for files. For
example, ``www.mydomain.com`` or ``mycdn.azureedge.net``.

``AZURE_CONNECTION_STRING``
Expand Down
32 changes: 22 additions & 10 deletions storages/backends/azure_storage.py
Expand Up @@ -123,7 +123,9 @@ class AzureStorage(BaseStorage):
def __init__(self, **settings):
super().__init__(**settings)
self._service_client = None
self._custom_service_client = None
self._client = None
self._custom_client = None
self._user_delegation_key = None
self._user_delegation_key_expiry = datetime.utcnow()

Expand All @@ -150,11 +152,11 @@ def get_default_settings(self):
"api_version": setting('AZURE_API_VERSION', None),
}

def _get_service_client(self):
def _get_service_client(self, use_custom_domain):
if self.connection_string is not None:
return BlobServiceClient.from_connection_string(self.connection_string)

account_domain = self.custom_domain or "{}.blob.{}".format(
account_domain = self.custom_domain if self.custom_domain and use_custom_domain else "{}.blob.{}".format(
self.account_name,
self.endpoint_suffix,
)
Expand All @@ -178,9 +180,15 @@ def _get_service_client(self):
@property
def service_client(self):
if self._service_client is None:
self._service_client = self._get_service_client()
self._service_client = self._get_service_client(use_custom_domain=False)
return self._service_client

@property
def custom_service_client(self):
if self._custom_service_client is None:
self._custom_service_client = self._get_service_client(use_custom_domain=True)
return self._custom_service_client

@property
def client(self):
if self._client is None:
Expand All @@ -189,6 +197,14 @@ def client(self):
)
return self._client

@property
def custom_client(self):
if self._custom_client is None:
self._custom_client = self.custom_service_client.get_container_client(
self.azure_container
)
return self._custom_client

def get_user_delegation_key(self, expiry):
# We'll only be able to get a user delegation key if we've authenticated with a
# token credential.
Expand All @@ -203,7 +219,7 @@ def get_user_delegation_key(self, expiry):
):
now = datetime.utcnow()
key_expiry_time = now + timedelta(days=7)
self._user_delegation_key = self.service_client.get_user_delegation_key(
self._user_delegation_key = self.custom_service_client.get_user_delegation_key(
key_start_time=now, key_expiry_time=key_expiry_time
)
self._user_delegation_key_expiry = key_expiry_time
Expand Down Expand Up @@ -244,11 +260,7 @@ def get_available_name(self, name, max_length=_AZURE_NAME_MAX_LEN):

def exists(self, name):
blob_client = self.client.get_blob_client(self._get_valid_path(name))
try:
blob_client.get_blob_properties()
return True
except ResourceNotFoundError:
return False
return blob_client.exists()

def delete(self, name):
try:
Expand Down Expand Up @@ -309,7 +321,7 @@ def url(self, name, expire=None, parameters=None):
)
credential = sas_token

container_blob_url = self.client.get_blob_client(name).url
container_blob_url = self.custom_client.get_blob_client(name).url
return BlobClient.from_blob_url(container_blob_url, credential=credential).url

def _get_content_settings_parameters(self, name, content=None):
Expand Down
99 changes: 81 additions & 18 deletions tests/test_azure.py
Expand Up @@ -3,7 +3,6 @@
from unittest import mock

import django
from azure.core.exceptions import ResourceNotFoundError
from azure.storage.blob import BlobProperties
from django.core.exceptions import SuspiciousOperation
from django.core.files.base import ContentFile
Expand All @@ -20,6 +19,7 @@ class AzureStorageTest(TestCase):
def setUp(self, *args):
self.storage = azure_storage.AzureStorage()
self.storage._client = mock.MagicMock()
self.storage._custom_client = mock.MagicMock()
self.storage.overwrite_files = True
self.account_name = 'test'
self.account_key = 'key'
Expand Down Expand Up @@ -88,41 +88,44 @@ def test_get_valid_path_idempotency(self):
def test_get_available_name(self):
self.storage.overwrite_files = False
client_mock = mock.MagicMock()
client_mock.get_blob_properties.side_effect = [True, ResourceNotFoundError]
client_mock.exists.side_effect = [True, False]
custom_client_mock = mock.MagicMock()
self.storage._client.get_blob_client.return_value = client_mock
self.storage._custom_client.get_blob_client.return_value = custom_client_mock
name = self.storage.get_available_name('foo.txt')
self.assertTrue(name.startswith('foo_'))
self.assertTrue(name.endswith('.txt'))
self.assertTrue(len(name) > len('foo.txt'))
self.assertEqual(client_mock.get_blob_properties.call_count, 2)
self.assertEqual(client_mock.exists.call_count, 2)
self.assertEqual(custom_client_mock.exists.call_count, 0)

def test_get_available_name_first(self):
self.storage.overwrite_files = False
client_mock = mock.MagicMock()
client_mock.get_blob_properties.side_effect = [ResourceNotFoundError]
client_mock.exists.return_value = False
self.storage._client.get_blob_client.return_value = client_mock
self.assertEqual(
self.storage.get_available_name('foo bar baz.txt'),
'foo bar baz.txt')
self.assertEqual(client_mock.get_blob_properties.call_count, 1)
self.assertEqual(client_mock.exists.call_count, 1)

def test_get_available_name_max_len(self):
self.storage.overwrite_files = False
# if you wonder why this is, file-system
# storage will raise when file name is too long as well,
# the form should validate this
client_mock = mock.MagicMock()
client_mock.get_blob_properties.side_effect = [False, ResourceNotFoundError]
client_mock.exists.side_effect = [True, False]
self.storage._client.get_blob_client.return_value = client_mock
self.assertRaises(ValueError, self.storage.get_available_name, 'a' * 1025)
name = self.storage.get_available_name('a' * 1000, max_length=100) # max_len == 1024
self.assertEqual(len(name), 100)
self.assertTrue('_' in name)
self.assertEqual(client_mock.get_blob_properties.call_count, 2)
self.assertEqual(client_mock.exists.call_count, 2)

def test_get_available_invalid(self):
self.storage.overwrite_files = False
self.storage._client.get_blob_properties.return_value = False
self.storage._client.exists.return_value = False
if django.VERSION[:2] == (3, 0):
# Django 2.2.21 added this security fix:
# https://docs.djangoproject.com/en/3.2/releases/2.2.21/#cve-2021-31542-potential-directory-traversal-via-uploaded-files
Expand All @@ -142,25 +145,27 @@ def test_get_available_invalid(self):
def test_url(self):
blob_mock = mock.MagicMock()
blob_mock.url = 'https://ret_foo.blob.core.windows.net/test/some%20blob'
self.storage._client.get_blob_client.return_value = blob_mock
self.storage._custom_client.get_blob_client.return_value = blob_mock
self.assertEqual(self.storage.url('some blob'), blob_mock.url)
self.storage._client.get_blob_client.assert_called_once_with('some blob')
self.storage.custom_client.get_blob_client.assert_called_once_with('some blob')
self.storage._client.get_blob_client.assert_not_called()

def test_url_unsafe_chars(self):
blob_mock = mock.MagicMock()
blob_mock.url = 'https://ret_foo.blob.core.windows.net/test/some%20blob'
self.storage._client.get_blob_client.return_value = blob_mock
self.storage._custom_client.get_blob_client.return_value = blob_mock
self.assertEqual(
self.storage.url('foo;?:@=&"<>#%{}|^~[]`bar/~!*()\''), blob_mock.url)
self.storage.client.get_blob_client.assert_called_once_with(
self.storage.custom_client.get_blob_client.assert_called_once_with(
'foo;?:@=&"<>#%{}|^~[]`bar/~!*()\'')
self.storage._client.get_blob_client.assert_not_called()

@mock.patch('storages.backends.azure_storage.generate_blob_sas')
def test_url_expire(self, generate_blob_sas_mocked):
generate_blob_sas_mocked.return_value = 'foo_token'
blob_mock = mock.MagicMock()
blob_mock.url = 'https://ret_foo.blob.core.windows.net/test/some%20blob'
self.storage._client.get_blob_client.return_value = blob_mock
self.storage._custom_client.get_blob_client.return_value = blob_mock
self.storage.account_name = self.account_name

fixed_time = make_aware(datetime.datetime(2016, 11, 6, 4), timezone.utc)
Expand All @@ -183,16 +188,16 @@ def test_url_expire_user_delegation_key(self, generate_blob_sas_mocked):
generate_blob_sas_mocked.return_value = 'foo_token'
blob_mock = mock.MagicMock()
blob_mock.url = 'https://ret_foo.blob.core.windows.net/test/some%20blob'
self.storage._client.get_blob_client.return_value = blob_mock
self.storage._custom_client.get_blob_client.return_value = blob_mock
self.storage.account_name = self.account_name
service_client = mock.MagicMock()
self.storage._service_client = service_client
custom_service_client = mock.MagicMock()
self.storage._custom_service_client = custom_service_client
self.storage.token_credential = 'token_credential'

fixed_time = make_aware(datetime.datetime(2016, 11, 6, 4), timezone.utc)
with mock.patch('storages.backends.azure_storage.datetime') as d_mocked:
d_mocked.utcnow.return_value = fixed_time
service_client.get_user_delegation_key.return_value = 'user delegation key'
custom_service_client.get_user_delegation_key.return_value = 'user delegation key'
self.assertEqual(
self.storage.url('some blob', 100),
'https://ret_foo.blob.core.windows.net/test/some%20blob')
Expand Down Expand Up @@ -228,9 +233,17 @@ def test_container_client_params_account_key(self):
'storages.backends.azure_storage.BlobServiceClient',
autospec=True) as bsc_mocked:
client_mock = mock.MagicMock()
custom_client_mock = mock.MagicMock()
bsc_mocked.return_value.get_container_client.return_value = client_mock
self.assertEqual(storage.client, client_mock)
bsc_mocked.assert_called_once_with(
'https://foo_name.blob.core.windows.net',
credential={'account_name': 'foo_name', 'account_key': 'foo_key'})

bsc_mocked.return_value.get_container_client.return_value = custom_client_mock
self.assertEqual(storage.custom_client, custom_client_mock)
self.assertEqual(bsc_mocked.call_count, 2)
bsc_mocked.assert_called_with(
'https://foo_domain',
credential={'account_name': 'foo_name', 'account_key': 'foo_key'})

Expand All @@ -244,9 +257,17 @@ def test_container_client_params_sas_token(self):
'storages.backends.azure_storage.BlobServiceClient',
autospec=True) as bsc_mocked:
client_mock = mock.MagicMock()
custom_client_mock = mock.MagicMock()
bsc_mocked.return_value.get_container_client.return_value = client_mock
self.assertEqual(storage.client, client_mock)
bsc_mocked.assert_called_once_with(
'http://foo_name.blob.core.windows.net',
credential='foo_token')

bsc_mocked.return_value.get_container_client.return_value = custom_client_mock
self.assertEqual(storage.custom_client, custom_client_mock)
self.assertEqual(bsc_mocked.call_count, 2)
bsc_mocked.assert_called_with(
'http://foo_domain',
credential='foo_token')

Expand Down Expand Up @@ -298,6 +319,7 @@ def test_storage_save(self):
content_type='text/plain',
content_encoding=None,
cache_control=None)
self.storage._custom_client.upload_blob.assert_not_called()

def test_storage_open_write(self):
"""
Expand All @@ -317,18 +339,23 @@ def test_storage_open_write(self):
max_concurrency=2,
timeout=20,
overwrite=True)
self.storage._custom_client.upload_blob.assert_not_called()

def test_storage_exists(self):
blob_name = "blob"
client_mock = mock.MagicMock()
custom_client_mock = mock.MagicMock()
self.storage._client.get_blob_client.return_value = client_mock
self.storage._custom_client.get_blob_client.return_value = client_mock
self.assertTrue(self.storage.exists(blob_name))
self.assertEqual(client_mock.get_blob_properties.call_count, 1)
self.assertEqual(client_mock.exists.call_count, 1)
self.assertEqual(custom_client_mock.exists.call_count, 0)

def test_delete_blob(self):
self.storage.delete("name")
self.storage._client.delete_blob.assert_called_once_with(
"name", timeout=20)
self.storage._custom_client.delete_blob.assert_not_called()

def test_storage_listdir_base(self):
file_names = ["some/path/1.txt", "2.txt", "other/path/3.txt", "4.txt"]
Expand All @@ -343,6 +370,7 @@ def test_storage_listdir_base(self):
dirs, files = self.storage.listdir("")
self.storage._client.list_blobs.assert_called_with(
name_starts_with="", timeout=20)
self.storage._custom_client.list_blobs.assert_not_called()

self.assertEqual(len(dirs), 2)
for directory in ["some", "other"]:
Expand All @@ -365,6 +393,7 @@ def test_storage_listdir_subdir(self):
obj.name = p
result.append(obj)
self.storage._client.list_blobs.return_value = iter(result)
self.storage._custom_client.list_blobs.assert_not_called()

dirs, files = self.storage.listdir("some/")
self.storage._client.list_blobs.assert_called_with(
Expand Down Expand Up @@ -424,3 +453,37 @@ def test_override_init_argument(self):
self.assertEqual(storage.azure_container, 'foo1')
storage = azure_storage.AzureStorage(azure_container='foo2')
self.assertEqual(storage.azure_container, 'foo2')

@mock.patch('storages.backends.azure_storage.AzureStorage._get_service_client',)
def test_get_service_client_use_custom_domain(self, gsc_mocked):
storage = azure_storage.AzureStorage()
storage.account_name = self.account_name

_ = storage.service_client
gsc_mocked.assert_called_once_with(use_custom_domain=False)

_ = storage.custom_service_client
gsc_mocked.assert_called_with(use_custom_domain=True)

def test_blobserviceclient_no_custom_domain(self):
storage = azure_storage.AzureStorage()
storage.account_name = 'foo_name'
storage.custom_domain = None
storage.account_key = 'foo_key'
with mock.patch(
'storages.backends.azure_storage.BlobServiceClient',
autospec=True) as bsc_mocked:
client_mock = mock.MagicMock()
custom_client_mock = mock.MagicMock()
bsc_mocked.return_value.get_container_client.return_value = client_mock
self.assertEqual(storage.client, client_mock)
bsc_mocked.assert_called_once_with(
'https://foo_name.blob.core.windows.net',
credential={'account_name': 'foo_name', 'account_key': 'foo_key'})

bsc_mocked.return_value.get_container_client.return_value = custom_client_mock
self.assertEqual(storage.custom_client, custom_client_mock)
self.assertEqual(bsc_mocked.call_count, 2)
bsc_mocked.assert_called_with(
'https://foo_name.blob.core.windows.net',
credential={'account_name': 'foo_name', 'account_key': 'foo_key'})