Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Azure name cleanup #752

Merged
merged 2 commits into from Sep 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 10 additions & 22 deletions storages/backends/azure_storage.py
@@ -1,7 +1,6 @@
from __future__ import unicode_literals

import mimetypes
import re
from datetime import datetime, timedelta
from tempfile import SpooledTemporaryFile

Expand All @@ -13,7 +12,7 @@
from django.core.files.storage import Storage
from django.utils import timezone
from django.utils.deconstruct import deconstructible
from django.utils.encoding import force_bytes, force_text
from django.utils.encoding import filepath_to_uri, force_bytes

from storages.utils import (
clean_name, get_available_overwrite_name, safe_join, setting,
Expand Down Expand Up @@ -101,10 +100,7 @@ def _get_valid_path(s):
# * must not end with dot or slash
# * can contain any character
# * must escape URL reserved characters
# We allow a subset of this to avoid
# illegal file names. We must ensure it is idempotent.
s = force_text(s).strip().replace(' ', '_')
s = re.sub(r'(?u)[^-\w./]', '', s)
# (not needed here since the azure client will do that)
s = s.strip('./')
if len(s) > _AZURE_NAME_MAX_LEN:
raise ValueError(
Expand All @@ -120,12 +116,6 @@ def _get_valid_path(s):
return s


def _clean_name_dance(name):
# `get_valid_path` may return `foo/../bar`
name = name.replace('\\', '/')
return clean_name(_get_valid_path(clean_name(name)))


# Max len according to azure's docs
_AZURE_NAME_MAX_LEN = 1024

Expand Down Expand Up @@ -198,29 +188,27 @@ def azure_protocol(self):
else:
return 'http'

def _path(self, name):
name = _clean_name_dance(name)
def _normalize_name(self, name):
try:
return safe_join(self.location, name)
except ValueError:
raise SuspiciousOperation("Attempted access to '%s' denied." % name)

def _get_valid_path(self, name):
# Must be idempotent
return _get_valid_path(self._path(name))
return _get_valid_path(
self._normalize_name(
clean_name(name)))

def _open(self, name, mode="rb"):
return AzureStorageFile(name, mode, self)

def get_valid_name(self, name):
return _clean_name_dance(name)

def get_available_name(self, name, max_length=_AZURE_NAME_MAX_LEN):
"""
Returns a filename that's free on the target storage system, and
available for new content to be written to.
"""
name = self.get_valid_name(name)
name = clean_name(name)
if self.overwrite_files:
return get_available_overwrite_name(name, max_length)
return super(AzureStorage, self).get_available_name(name, max_length)
Expand Down Expand Up @@ -248,7 +236,7 @@ def size(self, name):
return properties.content_length

def _save(self, name, content):
name_only = self.get_valid_name(name)
cleaned_name = clean_name(name)
name = self._get_valid_path(name)
guessed_type, content_encoding = mimetypes.guess_type(name)
content_type = (
Expand All @@ -270,7 +258,7 @@ def _save(self, name, content):
content_encoding=content_encoding),
max_connections=self.upload_max_conn,
timeout=self.timeout)
return name_only
return cleaned_name

def _expire_at(self, expire):
# azure expects time in UTC
Expand All @@ -292,7 +280,7 @@ def url(self, name, expire=None):
make_blob_url_kwargs['protocol'] = self.azure_protocol
return self.custom_service.make_blob_url(
container_name=self.azure_container,
blob_name=name,
blob_name=filepath_to_uri(name),
**make_blob_url_kwargs)

def get_modified_time(self, name):
Expand Down
57 changes: 50 additions & 7 deletions tests/integration/test_azure.py
Expand Up @@ -29,30 +29,30 @@ def setUp(self, *args):
self.storage.azure_container, public_access=False, fail_on_exist=False)

def test_save(self):
expected_name = "some_blob_Ϊ.txt"
expected_name = "some blob Ϊ.txt"
self.assertFalse(self.storage.exists(expected_name))
stream = io.BytesIO(b'Im a stream')
name = self.storage.save('some blob Ϊ.txt', stream)
name = self.storage.save(expected_name, stream)
self.assertEqual(name, expected_name)
self.assertTrue(self.storage.exists(expected_name))

def test_delete(self):
self.storage.location = 'path'
expected_name = "some_blob_Ϊ.txt"
expected_name = "some blob Ϊ.txt"
self.assertFalse(self.storage.exists(expected_name))
stream = io.BytesIO(b'Im a stream')
name = self.storage.save('some blob Ϊ.txt', stream)
name = self.storage.save(expected_name, stream)
self.assertEqual(name, expected_name)
self.assertTrue(self.storage.exists(expected_name))
self.storage.delete(expected_name)
self.assertFalse(self.storage.exists(expected_name))

def test_size(self):
self.storage.location = 'path'
expected_name = "some_path/some_blob_Ϊ.txt"
expected_name = "some path/some blob Ϊ.txt"
self.assertFalse(self.storage.exists(expected_name))
stream = io.BytesIO(b'Im a stream')
name = self.storage.save('some path/some blob Ϊ.txt', stream)
name = self.storage.save(expected_name, stream)
self.assertEqual(name, expected_name)
self.assertTrue(self.storage.exists(expected_name))
self.assertEqual(self.storage.size(expected_name), len(b'Im a stream'))
Expand All @@ -64,6 +64,15 @@ def test_url(self):
# has some query-string
self.assertTrue("/test/my_file.txt?" in self.storage.url("my_file.txt"))

def test_url_unsafe_chars(self):
name = "my?file <foo>.txt"
expected = "/test/my%3Ffile%20%3Cfoo%3E.txt"
self.assertTrue(
self.storage.url(name).endswith(expected))
# has some query-string
self.storage.expiration_secs = 360
self.assertTrue("{}?".format(expected) in self.storage.url(name))

def test_url_custom_endpoint(self):
storage = azure_storage.AzureStorage()
storage.is_emulated = True
Expand Down Expand Up @@ -107,7 +116,7 @@ def test_open_read(self):
stream = io.BytesIO()
self.storage.service.get_blob_to_stream(
container_name=self.storage.azure_container,
blob_name='root/path/some_file.txt',
blob_name='root/path/some file.txt',
stream=stream,
max_connections=1,
timeout=10)
Expand Down Expand Up @@ -184,6 +193,11 @@ class AzureStorageExpiry(azure_storage.AzureStorage):
expiration_secs = 360


class AzureStorageSpecialChars(azure_storage.AzureStorage):
def get_valid_name(self, name):
return name


class FooFileForm(forms.Form):

foo_file = forms.FileField()
Expand Down Expand Up @@ -265,3 +279,32 @@ def test_model_form(self):
self.assertEqual(fh.read(), b'foo content')
finally:
fh.close()

def test_name_clean_issue_609(self):
"""
Should strip special characters when using the default storage
"""
simple_file = SimpleFileModel()
simple_file.foo_file = SimpleUploadedFile(
name='foo%?:;~bar.txt',
content=b'foo content')
simple_file.save()
self.assertEqual(simple_file.foo_file.name, 'foo_uploads/foobar.txt')
self.assertTrue('foobar.txt' in simple_file.foo_file.url)

@override_settings(
DEFAULT_FILE_STORAGE='tests.integration.test_azure.AzureStorageSpecialChars')
def test_name_clean_issue_609_with_special_chars(self):
"""
Should not strip special chars
"""
name = 'foo%?:;~bar.txt'
simple_file = SimpleFileModel()
simple_file.foo_file = SimpleUploadedFile(
name=name,
content=b'foo content')
simple_file.save()
self.assertEqual(
simple_file.foo_file.name, 'foo_uploads/{}'.format(name))
self.assertTrue(
'foo_uploads/foo%25%3F%3A%3B~bar.txt' in simple_file.foo_file.url)
44 changes: 27 additions & 17 deletions tests/test_azure.py
Expand Up @@ -45,13 +45,13 @@ def test_get_valid_path(self):
self.storage._get_valid_path("path\\to\\somewhere"),
"path/to/somewhere")
self.assertEqual(
self.storage._get_valid_path("some/$/path"), "some/path")
self.storage._get_valid_path("some/$/path"), "some/$/path")
self.assertEqual(
self.storage._get_valid_path("/$/path"), "path")
self.storage._get_valid_path("/$/path"), "$/path")
self.assertEqual(
self.storage._get_valid_path("path/$/"), "path")
self.storage._get_valid_path("path/$/"), "path/$")
self.assertEqual(
self.storage._get_valid_path("path/$/$/$/path"), "path/path")
self.storage._get_valid_path("path/$/$/$/path"), "path/$/$/$/path")
self.assertEqual(
self.storage._get_valid_path("some///path"), "some/path")
self.assertEqual(
Expand All @@ -67,24 +67,23 @@ def test_get_valid_path(self):
self.assertRaises(ValueError, self.storage._get_valid_path, "/../")
self.assertRaises(ValueError, self.storage._get_valid_path, "..")
self.assertRaises(ValueError, self.storage._get_valid_path, "///")
self.assertRaises(ValueError, self.storage._get_valid_path, "!!!")
self.assertRaises(ValueError, self.storage._get_valid_path, "a" * 1025)
self.assertRaises(ValueError, self.storage._get_valid_path, "a/a" * 257)

def test_get_valid_path_idempotency(self):
self.assertEqual(
self.storage._get_valid_path("//$//a//$//"), "a")
self.storage._get_valid_path("//$//a//$//"), "$/a/$")
self.assertEqual(
self.storage._get_valid_path(
self.storage._get_valid_path("//$//a//$//")),
self.storage._get_valid_path("//$//a//$//"))
some_path = "some path/some long name & then some.txt"
self.assertEqual(
self.storage._get_valid_path("some path/some long name & then some.txt"),
"some_path/some_long_name__then_some.txt")
self.storage._get_valid_path(some_path), some_path)
self.assertEqual(
self.storage._get_valid_path(
self.storage._get_valid_path("some path/some long name & then some.txt")),
self.storage._get_valid_path("some path/some long name & then some.txt"))
self.storage._get_valid_path(some_path)),
self.storage._get_valid_path(some_path))

def test_get_available_name(self):
self.storage.overwrite_files = False
Expand All @@ -100,7 +99,7 @@ def test_get_available_name_first(self):
self.storage._service.exists.return_value = False
self.assertEqual(
self.storage.get_available_name('foo bar baz.txt'),
'foo_bar_baz.txt')
'foo bar baz.txt')
self.assertEqual(self.storage._service.exists.call_count, 1)

def test_get_available_name_max_len(self):
Expand All @@ -119,14 +118,25 @@ def test_get_available_invalid(self):
self.storage.overwrite_files = False
self.storage._service.exists.return_value = False
self.assertRaises(ValueError, self.storage.get_available_name, "")
self.assertRaises(ValueError, self.storage.get_available_name, "$$")
self.assertRaises(ValueError, self.storage.get_available_name, "/")
self.assertRaises(ValueError, self.storage.get_available_name, ".")
self.assertRaises(ValueError, self.storage.get_available_name, "///")
self.assertRaises(ValueError, self.storage.get_available_name, "...")

def test_url(self):
self.storage._custom_service.make_blob_url.return_value = 'ret_foo'
self.assertEqual(self.storage.url('some blob'), 'ret_foo')
self.storage._custom_service.make_blob_url.assert_called_once_with(
container_name=self.container_name,
blob_name='some_blob',
blob_name='some%20blob',
protocol='https')

def test_url_unsafe_chars(self):
self.storage.custom_service.make_blob_url.return_value = 'ret_foo'
self.assertEqual(self.storage.url('foo;?:@=&"<>#%{}|^~[]`bar/~!*()\''), 'ret_foo')
self.storage.custom_service.make_blob_url.assert_called_once_with(
container_name=self.container_name,
blob_name='foo%3B%3F%3A%40%3D%26%22%3C%3E%23%25%7B%7D%7C%5E~%5B%5D%60bar/~!*()\'',
protocol='https')

def test_url_expire(self):
Expand All @@ -139,12 +149,12 @@ def test_url_expire(self):
self.assertEqual(self.storage.url('some blob', 100), 'ret_foo')
self.storage._custom_service.generate_blob_shared_access_signature.assert_called_once_with(
self.container_name,
'some_blob',
'some blob',
permission=BlobPermissions.READ,
expiry=fixed_time + timedelta(seconds=100))
self.storage._custom_service.make_blob_url.assert_called_once_with(
container_name=self.container_name,
blob_name='some_blob',
blob_name='some%20blob',
sas_token='foo_token',
protocol='https')

Expand Down Expand Up @@ -284,10 +294,10 @@ def test_storage_save(self):
content = ContentFile('new content')
with mock.patch('storages.backends.azure_storage.ContentSettings') as c_mocked:
c_mocked.return_value = 'content_settings_foo'
self.assertEqual(self.storage.save(name, content), 'test_storage_save.txt')
self.assertEqual(self.storage.save(name, content), name)
self.storage._service.create_blob_from_stream.assert_called_once_with(
container_name=self.container_name,
blob_name='test_storage_save.txt',
blob_name=name,
stream=content.file,
content_settings='content_settings_foo',
max_connections=2,
Expand Down