diff --git a/storages/backends/azure_storage.py b/storages/backends/azure_storage.py index f476fd87a..ffb53de5e 100644 --- a/storages/backends/azure_storage.py +++ b/storages/backends/azure_storage.py @@ -1,7 +1,6 @@ from __future__ import unicode_literals import mimetypes -import re from datetime import datetime, timedelta from tempfile import SpooledTemporaryFile @@ -13,10 +12,13 @@ from django.core.files.storage import Storage from django.utils import timezone from django.utils.deconstruct import deconstructible -from django.utils.encoding import force_bytes, force_text +from django.utils.encoding import force_bytes, filepath_to_uri from storages.utils import ( - clean_name, get_available_overwrite_name, safe_join, setting, + clean_name, + get_available_overwrite_name, + safe_join, + setting, ) @@ -103,10 +105,7 @@ def _get_valid_path(s): # * must not end with dot or slash # * can contain any character # * must escape URL reserved characters - # We allow a subset of this to avoid - # illegal file names. We must ensure it is idempotent. - s = force_text(s).strip().replace(' ', '_') - s = re.sub(r'(?u)[^-\w./]', '', s) + # (not needed here since the azure client will do that) s = s.strip('./') if len(s) > _AZURE_NAME_MAX_LEN: raise ValueError( @@ -122,12 +121,6 @@ def _get_valid_path(s): return s -def _clean_name_dance(name): - # `get_valid_path` may return `foo/../bar` - name = name.replace('\\', '/') - return clean_name(_get_valid_path(clean_name(name))) - - # Max len according to azure's docs _AZURE_NAME_MAX_LEN = 1024 @@ -170,8 +163,7 @@ def azure_protocol(self): else: return 'http' - def _path(self, name): - name = _clean_name_dance(name) + def _normalize_name(self, name): try: return safe_join(self.location, name) except ValueError: @@ -179,20 +171,19 @@ def _path(self, name): def _get_valid_path(self, name): # Must be idempotent - return _get_valid_path(self._path(name)) + return _get_valid_path( + self._normalize_name( + clean_name(name))) def _open(self, name, mode="rb"): return AzureStorageFile(name, mode, self) - def get_valid_name(self, name): - return _clean_name_dance(name) - def get_available_name(self, name, max_length=_AZURE_NAME_MAX_LEN): """ Returns a filename that's free on the target storage system, and available for new content to be written to. """ - name = self.get_valid_name(name) + name = clean_name(name) if self.overwrite_files: return get_available_overwrite_name(name, max_length) return super(AzureStorage, self).get_available_name(name, max_length) @@ -220,7 +211,7 @@ def size(self, name): return properties.content_length def _save(self, name, content): - name_only = self.get_valid_name(name) + cleaned_name = clean_name(name) name = self._get_valid_path(name) guessed_type, content_encoding = mimetypes.guess_type(name) content_type = ( @@ -242,7 +233,7 @@ def _save(self, name, content): content_encoding=content_encoding), max_connections=self.upload_max_conn, timeout=self.timeout) - return name_only + return cleaned_name def _expire_at(self, expire): # azure expects time in UTC @@ -264,7 +255,7 @@ def url(self, name, expire=None): make_blob_url_kwargs['protocol'] = self.azure_protocol return self.service.make_blob_url( container_name=self.azure_container, - blob_name=name, + blob_name=filepath_to_uri(name), **make_blob_url_kwargs) def get_modified_time(self, name): diff --git a/tests/integration/test_azure.py b/tests/integration/test_azure.py index edfc58e94..53865d708 100644 --- a/tests/integration/test_azure.py +++ b/tests/integration/test_azure.py @@ -29,19 +29,19 @@ def setUp(self, *args): self.storage.azure_container, public_access=False, fail_on_exist=False) def test_save(self): - expected_name = "some_blob_Ϊ.txt" + expected_name = "some blob Ϊ.txt" self.assertFalse(self.storage.exists(expected_name)) stream = io.BytesIO(b'Im a stream') - name = self.storage.save('some blob Ϊ.txt', stream) + name = self.storage.save(expected_name, stream) self.assertEqual(name, expected_name) self.assertTrue(self.storage.exists(expected_name)) def test_delete(self): self.storage.location = 'path' - expected_name = "some_blob_Ϊ.txt" + expected_name = "some blob Ϊ.txt" self.assertFalse(self.storage.exists(expected_name)) stream = io.BytesIO(b'Im a stream') - name = self.storage.save('some blob Ϊ.txt', stream) + name = self.storage.save(expected_name, stream) self.assertEqual(name, expected_name) self.assertTrue(self.storage.exists(expected_name)) self.storage.delete(expected_name) @@ -49,10 +49,10 @@ def test_delete(self): def test_size(self): self.storage.location = 'path' - expected_name = "some_path/some_blob_Ϊ.txt" + expected_name = "some path/some blob Ϊ.txt" self.assertFalse(self.storage.exists(expected_name)) stream = io.BytesIO(b'Im a stream') - name = self.storage.save('some path/some blob Ϊ.txt', stream) + name = self.storage.save(expected_name, stream) self.assertEqual(name, expected_name) self.assertTrue(self.storage.exists(expected_name)) self.assertEqual(self.storage.size(expected_name), len(b'Im a stream')) @@ -64,6 +64,15 @@ def test_url(self): # has some query-string self.assertTrue("/test/my_file.txt?" in self.storage.url("my_file.txt")) + def test_url_unsafe_chars(self): + name = "my?file .txt" + expected = "/test/my%3Ffile%20%3Cfoo%3E.txt" + self.assertTrue( + self.storage.url(name).endswith(expected)) + # has some query-string + self.storage.expiration_secs = 360 + self.assertTrue("{}?".format(expected) in self.storage.url(name)) + @override_settings(USE_TZ=True) def test_get_modified_time_tz(self): stream = io.BytesIO(b'Im a stream') @@ -101,7 +110,7 @@ def test_open_read(self): stream = io.BytesIO() self.storage.service.get_blob_to_stream( container_name=self.storage.azure_container, - blob_name='root/path/some_file.txt', + blob_name='root/path/some file.txt', stream=stream, max_connections=1, timeout=10) @@ -178,6 +187,11 @@ class AzureStorageExpiry(azure_storage.AzureStorage): expiration_secs = 360 +class AzureStorageSpecialChars(azure_storage.AzureStorage): + def get_valid_name(self, name): + return name + + class FooFileForm(forms.Form): foo_file = forms.FileField() @@ -259,3 +273,32 @@ def test_model_form(self): self.assertEqual(fh.read(), b'foo content') finally: fh.close() + + def test_name_clean_issue_609(self): + """ + Should strip special characters when using the default storage + """ + simple_file = SimpleFileModel() + simple_file.foo_file = SimpleUploadedFile( + name='foo%?:;~bar.txt', + content=b'foo content') + simple_file.save() + self.assertEqual(simple_file.foo_file.name, 'foo_uploads/foobar.txt') + self.assertTrue('foobar.txt' in simple_file.foo_file.url) + + @override_settings( + DEFAULT_FILE_STORAGE='tests.integration.test_azure.AzureStorageSpecialChars') + def test_name_clean_issue_609_with_special_chars(self): + """ + Should not strip special chars + """ + name = 'foo%?:;~bar.txt' + simple_file = SimpleFileModel() + simple_file.foo_file = SimpleUploadedFile( + name=name, + content=b'foo content') + simple_file.save() + self.assertEqual( + simple_file.foo_file.name, 'foo_uploads/{}'.format(name)) + self.assertTrue( + 'foo_uploads/foo%25%3F%3A%3B~bar.txt' in simple_file.foo_file.url) diff --git a/tests/test_azure.py b/tests/test_azure.py index 55e93813f..e83a00b75 100644 --- a/tests/test_azure.py +++ b/tests/test_azure.py @@ -44,13 +44,13 @@ def test_get_valid_path(self): self.storage._get_valid_path("path\\to\\somewhere"), "path/to/somewhere") self.assertEqual( - self.storage._get_valid_path("some/$/path"), "some/path") + self.storage._get_valid_path("some/$/path"), "some/$/path") self.assertEqual( - self.storage._get_valid_path("/$/path"), "path") + self.storage._get_valid_path("/$/path"), "$/path") self.assertEqual( - self.storage._get_valid_path("path/$/"), "path") + self.storage._get_valid_path("path/$/"), "path/$") self.assertEqual( - self.storage._get_valid_path("path/$/$/$/path"), "path/path") + self.storage._get_valid_path("path/$/$/$/path"), "path/$/$/$/path") self.assertEqual( self.storage._get_valid_path("some///path"), "some/path") self.assertEqual( @@ -66,24 +66,23 @@ def test_get_valid_path(self): self.assertRaises(ValueError, self.storage._get_valid_path, "/../") self.assertRaises(ValueError, self.storage._get_valid_path, "..") self.assertRaises(ValueError, self.storage._get_valid_path, "///") - self.assertRaises(ValueError, self.storage._get_valid_path, "!!!") self.assertRaises(ValueError, self.storage._get_valid_path, "a" * 1025) self.assertRaises(ValueError, self.storage._get_valid_path, "a/a" * 257) def test_get_valid_path_idempotency(self): self.assertEqual( - self.storage._get_valid_path("//$//a//$//"), "a") + self.storage._get_valid_path("//$//a//$//"), "$/a/$") self.assertEqual( self.storage._get_valid_path( self.storage._get_valid_path("//$//a//$//")), self.storage._get_valid_path("//$//a//$//")) + some_path = "some path/some long name & then some.txt" self.assertEqual( - self.storage._get_valid_path("some path/some long name & then some.txt"), - "some_path/some_long_name__then_some.txt") + self.storage._get_valid_path(some_path), some_path) self.assertEqual( self.storage._get_valid_path( - self.storage._get_valid_path("some path/some long name & then some.txt")), - self.storage._get_valid_path("some path/some long name & then some.txt")) + self.storage._get_valid_path(some_path)), + self.storage._get_valid_path(some_path)) def test_get_available_name(self): self.storage.overwrite_files = False @@ -99,7 +98,7 @@ def test_get_available_name_first(self): self.storage._service.exists.return_value = False self.assertEqual( self.storage.get_available_name('foo bar baz.txt'), - 'foo_bar_baz.txt') + 'foo bar baz.txt') self.assertEqual(self.storage._service.exists.call_count, 1) def test_get_available_name_max_len(self): @@ -118,14 +117,25 @@ def test_get_available_invalid(self): self.storage.overwrite_files = False self.storage._service.exists.return_value = False self.assertRaises(ValueError, self.storage.get_available_name, "") - self.assertRaises(ValueError, self.storage.get_available_name, "$$") + self.assertRaises(ValueError, self.storage.get_available_name, "/") + self.assertRaises(ValueError, self.storage.get_available_name, ".") + self.assertRaises(ValueError, self.storage.get_available_name, "///") + self.assertRaises(ValueError, self.storage.get_available_name, "...") def test_url(self): self.storage._service.make_blob_url.return_value = 'ret_foo' self.assertEqual(self.storage.url('some blob'), 'ret_foo') self.storage._service.make_blob_url.assert_called_once_with( container_name=self.container_name, - blob_name='some_blob', + blob_name='some%20blob', + protocol='https') + + def test_url_unsafe_chars(self): + self.storage._service.make_blob_url.return_value = 'ret_foo' + self.assertEqual(self.storage.url('foo;?:@=&"<>#%{}|^~[]`bar/~!*()\''), 'ret_foo') + self.storage._service.make_blob_url.assert_called_once_with( + container_name=self.container_name, + blob_name='foo%3B%3F%3A%40%3D%26%22%3C%3E%23%25%7B%7D%7C%5E~%5B%5D%60bar/~!*()\'', protocol='https') def test_url_expire(self): @@ -138,12 +148,12 @@ def test_url_expire(self): self.assertEqual(self.storage.url('some blob', 100), 'ret_foo') self.storage._service.generate_blob_shared_access_signature.assert_called_once_with( self.container_name, - 'some_blob', + 'some blob', BlobPermissions.READ, expiry=fixed_time + timedelta(seconds=100)) self.storage._service.make_blob_url.assert_called_once_with( container_name=self.container_name, - blob_name='some_blob', + blob_name='some%20blob', sas_token='foo_token', protocol='https') @@ -157,10 +167,10 @@ def test_storage_save(self): content = ContentFile('new content') with mock.patch('storages.backends.azure_storage.ContentSettings') as c_mocked: c_mocked.return_value = 'content_settings_foo' - self.assertEqual(self.storage.save(name, content), 'test_storage_save.txt') + self.assertEqual(self.storage.save(name, content), name) self.storage._service.create_blob_from_stream.assert_called_once_with( container_name=self.container_name, - blob_name='test_storage_save.txt', + blob_name=name, stream=content.file, content_settings='content_settings_foo', max_connections=2,