Skip to content

Commit

Permalink
Switch to using pathlib
Browse files Browse the repository at this point in the history
  • Loading branch information
aminalaee committed May 3, 2022
1 parent 952d598 commit a032b12
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 29 deletions.
35 changes: 17 additions & 18 deletions starlette/staticfiles.py
Expand Up @@ -3,6 +3,7 @@
import stat
import typing
from email.utils import parsedate
from pathlib import Path

import anyio

Expand Down Expand Up @@ -51,7 +52,7 @@ def __init__(
self.all_directories = self.get_directories(directory, packages)
self.html = html
self.config_checked = False
if check_dir and directory is not None and not os.path.isdir(directory):
if check_dir and directory is not None and not Path(directory).is_dir():
raise RuntimeError(f"Directory '{directory}' does not exist")

def get_directories(
Expand All @@ -77,11 +78,9 @@ def get_directories(
spec = importlib.util.find_spec(package)
assert spec is not None, f"Package {package!r} could not be found."
assert spec.origin is not None, f"Package {package!r} could not be found."
package_directory = os.path.normpath(
os.path.join(spec.origin, "..", statics_dir)
)
assert os.path.isdir(
package_directory
package_directory = Path(spec.origin).joinpath("..", statics_dir).resolve()
assert (
package_directory.is_dir()
), f"Directory '{statics_dir!r}' in package {package!r} could not be found."
directories.append(package_directory)

Expand All @@ -101,14 +100,14 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
response = await self.get_response(path, scope)
await response(scope, receive, send)

def get_path(self, scope: Scope) -> str:
def get_path(self, scope: Scope) -> Path:
"""
Given the ASGI scope, return the `path` string to serve up,
with OS specific path separators, and any '..', '.' components removed.
"""
return os.path.normpath(os.path.join(*scope["path"].split("/")))
return Path(*scope["path"].split("/"))

async def get_response(self, path: str, scope: Scope) -> Response:
async def get_response(self, path: Path, scope: Scope) -> Response:
"""
Returns an HTTP response, given the incoming path, method and request headers.
"""
Expand All @@ -131,7 +130,7 @@ async def get_response(self, path: str, scope: Scope) -> Response:
elif stat_result and stat.S_ISDIR(stat_result.st_mode) and self.html:
# We're in HTML mode, and have got a directory URL.
# Check if we have 'index.html' file to serve.
index_path = os.path.join(path, "index.html")
index_path = path.joinpath("index.html")
full_path, stat_result = await anyio.to_thread.run_sync(
self.lookup_path, index_path
)
Expand All @@ -158,22 +157,22 @@ async def get_response(self, path: str, scope: Scope) -> Response:
raise HTTPException(status_code=404)

def lookup_path(
self, path: str
) -> typing.Tuple[str, typing.Optional[os.stat_result]]:
self, path: Path
) -> typing.Tuple[Path, typing.Optional[os.stat_result]]:
for directory in self.all_directories:
original_path = os.path.join(directory, path)
full_path = os.path.realpath(original_path)
directory = os.path.realpath(directory)
is_external = os.path.commonprefix([full_path, directory]) != directory
if is_external and not os.path.islink(original_path):
original_path = Path(directory).joinpath(path)
full_path = original_path.resolve()
directory = Path(directory).resolve()
is_internal = full_path.is_relative_to(directory)
if not is_internal and not original_path.is_symlink():
# Don't allow misbehaving clients to break out of the static files
# directory if not following symlinks.
continue
try:
return full_path, os.stat(full_path)
except (FileNotFoundError, NotADirectoryError):
continue
return "", None
return Path(), None

def file_response(
self,
Expand Down
24 changes: 13 additions & 11 deletions tests/test_staticfiles.py
Expand Up @@ -166,8 +166,8 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir):
directory = os.path.join(tmpdir, "foo")
os.mkdir(directory)

path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file_path = os.path.join(tmpdir, "example.txt")
with open(file_path, "w") as file:
file.write("outside root dir")

app = StaticFiles(directory=directory)
Expand Down Expand Up @@ -443,19 +443,21 @@ def mock_timeout(*args, **kwargs):
assert response.text == "Internal Server Error"


def test_staticfiles_follows_symlinks_to_break_out_of_dir(tmpdir, test_client_factory):
statics_path = os.path.join(tmpdir, "statics")
os.mkdir(statics_path)
def test_staticfiles_follows_symlinks_to_break_out_of_dir(
tmp_path: pathlib.Path, test_client_factory
):
statics_path = tmp_path.joinpath("statics")
statics_path.mkdir()

symlink_path = os.path.join(tmpdir, "symlink")
os.mkdir(symlink_path)
symlink_path = tmp_path.joinpath("symlink")
symlink_path.mkdir()

symlink_file_path = os.path.join(symlink_path, "index.html")
with open(symlink_file_path, "w") as file:
statics_file_path = statics_path.joinpath("index.html")
with open(statics_file_path, "w") as file:
file.write("<h1>Hello</h1>")

statics_file_path = os.path.join(statics_path, "index.html")
os.symlink(symlink_file_path, statics_file_path)
symlink_file_path = symlink_path.joinpath("index.html")
symlink_file_path.symlink_to(statics_file_path)

app = StaticFiles(directory=statics_path)
client = test_client_factory(app)
Expand Down

0 comments on commit a032b12

Please sign in to comment.