From 11af5c2c00226b963223783109d749531b13f116 Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Wed, 20 Apr 2022 12:39:38 +0430 Subject: [PATCH] Switch to using pathlib --- starlette/staticfiles.py | 35 +++++++++++++++++------------------ tests/test_staticfiles.py | 24 +++++++++++++----------- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py index 89c6ea765..2dfd3334a 100644 --- a/starlette/staticfiles.py +++ b/starlette/staticfiles.py @@ -3,6 +3,7 @@ import stat import typing from email.utils import parsedate +from pathlib import Path import anyio @@ -51,7 +52,7 @@ def __init__( self.all_directories = self.get_directories(directory, packages) self.html = html self.config_checked = False - if check_dir and directory is not None and not os.path.isdir(directory): + if check_dir and directory is not None and not Path(directory).is_dir(): raise RuntimeError(f"Directory '{directory}' does not exist") def get_directories( @@ -77,11 +78,9 @@ def get_directories( spec = importlib.util.find_spec(package) assert spec is not None, f"Package {package!r} could not be found." assert spec.origin is not None, f"Package {package!r} could not be found." - package_directory = os.path.normpath( - os.path.join(spec.origin, "..", statics_dir) - ) - assert os.path.isdir( - package_directory + package_directory = Path(spec.origin).joinpath("..", statics_dir).resolve() + assert ( + package_directory.is_dir() ), f"Directory '{statics_dir!r}' in package {package!r} could not be found." directories.append(package_directory) @@ -101,14 +100,14 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: response = await self.get_response(path, scope) await response(scope, receive, send) - def get_path(self, scope: Scope) -> str: + def get_path(self, scope: Scope) -> Path: """ Given the ASGI scope, return the `path` string to serve up, with OS specific path separators, and any '..', '.' components removed. """ - return os.path.normpath(os.path.join(*scope["path"].split("/"))) + return Path(*scope["path"].split("/")) - async def get_response(self, path: str, scope: Scope) -> Response: + async def get_response(self, path: Path, scope: Scope) -> Response: """ Returns an HTTP response, given the incoming path, method and request headers. """ @@ -131,7 +130,7 @@ async def get_response(self, path: str, scope: Scope) -> Response: elif stat_result and stat.S_ISDIR(stat_result.st_mode) and self.html: # We're in HTML mode, and have got a directory URL. # Check if we have 'index.html' file to serve. - index_path = os.path.join(path, "index.html") + index_path = path.joinpath("index.html") full_path, stat_result = await anyio.to_thread.run_sync( self.lookup_path, index_path ) @@ -158,14 +157,14 @@ async def get_response(self, path: str, scope: Scope) -> Response: raise HTTPException(status_code=404) def lookup_path( - self, path: str - ) -> typing.Tuple[str, typing.Optional[os.stat_result]]: + self, path: Path + ) -> typing.Tuple[Path, typing.Optional[os.stat_result]]: for directory in self.all_directories: - original_path = os.path.join(directory, path) - full_path = os.path.realpath(original_path) - directory = os.path.realpath(directory) - is_external = os.path.commonprefix([full_path, directory]) != directory - if is_external and not os.path.islink(original_path): + original_path = Path(directory).joinpath(path) + full_path = original_path.resolve() + directory = Path(directory).resolve() + is_internal = full_path.is_relative_to(directory) + if not is_internal and not original_path.is_symlink(): # Don't allow misbehaving clients to break out of the static files # directory if not following symlinks. continue @@ -173,7 +172,7 @@ def lookup_path( return full_path, os.stat(full_path) except (FileNotFoundError, NotADirectoryError): continue - return "", None + return Path(), None def file_response( self, diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index 8cac8eaa6..4d9b47324 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -166,8 +166,8 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir): directory = os.path.join(tmpdir, "foo") os.mkdir(directory) - path = os.path.join(tmpdir, "example.txt") - with open(path, "w") as file: + file_path = os.path.join(tmpdir, "example.txt") + with open(file_path, "w") as file: file.write("outside root dir") app = StaticFiles(directory=directory) @@ -443,19 +443,21 @@ def mock_timeout(*args, **kwargs): assert response.text == "Internal Server Error" -def test_staticfiles_follows_symlinks_to_break_out_of_dir(tmpdir, test_client_factory): - statics_path = os.path.join(tmpdir, "statics") - os.mkdir(statics_path) +def test_staticfiles_follows_symlinks_to_break_out_of_dir( + tmp_path: pathlib.Path, test_client_factory +): + statics_path = tmp_path.joinpath("statics") + statics_path.mkdir() - symlink_path = os.path.join(tmpdir, "symlink") - os.mkdir(symlink_path) + symlink_path = tmp_path.joinpath("symlink") + symlink_path.mkdir() - symlink_file_path = os.path.join(symlink_path, "index.html") - with open(symlink_file_path, "w") as file: + statics_file_path = statics_path.joinpath("index.html") + with open(statics_file_path, "w") as file: file.write("

Hello

") - statics_file_path = os.path.join(statics_path, "index.html") - os.symlink(symlink_file_path, statics_file_path) + symlink_file_path = symlink_path.joinpath("index.html") + symlink_file_path.symlink_to(statics_file_path) app = StaticFiles(directory=statics_path) client = test_client_factory(app)