diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py index 39a697260..f7057539f 100644 --- a/starlette/staticfiles.py +++ b/starlette/staticfiles.py @@ -111,20 +111,6 @@ async def get_response(self, path: str, scope: Scope) -> Response: full_path, stat_result = await anyio.to_thread.run_sync( self.lookup_path, path ) - except (FileNotFoundError, NotADirectoryError): - if self.html: - # Check for '404.html' if we're in HTML mode. - full_path, stat_result = await anyio.to_thread.run_sync( - self.lookup_path, "404.html" - ) - if stat_result and stat.S_ISREG(stat_result.st_mode): - return FileResponse( - full_path, - stat_result=stat_result, - method=scope["method"], - status_code=404, - ) - raise HTTPException(status_code=404) except PermissionError: raise HTTPException(status_code=401) except OSError: @@ -149,6 +135,18 @@ async def get_response(self, path: str, scope: Scope) -> Response: return RedirectResponse(url=url) return self.file_response(full_path, stat_result, scope) + if self.html: + # Check for '404.html' if we're in HTML mode. + full_path, stat_result = await anyio.to_thread.run_sync( + self.lookup_path, "404.html" + ) + if stat_result and stat.S_ISREG(stat_result.st_mode): + return FileResponse( + full_path, + stat_result=stat_result, + method=scope["method"], + status_code=404, + ) raise HTTPException(status_code=404) def lookup_path( @@ -161,7 +159,10 @@ def lookup_path( # Don't allow misbehaving clients to break out of the static files # directory. continue - return full_path, os.stat(full_path) + try: + return full_path, os.stat(full_path) + except (FileNotFoundError, NotADirectoryError): + continue return "", None def file_response( diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index 48fdaf1a5..8057af689 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -229,7 +229,7 @@ def test_staticfiles_304_with_last_modified_compare_last_req( assert response.content == b"" -def test_staticfiles_html(tmpdir, test_client_factory): +def test_staticfiles_html_normal(tmpdir, test_client_factory): path = os.path.join(tmpdir, "404.html") with open(path, "w") as file: file.write("

Custom not found page

") @@ -262,6 +262,73 @@ def test_staticfiles_html(tmpdir, test_client_factory): assert response.text == "

Custom not found page

" +def test_staticfiles_html_without_index(tmpdir, test_client_factory): + path = os.path.join(tmpdir, "404.html") + with open(path, "w") as file: + file.write("

Custom not found page

") + path = os.path.join(tmpdir, "dir") + os.mkdir(path) + + app = StaticFiles(directory=tmpdir, html=True) + client = test_client_factory(app) + + response = client.get("/dir/") + assert response.url == "http://testserver/dir/" + assert response.status_code == 404 + assert response.text == "

Custom not found page

" + + response = client.get("/dir") + assert response.url == "http://testserver/dir" + assert response.status_code == 404 + assert response.text == "

Custom not found page

" + + response = client.get("/missing") + assert response.status_code == 404 + assert response.text == "

Custom not found page

" + + +def test_staticfiles_html_without_404(tmpdir, test_client_factory): + path = os.path.join(tmpdir, "dir") + os.mkdir(path) + path = os.path.join(path, "index.html") + with open(path, "w") as file: + file.write("

Hello

") + + app = StaticFiles(directory=tmpdir, html=True) + client = test_client_factory(app) + + response = client.get("/dir/") + assert response.url == "http://testserver/dir/" + assert response.status_code == 200 + assert response.text == "

Hello

" + + response = client.get("/dir") + assert response.url == "http://testserver/dir/" + assert response.status_code == 200 + assert response.text == "

Hello

" + + with pytest.raises(HTTPException) as exc_info: + response = client.get("/missing") + assert exc_info.value.status_code == 404 + + +def test_staticfiles_html_only_files(tmpdir, test_client_factory): + path = os.path.join(tmpdir, "hello.html") + with open(path, "w") as file: + file.write("

Hello

") + + app = StaticFiles(directory=tmpdir, html=True) + client = test_client_factory(app) + + with pytest.raises(HTTPException) as exc_info: + response = client.get("/") + assert exc_info.value.status_code == 404 + + response = client.get("/hello.html") + assert response.status_code == 200 + assert response.text == "

Hello

" + + def test_staticfiles_cache_invalidation_for_deleted_file_html_mode( tmpdir, test_client_factory ):