-
Notifications
You must be signed in to change notification settings - Fork 756
/
access.py
100 lines (84 loc) · 3.77 KB
/
access.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from __future__ import annotations
import typing as t
import logging
import functools
from timeit import default_timer
from typing import TYPE_CHECKING
from bentoml.grpc.utils import to_http_status
from bentoml.grpc.utils import wrap_rpc_handler
from bentoml.grpc.utils import GRPC_CONTENT_TYPE
if TYPE_CHECKING:
import grpc
from grpc import aio
from grpc.aio._typing import MetadataType # pylint: disable=unused-import
from bentoml.grpc.types import Request
from bentoml.grpc.types import Response
from bentoml.grpc.types import RpcMethodHandler
from bentoml.grpc.types import AsyncHandlerMethod
from bentoml.grpc.types import HandlerCallDetails
from bentoml.grpc.types import BentoServicerContext
from bentoml.grpc.v1alpha1 import service_pb2 as pb
else:
from bentoml.grpc.utils import import_grpc
from bentoml.grpc.utils import import_generated_stubs
pb, _ = import_generated_stubs()
grpc, aio = import_grpc()
class AccessLogServerInterceptor(aio.ServerInterceptor):
"""
An asyncio interceptor for access logging.
"""
async def intercept_service(
self,
continuation: t.Callable[[HandlerCallDetails], t.Awaitable[RpcMethodHandler]],
handler_call_details: HandlerCallDetails,
) -> RpcMethodHandler:
logger = logging.getLogger("bentoml.access")
handler = await continuation(handler_call_details)
method_name = handler_call_details.method
if handler and (handler.response_streaming or handler.request_streaming):
return handler
def wrapper(behaviour: AsyncHandlerMethod[Response]):
@functools.wraps(behaviour)
async def new_behaviour(
request: Request, context: BentoServicerContext
) -> Response | t.Awaitable[Response]:
content_type = GRPC_CONTENT_TYPE
trailing_metadata: MetadataType | None = context.trailing_metadata()
if trailing_metadata:
trailing = dict(trailing_metadata)
content_type = trailing.get("content-type", GRPC_CONTENT_TYPE)
response = pb.Response()
start = default_timer()
try:
response = await behaviour(request, context)
except Exception as e: # pylint: disable=broad-except
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(str(e))
finally:
latency = max(default_timer() - start, 0) * 1000
req = [
"scheme=http", # TODO: support https when ssl is added
f"path={method_name}",
f"type={content_type}",
f"size={request.ByteSize()}",
]
# Note that in order AccessLogServerInterceptor to work, the
# interceptor must be added to the server after AsyncOpenTeleServerInterceptor
# and PrometheusServerInterceptor.
typed_context_code = t.cast(grpc.StatusCode, context.code())
resp = [
f"http_status={to_http_status(typed_context_code)}",
f"grpc_status={typed_context_code.value[0]}",
f"type={content_type}",
f"size={response.ByteSize()}",
]
logger.info(
"%s (%s) (%s) %.3fms",
context.peer(),
",".join(req),
",".join(resp),
latency,
)
return response
return new_behaviour
return wrap_rpc_handler(wrapper, handler)