diff --git a/uvicorn/config.py b/uvicorn/config.py index ea7888aee..e1c594b9e 100644 --- a/uvicorn/config.py +++ b/uvicorn/config.py @@ -241,6 +241,7 @@ def __init__( limit_concurrency: Optional[int] = None, limit_max_requests: Optional[int] = None, backlog: int = 2048, + timeout_request_start: int = 10, timeout_keep_alive: int = 5, timeout_notify: int = 30, callback_notify: Optional[Callable[..., Awaitable[None]]] = None, @@ -283,6 +284,7 @@ def __init__( self.limit_concurrency = limit_concurrency self.limit_max_requests = limit_max_requests self.backlog = backlog + self.timeout_request_start = timeout_request_start self.timeout_keep_alive = timeout_keep_alive self.timeout_notify = timeout_notify self.callback_notify = callback_notify diff --git a/uvicorn/main.py b/uvicorn/main.py index 16f5aa117..3f44dc6da 100644 --- a/uvicorn/main.py +++ b/uvicorn/main.py @@ -267,6 +267,13 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No default=None, help="Maximum number of requests to service before terminating the process.", ) +@click.option( + "--timeout-request-start", + type=int, + default=10, + help="Timeout unless request headers complete within this time.", + show_default=True, +) @click.option( "--timeout-keep-alive", type=int, @@ -387,6 +394,7 @@ def main( limit_concurrency: int, backlog: int, limit_max_requests: int, + timeout_request_start: int, timeout_keep_alive: int, ssl_keyfile: str, ssl_certfile: str, @@ -434,6 +442,7 @@ def main( limit_concurrency=limit_concurrency, backlog=backlog, limit_max_requests=limit_max_requests, + timeout_request_start=timeout_request_start, timeout_keep_alive=timeout_keep_alive, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile, @@ -486,6 +495,7 @@ def run( limit_concurrency: typing.Optional[int] = None, backlog: int = 2048, limit_max_requests: typing.Optional[int] = None, + timeout_request_start: int = 10, timeout_keep_alive: int = 5, ssl_keyfile: typing.Optional[str] = None, ssl_certfile: typing.Optional[typing.Union[str, os.PathLike]] = None, @@ -536,6 +546,7 @@ def run( limit_concurrency=limit_concurrency, backlog=backlog, limit_max_requests=limit_max_requests, + timeout_request_start=timeout_request_start, timeout_keep_alive=timeout_keep_alive, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile, diff --git a/uvicorn/protocols/http/h11_impl.py b/uvicorn/protocols/http/h11_impl.py index f9974c182..eeb4417b8 100644 --- a/uvicorn/protocols/http/h11_impl.py +++ b/uvicorn/protocols/http/h11_impl.py @@ -85,6 +85,8 @@ def __init__( self.limit_concurrency = config.limit_concurrency # Timeouts + self.timeout_request_start_task: Optional[asyncio.TimerHandle] = None + self.timeout_request_start = config.timeout_request_start self.timeout_keep_alive_task: Optional[asyncio.TimerHandle] = None self.timeout_keep_alive = config.timeout_keep_alive @@ -104,6 +106,7 @@ def __init__( self.scope: HTTPScope = None # type: ignore[assignment] self.headers: List[Tuple[bytes, bytes]] = None # type: ignore[assignment] self.cycle: RequestResponseCycle = None # type: ignore[assignment] + self.timeout_request_start_task_called = False # Protocol interface def connection_made( # type: ignore[override] @@ -121,6 +124,11 @@ def connection_made( # type: ignore[override] prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix) + print(self.loop.time()) + self.timeout_request_start_task = self.loop.call_later( + self.config.timeout_request_start, self.timeout_request_start_handler + ) + def connection_lost(self, exc: Optional[Exception]) -> None: self.connections.discard(self) @@ -154,6 +162,11 @@ def _unset_keepalive_if_required(self) -> None: self.timeout_keep_alive_task.cancel() self.timeout_keep_alive_task = None + def _unset_request_start_if_required(self) -> None: + if self.timeout_request_start_task is not None: + self.timeout_request_start_task.cancel() + self.timeout_request_start_task = None + def data_received(self, data: bytes) -> None: self._unset_keepalive_if_required() @@ -358,6 +371,9 @@ def timeout_keep_alive_handler(self) -> None: self.conn.send(event) self.transport.close() + def timeout_request_start_handler(self) -> None: + self.timeout_request_start_task_called = True + class RequestResponseCycle: def __init__( diff --git a/uvicorn/protocols/http/httptools_impl.py b/uvicorn/protocols/http/httptools_impl.py index fe3f826dd..d79031b29 100644 --- a/uvicorn/protocols/http/httptools_impl.py +++ b/uvicorn/protocols/http/httptools_impl.py @@ -83,6 +83,8 @@ def __init__( self.limit_concurrency = config.limit_concurrency # Timeouts + self.timeout_request_start_task: Optional[TimerHandle] = None + self.timeout_request_start = config.timeout_request_start self.timeout_keep_alive_task: Optional[TimerHandle] = None self.timeout_keep_alive = config.timeout_keep_alive @@ -103,6 +105,7 @@ def __init__( self.scope: HTTPScope = None # type: ignore[assignment] self.headers: List[Tuple[bytes, bytes]] = None # type: ignore[assignment] self.expect_100_continue = False + self.timeout_request_start_task_called = False self.cycle: RequestResponseCycle = None # type: ignore[assignment] # Protocol interface @@ -121,6 +124,11 @@ def connection_made( # type: ignore[override] prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix) + print(self.loop.time()) + self.timeout_request_start_task = self.loop.call_later( + self.config.timeout_request_start, self.timeout_request_start_handler + ) + def connection_lost(self, exc: Optional[Exception]) -> None: self.connections.discard(self) @@ -148,9 +156,20 @@ def _unset_keepalive_if_required(self) -> None: self.timeout_keep_alive_task.cancel() self.timeout_keep_alive_task = None + def _unset_request_start_if_required(self) -> None: + if self.timeout_request_start_task is not None: + self.timeout_request_start_task.cancel() + self.timeout_request_start_task = None + def data_received(self, data: bytes) -> None: self._unset_keepalive_if_required() + if self.timeout_request_start_task_called: + msg = "Timeout on request headers." + self.logger.warning(msg) + self.send_400_response(msg) + return + try: self.parser.feed_data(data) except httptools.HttpParserError: @@ -196,7 +215,6 @@ def handle_upgrade(self) -> None: self.transport.set_protocol(protocol) def send_400_response(self, msg: str) -> None: - content = [STATUS_LINE[400]] for name, value in self.server_state.default_headers: content.extend([name, b": ", value, b"\r\n"]) @@ -238,6 +256,10 @@ def on_header(self, name: bytes, value: bytes) -> None: self.headers.append((name, value)) def on_headers_complete(self) -> None: + if self.timeout_request_start_task_called: + return + self._unset_request_start_if_required() + http_version = self.parser.get_http_version() method = self.parser.get_method() self.scope["method"] = method.decode("ascii") @@ -356,6 +378,9 @@ def timeout_keep_alive_handler(self) -> None: if not self.transport.is_closing(): self.transport.close() + def timeout_request_start_handler(self) -> None: + self.timeout_request_start_task_called = True + class RequestResponseCycle: def __init__(