From d3dccdc477652b6de5a7b6b14a2bf3fa2f94be2c Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Sat, 28 May 2022 16:17:54 +0200 Subject: [PATCH] Allow staticfiles to follow symlinks outside directory (#1377) --- starlette/staticfiles.py | 44 +++++++++++++++++++++------------------ tests/test_staticfiles.py | 29 ++++++++++++++++++++++++-- 2 files changed, 51 insertions(+), 22 deletions(-) diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py index d09630f35..da10a390c 100644 --- a/starlette/staticfiles.py +++ b/starlette/staticfiles.py @@ -3,6 +3,7 @@ import stat import typing from email.utils import parsedate +from pathlib import Path import anyio @@ -51,7 +52,7 @@ def __init__( self.all_directories = self.get_directories(directory, packages) self.html = html self.config_checked = False - if check_dir and directory is not None and not os.path.isdir(directory): + if check_dir and directory is not None and not Path(directory).is_dir(): raise RuntimeError(f"Directory '{directory}' does not exist") def get_directories( @@ -77,11 +78,9 @@ def get_directories( spec = importlib.util.find_spec(package) assert spec is not None, f"Package {package!r} could not be found." assert spec.origin is not None, f"Package {package!r} could not be found." - package_directory = os.path.normpath( - os.path.join(spec.origin, "..", statics_dir) - ) - assert os.path.isdir( - package_directory + package_directory = Path(spec.origin).joinpath("..", statics_dir).resolve() + assert ( + package_directory.is_dir() ), f"Directory '{statics_dir!r}' in package {package!r} could not be found." directories.append(package_directory) @@ -101,14 +100,14 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: response = await self.get_response(path, scope) await response(scope, receive, send) - def get_path(self, scope: Scope) -> str: + def get_path(self, scope: Scope) -> Path: """ Given the ASGI scope, return the `path` string to serve up, with OS specific path separators, and any '..', '.' components removed. """ - return os.path.normpath(os.path.join(*scope["path"].split("/"))) + return Path(*scope["path"].split("/")) - async def get_response(self, path: str, scope: Scope) -> Response: + async def get_response(self, path: Path, scope: Scope) -> Response: """ Returns an HTTP response, given the incoming path, method and request headers. """ @@ -131,7 +130,7 @@ async def get_response(self, path: str, scope: Scope) -> Response: elif stat_result and stat.S_ISDIR(stat_result.st_mode) and self.html: # We're in HTML mode, and have got a directory URL. # Check if we have 'index.html' file to serve. - index_path = os.path.join(path, "index.html") + index_path = path.joinpath("index.html") full_path, stat_result = await anyio.to_thread.run_sync( self.lookup_path, index_path ) @@ -158,20 +157,25 @@ async def get_response(self, path: str, scope: Scope) -> Response: raise HTTPException(status_code=404) def lookup_path( - self, path: str - ) -> typing.Tuple[str, typing.Optional[os.stat_result]]: + self, path: Path + ) -> typing.Tuple[Path, typing.Optional[os.stat_result]]: for directory in self.all_directories: - full_path = os.path.realpath(os.path.join(directory, path)) - directory = os.path.realpath(directory) - if os.path.commonprefix([full_path, directory]) != directory: - # Don't allow misbehaving clients to break out of the static files - # directory. - continue + original_path = Path(directory).joinpath(path) + full_path = original_path.resolve() + directory = Path(directory).resolve() try: - return full_path, os.stat(full_path) + stat_result = os.lstat(original_path) + full_path.relative_to(directory) + return full_path, stat_result + except ValueError: + # Allow clients to break out of the static files directory + # if following symlinks. + if stat.S_ISLNK(stat_result.st_mode): + stat_result = os.lstat(full_path) + return full_path, stat_result except (FileNotFoundError, NotADirectoryError): continue - return "", None + return Path(), None def file_response( self, diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index 7d13a0522..53f3ea9cd 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -166,8 +166,8 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir): directory = os.path.join(tmpdir, "foo") os.mkdir(directory) - path = os.path.join(tmpdir, "example.txt") - with open(path, "w") as file: + file_path = os.path.join(tmpdir, "example.txt") + with open(file_path, "w") as file: file.write("outside root dir") app = StaticFiles(directory=directory) @@ -441,3 +441,28 @@ def mock_timeout(*args, **kwargs): response = client.get("/example.txt") assert response.status_code == 500 assert response.text == "Internal Server Error" + + +def test_staticfiles_follows_symlinks_to_break_out_of_dir( + tmp_path: pathlib.Path, test_client_factory +): + statics_path = tmp_path.joinpath("statics") + statics_path.mkdir() + + symlink_path = tmp_path.joinpath("symlink") + symlink_path.mkdir() + + symlink_file_path = symlink_path.joinpath("index.html") + with open(symlink_file_path, "w") as file: + file.write("

Hello

") + + statics_file_path = statics_path.joinpath("index.html") + statics_file_path.symlink_to(symlink_file_path) + + app = StaticFiles(directory=statics_path) + client = test_client_factory(app) + + response = client.get("/index.html") + assert response.url == "http://testserver/index.html" + assert response.status_code == 200 + assert response.text == "

Hello

"