Skip to content

Commit

Permalink
Allow StaticFiles follow symlinks (#1683)
Browse files Browse the repository at this point in the history
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
  • Loading branch information
aminalaee and Kludex committed Feb 4, 2023
1 parent ea70fd5 commit ea2e794
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 2 deletions.
3 changes: 2 additions & 1 deletion docs/staticfiles.md
Expand Up @@ -3,12 +3,13 @@ Starlette also includes a `StaticFiles` class for serving files in a given direc

### StaticFiles

Signature: `StaticFiles(directory=None, packages=None, check_dir=True)`
Signature: `StaticFiles(directory=None, packages=None, check_dir=True, follow_symlink=False)`

* `directory` - A string or [os.Pathlike][pathlike] denoting a directory path.
* `packages` - A list of strings or list of tuples of strings of python packages.
* `html` - Run in HTML mode. Automatically loads `index.html` for directories if such file exist.
* `check_dir` - Ensure that the directory exists upon instantiation. Defaults to `True`.
* `follow_symlink` - A boolean indicating if symbolic links for files and directories should be followed. Defaults to `False`.

You can combine this ASGI application with Starlette's routing to provide
comprehensive static file serving.
Expand Down
8 changes: 7 additions & 1 deletion starlette/staticfiles.py
Expand Up @@ -45,12 +45,14 @@ def __init__(
] = None,
html: bool = False,
check_dir: bool = True,
follow_symlink: bool = False,
) -> None:
self.directory = directory
self.packages = packages
self.all_directories = self.get_directories(directory, packages)
self.html = html
self.config_checked = False
self.follow_symlink = follow_symlink
if check_dir and directory is not None and not os.path.isdir(directory):
raise RuntimeError(f"Directory '{directory}' does not exist")

Expand Down Expand Up @@ -161,7 +163,11 @@ def lookup_path(
self, path: str
) -> typing.Tuple[str, typing.Optional[os.stat_result]]:
for directory in self.all_directories:
full_path = os.path.realpath(os.path.join(directory, path))
joined_path = os.path.join(directory, path)
if self.follow_symlink:
full_path = os.path.abspath(joined_path)
else:
full_path = os.path.realpath(joined_path)
directory = os.path.realpath(directory)
if os.path.commonprefix([full_path, directory]) != directory:
# Don't allow misbehaving clients to break out of the static files
Expand Down
68 changes: 68 additions & 0 deletions tests/test_staticfiles.py
@@ -1,6 +1,7 @@
import os
import pathlib
import stat
import tempfile
import time

import anyio
Expand Down Expand Up @@ -448,3 +449,70 @@ def mock_timeout(*args, **kwargs):
response = client.get("/example.txt")
assert response.status_code == 500
assert response.text == "Internal Server Error"


def test_staticfiles_follows_symlinks(tmpdir, test_client_factory):
statics_path = os.path.join(tmpdir, "statics")
os.mkdir(statics_path)

source_path = tempfile.mkdtemp()
source_file_path = os.path.join(source_path, "page.html")
with open(source_file_path, "w") as file:
file.write("<h1>Hello</h1>")

statics_file_path = os.path.join(statics_path, "index.html")
os.symlink(source_file_path, statics_file_path)

app = StaticFiles(directory=statics_path, follow_symlink=True)
client = test_client_factory(app)

response = client.get("/index.html")
assert response.url == "http://testserver/index.html"
assert response.status_code == 200
assert response.text == "<h1>Hello</h1>"


def test_staticfiles_follows_symlink_directories(tmpdir, test_client_factory):
statics_path = os.path.join(tmpdir, "statics")
statics_html_path = os.path.join(statics_path, "html")
os.mkdir(statics_path)

source_path = tempfile.mkdtemp()
source_file_path = os.path.join(source_path, "page.html")
with open(source_file_path, "w") as file:
file.write("<h1>Hello</h1>")

os.symlink(source_path, statics_html_path)

app = StaticFiles(directory=statics_path, follow_symlink=True)
client = test_client_factory(app)

response = client.get("/html/page.html")
assert response.url == "http://testserver/html/page.html"
assert response.status_code == 200
assert response.text == "<h1>Hello</h1>"


def test_staticfiles_disallows_path_traversal_with_symlinks(tmpdir):
statics_path = os.path.join(tmpdir, "statics")

root_source_path = tempfile.mkdtemp()
source_path = os.path.join(root_source_path, "statics")
os.mkdir(source_path)

source_file_path = os.path.join(root_source_path, "index.html")
with open(source_file_path, "w") as file:
file.write("<h1>Hello</h1>")

os.symlink(source_path, statics_path)

app = StaticFiles(directory=statics_path, follow_symlink=True)
# We can't test this with 'httpx', so we test the app directly here.
path = app.get_path({"path": "/../index.html"})
scope = {"method": "GET"}

with pytest.raises(HTTPException) as exc_info:
anyio.run(app.get_response, path, scope)

assert exc_info.value.status_code == 404
assert exc_info.value.detail == "Not Found"

0 comments on commit ea2e794

Please sign in to comment.