From fab084a7c275228e4f4ff289e4bccd48659521e5 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 10 Jun 2022 06:38:43 +0200 Subject: [PATCH] Revert "Allow staticfiles to follow symlinks outside directory (#1377)" This reverts commit d3dccdc477652b6de5a7b6b14a2bf3fa2f94be2c. --- starlette/staticfiles.py | 44 ++++++++++++++++++--------------------- tests/test_staticfiles.py | 29 ++------------------------ 2 files changed, 22 insertions(+), 51 deletions(-) diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py index da10a390c..d09630f35 100644 --- a/starlette/staticfiles.py +++ b/starlette/staticfiles.py @@ -3,7 +3,6 @@ import stat import typing from email.utils import parsedate -from pathlib import Path import anyio @@ -52,7 +51,7 @@ def __init__( self.all_directories = self.get_directories(directory, packages) self.html = html self.config_checked = False - if check_dir and directory is not None and not Path(directory).is_dir(): + if check_dir and directory is not None and not os.path.isdir(directory): raise RuntimeError(f"Directory '{directory}' does not exist") def get_directories( @@ -78,9 +77,11 @@ def get_directories( spec = importlib.util.find_spec(package) assert spec is not None, f"Package {package!r} could not be found." assert spec.origin is not None, f"Package {package!r} could not be found." - package_directory = Path(spec.origin).joinpath("..", statics_dir).resolve() - assert ( - package_directory.is_dir() + package_directory = os.path.normpath( + os.path.join(spec.origin, "..", statics_dir) + ) + assert os.path.isdir( + package_directory ), f"Directory '{statics_dir!r}' in package {package!r} could not be found." directories.append(package_directory) @@ -100,14 +101,14 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: response = await self.get_response(path, scope) await response(scope, receive, send) - def get_path(self, scope: Scope) -> Path: + def get_path(self, scope: Scope) -> str: """ Given the ASGI scope, return the `path` string to serve up, with OS specific path separators, and any '..', '.' components removed. """ - return Path(*scope["path"].split("/")) + return os.path.normpath(os.path.join(*scope["path"].split("/"))) - async def get_response(self, path: Path, scope: Scope) -> Response: + async def get_response(self, path: str, scope: Scope) -> Response: """ Returns an HTTP response, given the incoming path, method and request headers. """ @@ -130,7 +131,7 @@ async def get_response(self, path: Path, scope: Scope) -> Response: elif stat_result and stat.S_ISDIR(stat_result.st_mode) and self.html: # We're in HTML mode, and have got a directory URL. # Check if we have 'index.html' file to serve. - index_path = path.joinpath("index.html") + index_path = os.path.join(path, "index.html") full_path, stat_result = await anyio.to_thread.run_sync( self.lookup_path, index_path ) @@ -157,25 +158,20 @@ async def get_response(self, path: Path, scope: Scope) -> Response: raise HTTPException(status_code=404) def lookup_path( - self, path: Path - ) -> typing.Tuple[Path, typing.Optional[os.stat_result]]: + self, path: str + ) -> typing.Tuple[str, typing.Optional[os.stat_result]]: for directory in self.all_directories: - original_path = Path(directory).joinpath(path) - full_path = original_path.resolve() - directory = Path(directory).resolve() + full_path = os.path.realpath(os.path.join(directory, path)) + directory = os.path.realpath(directory) + if os.path.commonprefix([full_path, directory]) != directory: + # Don't allow misbehaving clients to break out of the static files + # directory. + continue try: - stat_result = os.lstat(original_path) - full_path.relative_to(directory) - return full_path, stat_result - except ValueError: - # Allow clients to break out of the static files directory - # if following symlinks. - if stat.S_ISLNK(stat_result.st_mode): - stat_result = os.lstat(full_path) - return full_path, stat_result + return full_path, os.stat(full_path) except (FileNotFoundError, NotADirectoryError): continue - return Path(), None + return "", None def file_response( self, diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index 53f3ea9cd..7d13a0522 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -166,8 +166,8 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir): directory = os.path.join(tmpdir, "foo") os.mkdir(directory) - file_path = os.path.join(tmpdir, "example.txt") - with open(file_path, "w") as file: + path = os.path.join(tmpdir, "example.txt") + with open(path, "w") as file: file.write("outside root dir") app = StaticFiles(directory=directory) @@ -441,28 +441,3 @@ def mock_timeout(*args, **kwargs): response = client.get("/example.txt") assert response.status_code == 500 assert response.text == "Internal Server Error" - - -def test_staticfiles_follows_symlinks_to_break_out_of_dir( - tmp_path: pathlib.Path, test_client_factory -): - statics_path = tmp_path.joinpath("statics") - statics_path.mkdir() - - symlink_path = tmp_path.joinpath("symlink") - symlink_path.mkdir() - - symlink_file_path = symlink_path.joinpath("index.html") - with open(symlink_file_path, "w") as file: - file.write("

Hello

") - - statics_file_path = statics_path.joinpath("index.html") - statics_file_path.symlink_to(symlink_file_path) - - app = StaticFiles(directory=statics_path) - client = test_client_factory(app) - - response = client.get("/index.html") - assert response.url == "http://testserver/index.html" - assert response.status_code == 200 - assert response.text == "

Hello

"