Skip to content

Commit

Permalink
fix(runner): receive requests before cork
Browse files Browse the repository at this point in the history
  • Loading branch information
bojiang committed Sep 16, 2022
1 parent 532f5e1 commit 7d8dbd4
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 44 deletions.
2 changes: 1 addition & 1 deletion bentoml/_internal/marshal/dispatcher.py
Expand Up @@ -149,7 +149,7 @@ def _wake_event(self):
def __call__(
self,
callback: t.Callable[
[t.Collection[T_IN]], t.Coroutine[None, None, t.Collection[T_OUT]]
[t.Sequence[T_IN]], t.Coroutine[None, None, t.Sequence[T_OUT]]
],
) -> t.Callable[[T_IN], t.Coroutine[None, None, T_OUT]]:
self.callback = callback
Expand Down
88 changes: 45 additions & 43 deletions bentoml/_internal/server/runner_app.py
Expand Up @@ -11,6 +11,8 @@
from simple_di import inject
from simple_di import Provide

from bentoml._internal.types import LazyType

from ..context import trace_context
from ..context import component_context
from ..runner.utils import Params
Expand All @@ -35,6 +37,7 @@

from ..runner.runner import Runner
from ..runner.runner import RunnerMethod
from ..runner.container import Payload


class RunnerAppFactory(BaseAppFactory):
Expand Down Expand Up @@ -138,13 +141,10 @@ def routes(self) -> t.List[BaseRoute]:
for method in self.runner.runner_methods:
path = "/" if method.name == "__call__" else "/" + method.name
if method.config.batchable:
_func = self.dispatchers[method.name](
self._async_cork_run(runner_method=method)
)
routes.append(
Route(
path=path,
endpoint=_func,
endpoint=self._mk_request_handler(runner_method=method),
methods=["POST"],
)
)
Expand Down Expand Up @@ -205,32 +205,30 @@ def client_request_hook(span: Span, _scope: t.Dict[str, t.Any]) -> None:

return middlewares

def _async_cork_run(
def _mk_request_handler(
self,
runner_method: RunnerMethod[t.Any, t.Any, t.Any],
) -> t.Callable[[t.Collection[Request]], t.Coroutine[None, None, list[Response]]]:
) -> t.Callable[[Request], t.Coroutine[None, None, Response]]:
from starlette.responses import Response

async def _run(requests: t.Collection[Request]) -> list[Response]:
assert self._is_ready
server_str = f"BentoML-Runner/{self.runner.name}/{runner_method.name}/{self.worker_index}"

self.legacy_adaptive_batch_size_hist_map[runner_method.name].observe(
len(requests)
async def infer_batch(
params_list: t.Sequence[Params[t.Any]],
) -> list[Payload] | list[tuple[Payload, ...]]:
self.legacy_adaptive_batch_size_hist_map[runner_method.name].observe( # type: ignore
len(params_list)
)
self.adaptive_batch_size_hist.labels(
self.adaptive_batch_size_hist.labels( # type: ignore
runner_name=self.runner.name,
worker_index=self.worker_index,
method_name=runner_method.name,
service_version=component_context.bento_version,
service_name=component_context.bento_name,
).observe(len(requests))
).observe(len(params_list))

if not requests:
if not params_list:
return []
params_list: list[Params[t.Any]] = []
for r in requests:
r_ = await r.body()
params_list.append(pickle.loads(r_))

input_batch_dim, output_batch_dim = runner_method.config.batch_dim

Expand All @@ -242,50 +240,54 @@ async def _run(requests: t.Collection[Request]) -> list[Response]:
*batched_params.args, **batched_params.kwargs
)

server_str = f"BentoML-Runner/{self.runner.name}/{runner_method.name}/{self.worker_index}"

# multiple output branch
if isinstance(batch_ret, tuple):
if LazyType[tuple[t.Any, ...]](tuple).isinstance(batch_ret):
output_num = len(batch_ret)
payloadss = [
payloadss = tuple(
AutoContainer.batch_to_payloads(
batch_ret[idx], indices, batch_dim=output_batch_dim
)
for idx in range(output_num)
]

return [
Response(
pickle.dumps(payloads),
headers={
PAYLOAD_META_HEADER: json.dumps({}),
"Content-Type": "application/vnd.bentoml.multiple_outputs",
"Server": server_str,
},
)
for payloads in zip(*payloadss)
]
)
return list(zip(*payloadss))

# single output branch
payloads = AutoContainer.batch_to_payloads(
batch_ret,
indices,
batch_dim=output_batch_dim,
)
return payloads

return [
Response(
payload.data,
infer = self.dispatchers[runner_method.name](infer_batch)

async def _request_handler(request: Request) -> Response:
assert self._is_ready

r_: bytes = await request.body()
params: Params[t.Any] = pickle.loads(r_)

payload = await infer(params)

if isinstance(payload, tuple):
return Response(
pickle.dumps(payload),
headers={
PAYLOAD_META_HEADER: json.dumps(payload.meta),
"Content-Type": f"application/vnd.bentoml.{payload.container}",
PAYLOAD_META_HEADER: json.dumps({}),
"Content-Type": "application/vnd.bentoml.multiple_outputs",
"Server": server_str,
},
)
for payload in payloads
]
return Response(
payload.data,
headers={
PAYLOAD_META_HEADER: json.dumps(payload.meta),
"Content-Type": f"application/vnd.bentoml.{payload.container}",
"Server": server_str,
},
)

return _run
return _request_handler

def async_run(
self,
Expand All @@ -302,7 +304,7 @@ async def _run(request: Request) -> Response:

try:
ret = await runner_method.async_run(*params.args, **params.kwargs)
except BaseException as exc:
except Exception as exc: # pylint: disable=broad-except
logger.error(
f"Exception on runner '{runner_method.runner.name}' method '{runner_method.name}'",
exc_info=exc,
Expand Down

0 comments on commit 7d8dbd4

Please sign in to comment.