Skip to content

Commit

Permalink
Add missing trace log for websocket protocols (#1083)
Browse files Browse the repository at this point in the history
* Add missing trace log for websocket protocols

Formerly we only have trace log for http protocol. It seems we missed
to add trace log for websocket protocols.

* Add trace logging tests for protocols

Protocols emit trace logs on `connection_made()` and `connection_lost`.
Add missing tests for these logs.

* Specify connection type in protocol trace log

The default transport protocol is http protocol. A ws request
is firstly handled using http protocol, then handled using ws protocol
by switching the protocol. To discriminate the trace log emitter,
we need to specify the protocol type in the log message.
  • Loading branch information
laggardkernel authored and Kludex committed Nov 17, 2021
1 parent 593897a commit 6370d45
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 4 deletions.
51 changes: 51 additions & 0 deletions tests/middleware/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import httpx
import pytest
import websockets

from tests.utils import run_server
from uvicorn import Config
Expand Down Expand Up @@ -45,6 +46,56 @@ async def test_trace_logging(caplog):
assert "ASGI [2] Completed" in messages.pop(0)


@pytest.mark.asyncio
@pytest.mark.parametrize("http_protocol", [("h11"), ("httptools")])
async def test_trace_logging_on_http_protocol(http_protocol, caplog):
config = Config(app=app, log_level="trace", http=http_protocol)
with caplog_for_logger(caplog, "uvicorn.error"):
async with run_server(config):
async with httpx.AsyncClient() as client:
response = await client.get("http://127.0.0.1:8000")
assert response.status_code == 204
messages = [
record.message
for record in caplog.records
if record.name == "uvicorn.error"
]
assert any(" - HTTP connection made" in message for message in messages)
assert any(" - HTTP connection lost" in message for message in messages)


@pytest.mark.asyncio
@pytest.mark.parametrize("ws_protocol", [("websockets"), ("wsproto")])
async def test_trace_logging_on_ws_protocol(ws_protocol, caplog):
async def websocket_app(scope, receive, send):
assert scope["type"] == "websocket"
while True:
message = await receive()
if message["type"] == "websocket.connect":
await send({"type": "websocket.accept"})
elif message["type"] == "websocket.disconnect":
break

async def open_connection(url):
async with websockets.connect(url) as websocket:
return websocket.open

config = Config(app=websocket_app, log_level="trace", ws=ws_protocol)
with caplog_for_logger(caplog, "uvicorn.error"):
async with run_server(config):
is_open = await open_connection("ws://127.0.0.1:8000")
assert is_open
messages = [
record.message
for record in caplog.records
if record.name == "uvicorn.error"
]
print(messages)
assert any(" - Upgrading to WebSocket" in message for message in messages)
assert any(" - WebSocket connection made" in message for message in messages)
assert any(" - WebSocket connection lost" in message for message in messages)


@pytest.mark.asyncio
@pytest.mark.parametrize("use_colors", [(True), (False), (None)])
async def test_access_logging(use_colors, caplog):
Expand Down
8 changes: 6 additions & 2 deletions uvicorn/protocols/http/h11_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ def connection_made(self, transport):

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sConnection made", prefix)
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix)

def connection_lost(self, exc):
self.connections.discard(self)

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sConnection lost", prefix)
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection lost", prefix)

if self.cycle and not self.cycle.response_complete:
self.cycle.disconnected = True
Expand Down Expand Up @@ -249,6 +249,10 @@ def handle_upgrade(self, event):
self.transport.close()
return

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix)

self.connections.discard(self)
output = [event.method, b" ", event.target, b" HTTP/1.1\r\n"]
for name, value in self.headers:
Expand Down
8 changes: 6 additions & 2 deletions uvicorn/protocols/http/httptools_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,14 @@ def connection_made(self, transport):

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sConnection made", prefix)
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix)

def connection_lost(self, exc):
self.connections.discard(self)

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sConnection lost", prefix)
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection lost", prefix)

if self.cycle and not self.cycle.response_complete:
self.cycle.disconnected = True
Expand Down Expand Up @@ -161,6 +161,10 @@ def handle_upgrade(self):
self.transport.close()
return

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix)

self.connections.discard(self)
method = self.scope["method"].encode()
output = [method, b" ", self.url, b" HTTP/1.1\r\n"]
Expand Down
11 changes: 11 additions & 0 deletions uvicorn/protocols/websockets/websockets_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import websockets
from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory

from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.utils import get_local_addr, get_remote_addr, is_ssl


Expand Down Expand Up @@ -70,10 +71,20 @@ def connection_made(self, transport):
self.server = get_local_addr(transport)
self.client = get_remote_addr(transport)
self.scheme = "wss" if is_ssl(transport) else "ws"

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)

super().connection_made(transport)

def connection_lost(self, exc):
self.connections.remove(self)

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)

self.handshake_completed_event.set()
super().connection_lost(exc)

Expand Down
9 changes: 9 additions & 0 deletions uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from wsproto.extensions import PerMessageDeflate
from wsproto.utilities import RemoteProtocolError

from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.utils import get_local_addr, get_remote_addr, is_ssl

# Check wsproto version. We've build against 0.13. We don't know about 0.14 yet.
Expand Down Expand Up @@ -61,11 +62,19 @@ def connection_made(self, transport):
self.client = get_remote_addr(transport)
self.scheme = "wss" if is_ssl(transport) else "ws"

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)

def connection_lost(self, exc):
if exc is not None:
self.queue.put_nowait({"type": "websocket.disconnect"})
self.connections.remove(self)

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)

def eof_received(self):
pass

Expand Down

0 comments on commit 6370d45

Please sign in to comment.