Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement --timeout-request-start flag #1685

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions uvicorn/config.py
Expand Up @@ -241,6 +241,7 @@ def __init__(
limit_concurrency: Optional[int] = None,
limit_max_requests: Optional[int] = None,
backlog: int = 2048,
timeout_request_start: int = 10,
timeout_keep_alive: int = 5,
timeout_notify: int = 30,
callback_notify: Optional[Callable[..., Awaitable[None]]] = None,
Expand Down Expand Up @@ -283,6 +284,7 @@ def __init__(
self.limit_concurrency = limit_concurrency
self.limit_max_requests = limit_max_requests
self.backlog = backlog
self.timeout_request_start = timeout_request_start
self.timeout_keep_alive = timeout_keep_alive
self.timeout_notify = timeout_notify
self.callback_notify = callback_notify
Expand Down
11 changes: 11 additions & 0 deletions uvicorn/main.py
Expand Up @@ -267,6 +267,13 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
default=None,
help="Maximum number of requests to service before terminating the process.",
)
@click.option(
"--timeout-request-start",
type=int,
default=10,
help="Timeout unless request headers complete within this time.",
show_default=True,
)
@click.option(
"--timeout-keep-alive",
type=int,
Expand Down Expand Up @@ -387,6 +394,7 @@ def main(
limit_concurrency: int,
backlog: int,
limit_max_requests: int,
timeout_request_start: int,
timeout_keep_alive: int,
ssl_keyfile: str,
ssl_certfile: str,
Expand Down Expand Up @@ -434,6 +442,7 @@ def main(
limit_concurrency=limit_concurrency,
backlog=backlog,
limit_max_requests=limit_max_requests,
timeout_request_start=timeout_request_start,
timeout_keep_alive=timeout_keep_alive,
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
Expand Down Expand Up @@ -486,6 +495,7 @@ def run(
limit_concurrency: typing.Optional[int] = None,
backlog: int = 2048,
limit_max_requests: typing.Optional[int] = None,
timeout_request_start: int = 10,
timeout_keep_alive: int = 5,
ssl_keyfile: typing.Optional[str] = None,
ssl_certfile: typing.Optional[typing.Union[str, os.PathLike]] = None,
Expand Down Expand Up @@ -536,6 +546,7 @@ def run(
limit_concurrency=limit_concurrency,
backlog=backlog,
limit_max_requests=limit_max_requests,
timeout_request_start=timeout_request_start,
timeout_keep_alive=timeout_keep_alive,
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
Expand Down
16 changes: 16 additions & 0 deletions uvicorn/protocols/http/h11_impl.py
Expand Up @@ -85,6 +85,8 @@ def __init__(
self.limit_concurrency = config.limit_concurrency

# Timeouts
self.timeout_request_start_task: Optional[asyncio.TimerHandle] = None
self.timeout_request_start = config.timeout_request_start
self.timeout_keep_alive_task: Optional[asyncio.TimerHandle] = None
self.timeout_keep_alive = config.timeout_keep_alive

Expand All @@ -104,6 +106,7 @@ def __init__(
self.scope: HTTPScope = None # type: ignore[assignment]
self.headers: List[Tuple[bytes, bytes]] = None # type: ignore[assignment]
self.cycle: RequestResponseCycle = None # type: ignore[assignment]
self.timeout_request_start_task_called = False

# Protocol interface
def connection_made( # type: ignore[override]
Expand All @@ -121,6 +124,11 @@ def connection_made( # type: ignore[override]
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix)

print(self.loop.time())
self.timeout_request_start_task = self.loop.call_later(
self.config.timeout_request_start, self.timeout_request_start_handler
)

def connection_lost(self, exc: Optional[Exception]) -> None:
self.connections.discard(self)

Expand Down Expand Up @@ -154,6 +162,11 @@ def _unset_keepalive_if_required(self) -> None:
self.timeout_keep_alive_task.cancel()
self.timeout_keep_alive_task = None

def _unset_request_start_if_required(self) -> None:
if self.timeout_request_start_task is not None:
self.timeout_request_start_task.cancel()
self.timeout_request_start_task = None

def data_received(self, data: bytes) -> None:
self._unset_keepalive_if_required()

Expand Down Expand Up @@ -358,6 +371,9 @@ def timeout_keep_alive_handler(self) -> None:
self.conn.send(event)
self.transport.close()

def timeout_request_start_handler(self) -> None:
self.timeout_request_start_task_called = True


class RequestResponseCycle:
def __init__(
Expand Down
27 changes: 26 additions & 1 deletion uvicorn/protocols/http/httptools_impl.py
Expand Up @@ -83,6 +83,8 @@ def __init__(
self.limit_concurrency = config.limit_concurrency

# Timeouts
self.timeout_request_start_task: Optional[TimerHandle] = None
self.timeout_request_start = config.timeout_request_start
self.timeout_keep_alive_task: Optional[TimerHandle] = None
self.timeout_keep_alive = config.timeout_keep_alive

Expand All @@ -103,6 +105,7 @@ def __init__(
self.scope: HTTPScope = None # type: ignore[assignment]
self.headers: List[Tuple[bytes, bytes]] = None # type: ignore[assignment]
self.expect_100_continue = False
self.timeout_request_start_task_called = False
self.cycle: RequestResponseCycle = None # type: ignore[assignment]

# Protocol interface
Expand All @@ -121,6 +124,11 @@ def connection_made( # type: ignore[override]
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix)

print(self.loop.time())
self.timeout_request_start_task = self.loop.call_later(
self.config.timeout_request_start, self.timeout_request_start_handler
)

def connection_lost(self, exc: Optional[Exception]) -> None:
self.connections.discard(self)

Expand Down Expand Up @@ -148,9 +156,20 @@ def _unset_keepalive_if_required(self) -> None:
self.timeout_keep_alive_task.cancel()
self.timeout_keep_alive_task = None

def _unset_request_start_if_required(self) -> None:
if self.timeout_request_start_task is not None:
self.timeout_request_start_task.cancel()
self.timeout_request_start_task = None

def data_received(self, data: bytes) -> None:
self._unset_keepalive_if_required()

if self.timeout_request_start_task_called:
msg = "Timeout on request headers."
self.logger.warning(msg)
self.send_400_response(msg)
return

try:
self.parser.feed_data(data)
except httptools.HttpParserError:
Expand Down Expand Up @@ -196,7 +215,6 @@ def handle_upgrade(self) -> None:
self.transport.set_protocol(protocol)

def send_400_response(self, msg: str) -> None:

content = [STATUS_LINE[400]]
for name, value in self.server_state.default_headers:
content.extend([name, b": ", value, b"\r\n"])
Expand Down Expand Up @@ -238,6 +256,10 @@ def on_header(self, name: bytes, value: bytes) -> None:
self.headers.append((name, value))

def on_headers_complete(self) -> None:
if self.timeout_request_start_task_called:
return
self._unset_request_start_if_required()

http_version = self.parser.get_http_version()
method = self.parser.get_method()
self.scope["method"] = method.decode("ascii")
Expand Down Expand Up @@ -356,6 +378,9 @@ def timeout_keep_alive_handler(self) -> None:
if not self.transport.is_closing():
self.transport.close()

def timeout_request_start_handler(self) -> None:
self.timeout_request_start_task_called = True


class RequestResponseCycle:
def __init__(
Expand Down