diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index cdc119fba..ae694bab0 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -1,6 +1,7 @@ import os import pathlib import stat +import tempfile import time import anyio @@ -454,15 +455,13 @@ def test_staticfiles_follows_symlinks(tmpdir, test_client_factory): statics_path = os.path.join(tmpdir, "statics") os.mkdir(statics_path) - symlink_path = os.path.join(tmpdir, "symlink") - os.mkdir(symlink_path) - - symlink_file_path = os.path.join(symlink_path, "index.html") - with open(symlink_file_path, "w") as file: + source_path = tempfile.mkdtemp() + source_file_path = os.path.join(source_path, "page.html") + with open(source_file_path, "w") as file: file.write("

Hello

") statics_file_path = os.path.join(statics_path, "index.html") - os.symlink(symlink_file_path, statics_file_path) + os.symlink(source_file_path, statics_file_path) app = StaticFiles(directory=statics_path) client = test_client_factory(app) @@ -473,19 +472,39 @@ def test_staticfiles_follows_symlinks(tmpdir, test_client_factory): assert response.text == "

Hello

" -def test_staticfiles_disallows_path_traversal_with_symlinks(tmpdir): +def test_staticfiles_follows_symlink_directories(tmpdir, test_client_factory): statics_path = os.path.join(tmpdir, "statics") + statics_html_path = os.path.join(statics_path, "html") os.mkdir(statics_path) - symlink_path = os.path.join(tmpdir, "symlink") - os.mkdir(symlink_path) + source_path = tempfile.mkdtemp() + source_file_path = os.path.join(source_path, "page.html") + with open(source_file_path, "w") as file: + file.write("

Hello

") + + os.symlink(source_path, statics_html_path) + + app = StaticFiles(directory=statics_path) + client = test_client_factory(app) + + response = client.get("/html/page.html") + assert response.url == "http://testserver/html/page.html" + assert response.status_code == 200 + assert response.text == "

Hello

" + + +def test_staticfiles_disallows_path_traversal_with_symlinks(tmpdir): + statics_path = os.path.join(tmpdir, "statics") + + root_source_path = tempfile.mkdtemp() + source_path = os.path.join(root_source_path, "statics") + os.mkdir(source_path) - symlink_file_path = os.path.join(symlink_path, "index.html") - with open(symlink_file_path, "w") as file: + source_file_path = os.path.join(root_source_path, "index.html") + with open(source_file_path, "w") as file: file.write("

Hello

") - temp_path = os.path.join(tmpdir, "index.html") - os.symlink(symlink_file_path, temp_path) + os.symlink(source_path, statics_path) app = StaticFiles(directory=statics_path) # We can't test this with 'requests', so we test the app directly here.