diff --git a/sanic/app.py b/sanic/app.py index d78c67ded7..566266e06d 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -1631,7 +1631,9 @@ async def _startup(self): self._future_registry.clear() self.signalize() self.finalize() - ErrorHandler.finalize(self.error_handler) + ErrorHandler.finalize( + self.error_handler, fallback=self.config.FALLBACK_ERROR_FORMAT + ) TouchUp.run(self) async def _server_event( diff --git a/sanic/handlers.py b/sanic/handlers.py index af667c9a8e..046e56e18c 100644 --- a/sanic/handlers.py +++ b/sanic/handlers.py @@ -38,7 +38,14 @@ def __init__( self.base = base @classmethod - def finalize(cls, error_handler): + def finalize(cls, error_handler, fallback: Optional[str] = None): + if ( + fallback + and fallback != "auto" + and error_handler.fallback == "auto" + ): + error_handler.fallback = fallback + if not isinstance(error_handler, cls): error_logger.warning( f"Error handler is non-conforming: {type(error_handler)}" diff --git a/tests/test_errorpages.py b/tests/test_errorpages.py index 84949fde5c..1843f6a707 100644 --- a/tests/test_errorpages.py +++ b/tests/test_errorpages.py @@ -1,6 +1,7 @@ import pytest from sanic import Sanic +from sanic.config import Config from sanic.errorpages import HTMLRenderer, exception_response from sanic.exceptions import NotFound, SanicException from sanic.handlers import ErrorHandler @@ -313,3 +314,31 @@ def test_setting_fallback_to_non_default_raise_warning(app): app.config.FALLBACK_ERROR_FORMAT = "json" assert app.error_handler.fallback == "json" + + +def test_allow_fallback_error_format_in_config_injection(): + class MyConfig(Config): + FALLBACK_ERROR_FORMAT = "text" + + app = Sanic("test", config=MyConfig()) + + @app.route("/error", methods=["GET", "POST"]) + def err(request): + raise Exception("something went wrong") + + request, response = app.test_client.get("/error") + assert request.app.error_handler.fallback == "text" + assert response.status == 500 + assert response.content_type == "text/plain; charset=utf-8" + + +def test_allow_fallback_error_format_in_config_replacement(app): + class MyConfig(Config): + FALLBACK_ERROR_FORMAT = "text" + + app.config = MyConfig() + + request, response = app.test_client.get("/error") + assert request.app.error_handler.fallback == "text" + assert response.status == 500 + assert response.content_type == "text/plain; charset=utf-8"