diff --git a/docs/staticfiles.md b/docs/staticfiles.md index 3591b4f9f..fa14e6c34 100644 --- a/docs/staticfiles.md +++ b/docs/staticfiles.md @@ -3,12 +3,13 @@ Starlette also includes a `StaticFiles` class for serving files in a given direc ### StaticFiles -Signature: `StaticFiles(directory=None, packages=None, check_dir=True)` +Signature: `StaticFiles(directory=None, packages=None, check_dir=True, follow_symlink=False)` * `directory` - A string or [os.Pathlike][pathlike] denoting a directory path. * `packages` - A list of strings or list of tuples of strings of python packages. * `html` - Run in HTML mode. Automatically loads `index.html` for directories if such file exist. * `check_dir` - Ensure that the directory exists upon instantiation. Defaults to `True`. +* `follow_symlink` - A boolean indicating if symbolic links for files and directories should be followed. Defaults to `False`. You can combine this ASGI application with Starlette's routing to provide comprehensive static file serving. diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py index d09630f35..4d075b3ed 100644 --- a/starlette/staticfiles.py +++ b/starlette/staticfiles.py @@ -45,12 +45,14 @@ def __init__( ] = None, html: bool = False, check_dir: bool = True, + follow_symlink: bool = False, ) -> None: self.directory = directory self.packages = packages self.all_directories = self.get_directories(directory, packages) self.html = html self.config_checked = False + self.follow_symlink = follow_symlink if check_dir and directory is not None and not os.path.isdir(directory): raise RuntimeError(f"Directory '{directory}' does not exist") @@ -161,7 +163,11 @@ def lookup_path( self, path: str ) -> typing.Tuple[str, typing.Optional[os.stat_result]]: for directory in self.all_directories: - full_path = os.path.realpath(os.path.join(directory, path)) + joined_path = os.path.join(directory, path) + if self.follow_symlink: + full_path = os.path.abspath(joined_path) + else: + full_path = os.path.realpath(joined_path) directory = os.path.realpath(directory) if os.path.commonprefix([full_path, directory]) != directory: # Don't allow misbehaving clients to break out of the static files diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index 142c2a00b..eb6c73f7f 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -1,6 +1,7 @@ import os import pathlib import stat +import tempfile import time import anyio @@ -448,3 +449,70 @@ def mock_timeout(*args, **kwargs): response = client.get("/example.txt") assert response.status_code == 500 assert response.text == "Internal Server Error" + + +def test_staticfiles_follows_symlinks(tmpdir, test_client_factory): + statics_path = os.path.join(tmpdir, "statics") + os.mkdir(statics_path) + + source_path = tempfile.mkdtemp() + source_file_path = os.path.join(source_path, "page.html") + with open(source_file_path, "w") as file: + file.write("

Hello

") + + statics_file_path = os.path.join(statics_path, "index.html") + os.symlink(source_file_path, statics_file_path) + + app = StaticFiles(directory=statics_path, follow_symlink=True) + client = test_client_factory(app) + + response = client.get("/index.html") + assert response.url == "http://testserver/index.html" + assert response.status_code == 200 + assert response.text == "

Hello

" + + +def test_staticfiles_follows_symlink_directories(tmpdir, test_client_factory): + statics_path = os.path.join(tmpdir, "statics") + statics_html_path = os.path.join(statics_path, "html") + os.mkdir(statics_path) + + source_path = tempfile.mkdtemp() + source_file_path = os.path.join(source_path, "page.html") + with open(source_file_path, "w") as file: + file.write("

Hello

") + + os.symlink(source_path, statics_html_path) + + app = StaticFiles(directory=statics_path, follow_symlink=True) + client = test_client_factory(app) + + response = client.get("/html/page.html") + assert response.url == "http://testserver/html/page.html" + assert response.status_code == 200 + assert response.text == "

Hello

" + + +def test_staticfiles_disallows_path_traversal_with_symlinks(tmpdir): + statics_path = os.path.join(tmpdir, "statics") + + root_source_path = tempfile.mkdtemp() + source_path = os.path.join(root_source_path, "statics") + os.mkdir(source_path) + + source_file_path = os.path.join(root_source_path, "index.html") + with open(source_file_path, "w") as file: + file.write("

Hello

") + + os.symlink(source_path, statics_path) + + app = StaticFiles(directory=statics_path, follow_symlink=True) + # We can't test this with 'httpx', so we test the app directly here. + path = app.get_path({"path": "/../index.html"}) + scope = {"method": "GET"} + + with pytest.raises(HTTPException) as exc_info: + anyio.run(app.get_response, path, scope) + + assert exc_info.value.status_code == 404 + assert exc_info.value.detail == "Not Found"