Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(runner): receive requests before cork #2996

Merged
merged 2 commits into from Sep 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion bentoml/_internal/marshal/dispatcher.py
Expand Up @@ -149,7 +149,7 @@ def _wake_event(self):
def __call__(
self,
callback: t.Callable[
[t.Collection[T_IN]], t.Coroutine[None, None, t.Collection[T_OUT]]
[t.Sequence[T_IN]], t.Coroutine[None, None, t.Sequence[T_OUT]]
],
) -> t.Callable[[T_IN], t.Coroutine[None, None, T_OUT]]:
self.callback = callback
Expand Down
91 changes: 48 additions & 43 deletions bentoml/_internal/server/runner_app.py
Expand Up @@ -11,6 +11,8 @@
from simple_di import inject
from simple_di import Provide

from bentoml._internal.types import LazyType

from ..context import trace_context
from ..context import component_context
from ..runner.utils import Params
Expand All @@ -19,6 +21,7 @@
from ..utils.metrics import metric_name
from ..utils.metrics import exponential_buckets
from ..server.base_app import BaseAppFactory
from ..runner.container import Payload
from ..runner.container import AutoContainer
from ..marshal.dispatcher import CorkDispatcher
from ..configuration.containers import BentoMLContainer
Expand Down Expand Up @@ -138,13 +141,10 @@ def routes(self) -> t.List[BaseRoute]:
for method in self.runner.runner_methods:
path = "/" if method.name == "__call__" else "/" + method.name
if method.config.batchable:
_func = self.dispatchers[method.name](
self._async_cork_run(runner_method=method)
)
routes.append(
Route(
path=path,
endpoint=_func,
endpoint=self._mk_request_handler(runner_method=method),
methods=["POST"],
)
)
Expand Down Expand Up @@ -205,32 +205,30 @@ def client_request_hook(span: Span, _scope: t.Dict[str, t.Any]) -> None:

return middlewares

def _async_cork_run(
def _mk_request_handler(
self,
runner_method: RunnerMethod[t.Any, t.Any, t.Any],
) -> t.Callable[[t.Collection[Request]], t.Coroutine[None, None, list[Response]]]:
) -> t.Callable[[Request], t.Coroutine[None, None, Response]]:
from starlette.responses import Response

async def _run(requests: t.Collection[Request]) -> list[Response]:
assert self._is_ready
server_str = f"BentoML-Runner/{self.runner.name}/{runner_method.name}/{self.worker_index}"

self.legacy_adaptive_batch_size_hist_map[runner_method.name].observe(
len(requests)
async def infer_batch(
params_list: t.Sequence[Params[t.Any]],
) -> list[Payload] | list[tuple[Payload, ...]]:
self.legacy_adaptive_batch_size_hist_map[runner_method.name].observe( # type: ignore
len(params_list)
)
self.adaptive_batch_size_hist.labels(
self.adaptive_batch_size_hist.labels( # type: ignore
runner_name=self.runner.name,
worker_index=self.worker_index,
method_name=runner_method.name,
service_version=component_context.bento_version,
service_name=component_context.bento_name,
).observe(len(requests))
).observe(len(params_list))

if not requests:
if not params_list:
return []
params_list: list[Params[t.Any]] = []
for r in requests:
r_ = await r.body()
params_list.append(pickle.loads(r_))

input_batch_dim, output_batch_dim = runner_method.config.batch_dim

Expand All @@ -242,50 +240,57 @@ async def _run(requests: t.Collection[Request]) -> list[Response]:
*batched_params.args, **batched_params.kwargs
)

server_str = f"BentoML-Runner/{self.runner.name}/{runner_method.name}/{self.worker_index}"

# multiple output branch
if isinstance(batch_ret, tuple):
if LazyType["tuple[t.Any, ...]"](tuple).isinstance(batch_ret):
output_num = len(batch_ret)
payloadss = [
payloadss = tuple(
AutoContainer.batch_to_payloads(
batch_ret[idx], indices, batch_dim=output_batch_dim
)
for idx in range(output_num)
]

return [
Response(
pickle.dumps(payloads),
headers={
PAYLOAD_META_HEADER: json.dumps({}),
"Content-Type": "application/vnd.bentoml.multiple_outputs",
"Server": server_str,
},
)
for payloads in zip(*payloadss)
]
)
ret = list(zip(*payloadss))
return ret

# single output branch
payloads = AutoContainer.batch_to_payloads(
batch_ret,
indices,
batch_dim=output_batch_dim,
)
return payloads

return [
Response(
payload.data,
infer = self.dispatchers[runner_method.name](infer_batch)

async def _request_handler(request: Request) -> Response:
assert self._is_ready

r_: bytes = await request.body()
params: Params[t.Any] = pickle.loads(r_)

payload = await infer(params)

if not isinstance(
payload, Payload
): # a tuple, which means user runnable has multiple outputs
return Response(
pickle.dumps(payload),
headers={
PAYLOAD_META_HEADER: json.dumps(payload.meta),
"Content-Type": f"application/vnd.bentoml.{payload.container}",
PAYLOAD_META_HEADER: json.dumps({}),
"Content-Type": "application/vnd.bentoml.multiple_outputs",
"Server": server_str,
},
)
for payload in payloads
]
return Response(
payload.data,
headers={
PAYLOAD_META_HEADER: json.dumps(payload.meta),
"Content-Type": f"application/vnd.bentoml.{payload.container}",
"Server": server_str,
},
)

return _run
return _request_handler

def async_run(
self,
Expand All @@ -302,7 +307,7 @@ async def _run(request: Request) -> Response:

try:
ret = await runner_method.async_run(*params.args, **params.kwargs)
except BaseException as exc:
except Exception as exc: # pylint: disable=broad-except
logger.error(
f"Exception on runner '{runner_method.runner.name}' method '{runner_method.name}'",
exc_info=exc,
Expand Down