Skip to content

Commit

Permalink
Emit a signal about sending headers in client tracing API (#5122)
Browse files Browse the repository at this point in the history
Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
  • Loading branch information
derlih and asvetlov committed Oct 25, 2020
1 parent 4c64ddd commit f244621
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGES/5105.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Tracing for client sent headers
9 changes: 9 additions & 0 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,9 @@ async def send(self, conn: "Connection") -> "ClientResponse":
on_chunk_sent=functools.partial(
self._on_chunk_request_sent, self.method, self.url
),
on_headers_sent=functools.partial(
self._on_headers_request_sent, self.method, self.url
),
)

if self.compress:
Expand Down Expand Up @@ -634,6 +637,12 @@ async def _on_chunk_request_sent(self, method: str, url: URL, chunk: bytes) -> N
for trace in self._traces:
await trace.send_request_chunk_sent(method, url, chunk)

async def _on_headers_request_sent(
self, method: str, url: URL, headers: "CIMultiDict[str]"
) -> None:
for trace in self._traces:
await trace.send_request_headers(method, url, headers)


class ClientResponse(HeadersMixin):

Expand Down
6 changes: 6 additions & 0 deletions aiohttp/http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@


_T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]]
_T_OnHeadersSent = Optional[Callable[["CIMultiDict[str]"], Awaitable[None]]]


class StreamWriter(AbstractStreamWriter):
Expand All @@ -27,6 +28,7 @@ def __init__(
protocol: BaseProtocol,
loop: asyncio.AbstractEventLoop,
on_chunk_sent: _T_OnChunkSent = None,
on_headers_sent: _T_OnHeadersSent = None,
) -> None:
self._protocol = protocol
self._transport = protocol.transport
Expand All @@ -42,6 +44,7 @@ def __init__(
self._drain_waiter = None

self._on_chunk_sent = on_chunk_sent # type: _T_OnChunkSent
self._on_headers_sent = on_headers_sent # type: _T_OnHeadersSent

@property
def transport(self) -> Optional[asyncio.Transport]:
Expand Down Expand Up @@ -114,6 +117,9 @@ async def write_headers(
self, status_line: str, headers: "CIMultiDict[str]"
) -> None:
"""Write request/response status and headers."""
if self._on_headers_sent is not None:
await self._on_headers_sent(headers)

# status + headers
buf = _serialize_headers(status_line, headers)
self._write(buf)
Expand Down
29 changes: 29 additions & 0 deletions aiohttp/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __call__(
"TraceRequestRedirectParams",
"TraceRequestChunkSentParams",
"TraceResponseChunkReceivedParams",
"TraceRequestHeadersSentParams",
)


Expand Down Expand Up @@ -97,6 +98,9 @@ def __init__(
self._on_dns_cache_miss = Signal(
self
) # type: Signal[_SignalCallback[TraceDnsCacheMissParams]]
self._on_request_headers_sent = Signal(
self
) # type: Signal[_SignalCallback[TraceRequestHeadersSentParams]]

self._trace_config_ctx_factory = trace_config_ctx_factory

Expand All @@ -122,6 +126,7 @@ def freeze(self) -> None:
self._on_dns_resolvehost_end.freeze()
self._on_dns_cache_hit.freeze()
self._on_dns_cache_miss.freeze()
self._on_request_headers_sent.freeze()

@property
def on_request_start(self) -> "Signal[_SignalCallback[TraceRequestStartParams]]":
Expand Down Expand Up @@ -205,6 +210,12 @@ def on_dns_cache_hit(self) -> "Signal[_SignalCallback[TraceDnsCacheHitParams]]":
def on_dns_cache_miss(self) -> "Signal[_SignalCallback[TraceDnsCacheMissParams]]":
return self._on_dns_cache_miss

@property
def on_request_headers_sent(
self,
) -> "Signal[_SignalCallback[TraceRequestHeadersSentParams]]":
return self._on_request_headers_sent


@attr.s(auto_attribs=True, frozen=True, slots=True)
class TraceRequestStartParams:
Expand Down Expand Up @@ -316,6 +327,15 @@ class TraceDnsCacheMissParams:
host: str


@attr.s(auto_attribs=True, frozen=True, slots=True)
class TraceRequestHeadersSentParams:
""" Parameters sent by the `on_request_headers_sent` signal"""

method: str
url: URL
headers: "CIMultiDict[str]"


class Trace:
"""Internal class used to keep together the main dependencies used
at the moment of send a signal."""
Expand Down Expand Up @@ -440,3 +460,12 @@ async def send_dns_cache_miss(self, host: str) -> None:
return await self._trace_config.on_dns_cache_miss.send(
self._session, self._trace_config_ctx, TraceDnsCacheMissParams(host)
)

async def send_request_headers(
self, method: str, url: URL, headers: "CIMultiDict[str]"
) -> None:
return await self._trace_config._on_request_headers_sent.send(
self._session,
self._trace_config_ctx,
TraceRequestHeadersSentParams(method, url, headers),
)
31 changes: 30 additions & 1 deletion docs/tracing_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Overview

acquire_connection[description="Connection acquiring"];
headers_received;
headers_sent;
headers_sent[description="on_request_headers_sent"];
chunk_sent[description="on_request_chunk_sent"];
chunk_received[description="on_response_chunk_received"];

Expand Down Expand Up @@ -269,6 +269,14 @@ TraceConfig

``params`` is :class:`aiohttp.TraceDnsCacheMissParams` instance.

.. attribute:: on_request_headers_sent

Property that gives access to the signals that will be executed
when request headers are sent.

``params`` is :class:`aiohttp.TraceRequestHeadersSentParams` instance.

.. versionadded:: 3.8

TraceRequestStartParams
-----------------------
Expand Down Expand Up @@ -492,3 +500,24 @@ TraceDnsCacheMissParams
.. attribute:: host

Host didn't find the cache.

TraceRequestHeadersSentParams
-----------------------------

.. class:: TraceRequestHeadersSentParams

See :attr:`TraceConfig.on_request_headers_sent` for details.

.. versionadded:: 3.8

.. attribute:: method

Method that will be used to make the request.

.. attribute:: url

URL that will be used for the request.

.. attribute:: headers

Headers that will be used for the request.
15 changes: 12 additions & 3 deletions tests/test_client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@ async def handler(request):
body = "This is request body"
gathered_req_body = BytesIO()
gathered_res_body = BytesIO()
gathered_req_headers = CIMultiDict()
on_request_start = mock.Mock(side_effect=make_mocked_coro(mock.Mock()))
on_request_redirect = mock.Mock(side_effect=make_mocked_coro(mock.Mock()))
on_request_end = mock.Mock(side_effect=make_mocked_coro(mock.Mock()))
Expand All @@ -543,6 +544,9 @@ async def on_request_chunk_sent(session, context, params):
async def on_response_chunk_received(session, context, params):
gathered_res_body.write(params.chunk)

async def on_request_headers_sent(session, context, params):
gathered_req_headers.extend(**params.headers)

trace_config = aiohttp.TraceConfig(
trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx)
)
Expand All @@ -551,8 +555,12 @@ async def on_response_chunk_received(session, context, params):
trace_config.on_request_chunk_sent.append(on_request_chunk_sent)
trace_config.on_response_chunk_received.append(on_response_chunk_received)
trace_config.on_request_redirect.append(on_request_redirect)
trace_config.on_request_headers_sent.append(on_request_headers_sent)

session = await aiohttp_client(app, trace_configs=[trace_config])
headers = CIMultiDict({"Custom-Header": "Custom value"})
session = await aiohttp_client(
app, trace_configs=[trace_config], headers=headers
)

async with session.post(
"/", data=body, trace_request_ctx=trace_request_ctx
Expand All @@ -564,20 +572,21 @@ async def on_response_chunk_received(session, context, params):
session.session,
trace_config_ctx,
aiohttp.TraceRequestStartParams(
hdrs.METH_POST, session.make_url("/"), CIMultiDict()
hdrs.METH_POST, session.make_url("/"), headers
),
)

on_request_end.assert_called_once_with(
session.session,
trace_config_ctx,
aiohttp.TraceRequestEndParams(
hdrs.METH_POST, session.make_url("/"), CIMultiDict(), resp
hdrs.METH_POST, session.make_url("/"), headers, resp
),
)
assert not on_request_redirect.called
assert gathered_req_body.getvalue() == body.encode("utf8")
assert gathered_res_body.getvalue() == json.dumps({"ok": True}).encode("utf8")
assert gathered_req_headers["Custom-Header"] == "Custom value"


async def test_request_tracing_exception() -> None:
Expand Down
1 change: 1 addition & 0 deletions tests/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def test_freeze(self) -> None:
assert trace_config.on_dns_resolvehost_end.frozen
assert trace_config.on_dns_cache_hit.frozen
assert trace_config.on_dns_cache_miss.frozen
assert trace_config.on_request_headers_sent.frozen


class TestTrace:
Expand Down

0 comments on commit f244621

Please sign in to comment.