diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py index 33ea0b033..39a697260 100644 --- a/starlette/staticfiles.py +++ b/starlette/staticfiles.py @@ -7,12 +7,8 @@ import anyio from starlette.datastructures import URL, Headers -from starlette.responses import ( - FileResponse, - PlainTextResponse, - RedirectResponse, - Response, -) +from starlette.exceptions import HTTPException +from starlette.responses import FileResponse, RedirectResponse, Response from starlette.types import Receive, Scope, Send PathLike = typing.Union[str, "os.PathLike[str]"] @@ -109,9 +105,30 @@ async def get_response(self, path: str, scope: Scope) -> Response: Returns an HTTP response, given the incoming path, method and request headers. """ if scope["method"] not in ("GET", "HEAD"): - return PlainTextResponse("Method Not Allowed", status_code=405) + raise HTTPException(status_code=405) - full_path, stat_result = await self.lookup_path(path) + try: + full_path, stat_result = await anyio.to_thread.run_sync( + self.lookup_path, path + ) + except (FileNotFoundError, NotADirectoryError): + if self.html: + # Check for '404.html' if we're in HTML mode. + full_path, stat_result = await anyio.to_thread.run_sync( + self.lookup_path, "404.html" + ) + if stat_result and stat.S_ISREG(stat_result.st_mode): + return FileResponse( + full_path, + stat_result=stat_result, + method=scope["method"], + status_code=404, + ) + raise HTTPException(status_code=404) + except PermissionError: + raise HTTPException(status_code=401) + except OSError: + raise if stat_result and stat.S_ISREG(stat_result.st_mode): # We have a static file to serve. @@ -121,7 +138,9 @@ async def get_response(self, path: str, scope: Scope) -> Response: # We're in HTML mode, and have got a directory URL. # Check if we have 'index.html' file to serve. index_path = os.path.join(path, "index.html") - full_path, stat_result = await self.lookup_path(index_path) + full_path, stat_result = await anyio.to_thread.run_sync( + self.lookup_path, index_path + ) if stat_result is not None and stat.S_ISREG(stat_result.st_mode): if not scope["path"].endswith("/"): # Directory URLs should redirect to always end in "/". @@ -130,20 +149,9 @@ async def get_response(self, path: str, scope: Scope) -> Response: return RedirectResponse(url=url) return self.file_response(full_path, stat_result, scope) - if self.html: - # Check for '404.html' if we're in HTML mode. - full_path, stat_result = await self.lookup_path("404.html") - if stat_result is not None and stat.S_ISREG(stat_result.st_mode): - return FileResponse( - full_path, - stat_result=stat_result, - method=scope["method"], - status_code=404, - ) - - return PlainTextResponse("Not Found", status_code=404) + raise HTTPException(status_code=404) - async def lookup_path( + def lookup_path( self, path: str ) -> typing.Tuple[str, typing.Optional[os.stat_result]]: for directory in self.all_directories: @@ -153,11 +161,7 @@ async def lookup_path( # Don't allow misbehaving clients to break out of the static files # directory. continue - try: - stat_result = await anyio.to_thread.run_sync(os.stat, full_path) - return full_path, stat_result - except FileNotFoundError: - pass + return full_path, os.stat(full_path) return "", None def file_response( diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index d5ec1afc5..48fdaf1a5 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -1,11 +1,13 @@ import os import pathlib +import stat import time import anyio import pytest from starlette.applications import Starlette +from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.routing import Mount from starlette.staticfiles import StaticFiles @@ -71,8 +73,10 @@ def test_staticfiles_post(tmpdir, test_client_factory): with open(path, "w") as file: file.write("") - app = StaticFiles(directory=tmpdir) + routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")] + app = Starlette(routes=routes) client = test_client_factory(app) + response = client.post("/example.txt") assert response.status_code == 405 assert response.text == "Method Not Allowed" @@ -83,8 +87,10 @@ def test_staticfiles_with_directory_returns_404(tmpdir, test_client_factory): with open(path, "w") as file: file.write("") - app = StaticFiles(directory=tmpdir) + routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")] + app = Starlette(routes=routes) client = test_client_factory(app) + response = client.get("/") assert response.status_code == 404 assert response.text == "Not Found" @@ -95,8 +101,10 @@ def test_staticfiles_with_missing_file_returns_404(tmpdir, test_client_factory): with open(path, "w") as file: file.write("") - app = StaticFiles(directory=tmpdir) + routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")] + app = Starlette(routes=routes) client = test_client_factory(app) + response = client.get("/404.txt") assert response.status_code == 404 assert response.text == "Not Found" @@ -136,11 +144,15 @@ def test_staticfiles_config_check_occurs_only_once(tmpdir, test_client_factory): app = StaticFiles(directory=tmpdir) client = test_client_factory(app) assert not app.config_checked - client.get("/") - assert app.config_checked - client.get("/") + + with pytest.raises(HTTPException): + client.get("/") + assert app.config_checked + with pytest.raises(HTTPException): + client.get("/") + def test_staticfiles_prevents_breaking_out_of_directory(tmpdir): directory = os.path.join(tmpdir, "foo") @@ -154,9 +166,12 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir): # We can't test this with 'requests', so we test the app directly here. path = app.get_path({"path": "/../example.txt"}) scope = {"method": "GET"} - response = anyio.run(app.get_response, path, scope) - assert response.status_code == 404 - assert response.body == b"Not Found" + + with pytest.raises(HTTPException) as exc_info: + anyio.run(app.get_response, path, scope) + + assert exc_info.value.status_code == 404 + assert exc_info.value.detail == "Not Found" def test_staticfiles_never_read_file_for_head_method(tmpdir, test_client_factory): @@ -284,3 +299,70 @@ def test_staticfiles_cache_invalidation_for_deleted_file_html_mode( ) assert resp_deleted.status_code == 404 assert resp_deleted.text == "

404 file

" + + +def test_staticfiles_with_invalid_dir_permissions_returns_401( + tmpdir, test_client_factory +): + path = os.path.join(tmpdir, "example.txt") + with open(path, "w") as file: + file.write("") + + os.chmod(tmpdir, stat.S_IRWXO) + + routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")] + app = Starlette(routes=routes) + client = test_client_factory(app) + + response = client.get("/example.txt") + assert response.status_code == 401 + assert response.text == "Unauthorized" + + +def test_staticfiles_with_missing_dir_returns_404(tmpdir, test_client_factory): + path = os.path.join(tmpdir, "example.txt") + with open(path, "w") as file: + file.write("") + + routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")] + app = Starlette(routes=routes) + client = test_client_factory(app) + + response = client.get("/foo/example.txt") + assert response.status_code == 404 + assert response.text == "Not Found" + + +def test_staticfiles_access_file_as_dir_returns_404(tmpdir, test_client_factory): + path = os.path.join(tmpdir, "example.txt") + with open(path, "w") as file: + file.write("") + + routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")] + app = Starlette(routes=routes) + client = test_client_factory(app) + + response = client.get("/example.txt/foo") + assert response.status_code == 404 + assert response.text == "Not Found" + + +def test_staticfiles_unhandled_os_error_returns_500( + tmpdir, test_client_factory, monkeypatch +): + def mock_timeout(*args, **kwargs): + raise TimeoutError + + path = os.path.join(tmpdir, "example.txt") + with open(path, "w") as file: + file.write("") + + routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")] + app = Starlette(routes=routes) + client = test_client_factory(app, raise_server_exceptions=False) + + monkeypatch.setattr("starlette.staticfiles.StaticFiles.lookup_path", mock_timeout) + + response = client.get("/example.txt") + assert response.status_code == 500 + assert response.text == "Internal Server Error"