Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow staticfiles to follow symlinks outside directory #1377

Merged
merged 17 commits into from May 28, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
43 changes: 23 additions & 20 deletions starlette/staticfiles.py
Expand Up @@ -3,6 +3,7 @@
import stat
import typing
from email.utils import parsedate
from pathlib import Path

import anyio

Expand Down Expand Up @@ -51,7 +52,7 @@ def __init__(
self.all_directories = self.get_directories(directory, packages)
self.html = html
self.config_checked = False
if check_dir and directory is not None and not os.path.isdir(directory):
if check_dir and directory is not None and not Path(directory).is_dir():
raise RuntimeError(f"Directory '{directory}' does not exist")

def get_directories(
Expand All @@ -77,11 +78,9 @@ def get_directories(
spec = importlib.util.find_spec(package)
assert spec is not None, f"Package {package!r} could not be found."
assert spec.origin is not None, f"Package {package!r} could not be found."
package_directory = os.path.normpath(
os.path.join(spec.origin, "..", statics_dir)
)
assert os.path.isdir(
package_directory
package_directory = Path(spec.origin).joinpath("..", statics_dir).resolve()
aminalaee marked this conversation as resolved.
Show resolved Hide resolved
assert (
package_directory.is_dir()
), f"Directory '{statics_dir!r}' in package {package!r} could not be found."
directories.append(package_directory)

Expand All @@ -101,14 +100,14 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
response = await self.get_response(path, scope)
await response(scope, receive, send)

def get_path(self, scope: Scope) -> str:
def get_path(self, scope: Scope) -> Path:
"""
Given the ASGI scope, return the `path` string to serve up,
with OS specific path separators, and any '..', '.' components removed.
"""
return os.path.normpath(os.path.join(*scope["path"].split("/")))
return Path(*scope["path"].split("/"))

async def get_response(self, path: str, scope: Scope) -> Response:
async def get_response(self, path: Path, scope: Scope) -> Response:
"""
Returns an HTTP response, given the incoming path, method and request headers.
"""
Expand All @@ -131,7 +130,7 @@ async def get_response(self, path: str, scope: Scope) -> Response:
elif stat_result and stat.S_ISDIR(stat_result.st_mode) and self.html:
# We're in HTML mode, and have got a directory URL.
# Check if we have 'index.html' file to serve.
index_path = os.path.join(path, "index.html")
index_path = path.joinpath("index.html")
full_path, stat_result = await anyio.to_thread.run_sync(
self.lookup_path, index_path
)
Expand All @@ -158,20 +157,24 @@ async def get_response(self, path: str, scope: Scope) -> Response:
raise HTTPException(status_code=404)

def lookup_path(
self, path: str
) -> typing.Tuple[str, typing.Optional[os.stat_result]]:
self, path: Path
) -> typing.Tuple[Path, typing.Optional[os.stat_result]]:
for directory in self.all_directories:
full_path = os.path.realpath(os.path.join(directory, path))
directory = os.path.realpath(directory)
if os.path.commonprefix([full_path, directory]) != directory:
# Don't allow misbehaving clients to break out of the static files
# directory.
continue
original_path = Path(directory).joinpath(path)
full_path = original_path.resolve()
directory = Path(directory).resolve()
try:
return full_path, os.stat(full_path)
stat_result = os.lstat(original_path)
full_path.relative_to(directory)
aminalaee marked this conversation as resolved.
Show resolved Hide resolved
return full_path, stat_result
except ValueError:
# Don't allow misbehaving clients to break out of the static files
# directory if not following symlinks.
aminalaee marked this conversation as resolved.
Show resolved Hide resolved
if not stat.S_ISLNK(stat_result.st_mode):
continue
aminalaee marked this conversation as resolved.
Show resolved Hide resolved
except (FileNotFoundError, NotADirectoryError):
continue
return "", None
return Path(), None

def file_response(
self,
Expand Down
29 changes: 27 additions & 2 deletions tests/test_staticfiles.py
Expand Up @@ -166,8 +166,8 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir):
directory = os.path.join(tmpdir, "foo")
os.mkdir(directory)

path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file_path = os.path.join(tmpdir, "example.txt")
euri10 marked this conversation as resolved.
Show resolved Hide resolved
with open(file_path, "w") as file:
file.write("outside root dir")

app = StaticFiles(directory=directory)
Expand Down Expand Up @@ -441,3 +441,28 @@ def mock_timeout(*args, **kwargs):
response = client.get("/example.txt")
assert response.status_code == 500
assert response.text == "Internal Server Error"


def test_staticfiles_follows_symlinks_to_break_out_of_dir(
aminalaee marked this conversation as resolved.
Show resolved Hide resolved
tmp_path: pathlib.Path, test_client_factory
):
statics_path = tmp_path.joinpath("statics")
statics_path.mkdir()

symlink_path = tmp_path.joinpath("symlink")
symlink_path.mkdir()

statics_file_path = statics_path.joinpath("index.html")
with open(statics_file_path, "w") as file:
file.write("<h1>Hello</h1>")

symlink_file_path = symlink_path.joinpath("index.html")
symlink_file_path.symlink_to(statics_file_path)

app = StaticFiles(directory=statics_path)
client = test_client_factory(app)

response = client.get("/index.html")
assert response.url == "http://testserver/index.html"
assert response.status_code == 200
assert response.text == "<h1>Hello</h1>"