diff --git a/bentoml/_internal/marshal/dispatcher.py b/bentoml/_internal/marshal/dispatcher.py index 25e8ba5e70..3a3b7dca5e 100644 --- a/bentoml/_internal/marshal/dispatcher.py +++ b/bentoml/_internal/marshal/dispatcher.py @@ -149,7 +149,7 @@ def _wake_event(self): def __call__( self, callback: t.Callable[ - [t.Collection[T_IN]], t.Coroutine[None, None, t.Collection[T_OUT]] + [t.Sequence[T_IN]], t.Coroutine[None, None, t.Sequence[T_OUT]] ], ) -> t.Callable[[T_IN], t.Coroutine[None, None, T_OUT]]: self.callback = callback diff --git a/bentoml/_internal/server/runner_app.py b/bentoml/_internal/server/runner_app.py index f463590ba6..2f3e94c456 100644 --- a/bentoml/_internal/server/runner_app.py +++ b/bentoml/_internal/server/runner_app.py @@ -11,6 +11,8 @@ from simple_di import inject from simple_di import Provide +from bentoml._internal.types import LazyType + from ..context import trace_context from ..context import component_context from ..runner.utils import Params @@ -19,6 +21,7 @@ from ..utils.metrics import metric_name from ..utils.metrics import exponential_buckets from ..server.base_app import BaseAppFactory +from ..runner.container import Payload from ..runner.container import AutoContainer from ..marshal.dispatcher import CorkDispatcher from ..configuration.containers import BentoMLContainer @@ -138,13 +141,10 @@ def routes(self) -> t.List[BaseRoute]: for method in self.runner.runner_methods: path = "/" if method.name == "__call__" else "/" + method.name if method.config.batchable: - _func = self.dispatchers[method.name]( - self._async_cork_run(runner_method=method) - ) routes.append( Route( path=path, - endpoint=_func, + endpoint=self._mk_request_handler(runner_method=method), methods=["POST"], ) ) @@ -205,32 +205,30 @@ def client_request_hook(span: Span, _scope: t.Dict[str, t.Any]) -> None: return middlewares - def _async_cork_run( + def _mk_request_handler( self, runner_method: RunnerMethod[t.Any, t.Any, t.Any], - ) -> t.Callable[[t.Collection[Request]], t.Coroutine[None, None, list[Response]]]: + ) -> t.Callable[[Request], t.Coroutine[None, None, Response]]: from starlette.responses import Response - async def _run(requests: t.Collection[Request]) -> list[Response]: - assert self._is_ready + server_str = f"BentoML-Runner/{self.runner.name}/{runner_method.name}/{self.worker_index}" - self.legacy_adaptive_batch_size_hist_map[runner_method.name].observe( - len(requests) + async def infer_batch( + params_list: t.Sequence[Params[t.Any]], + ) -> list[Payload] | list[tuple[Payload, ...]]: + self.legacy_adaptive_batch_size_hist_map[runner_method.name].observe( # type: ignore + len(params_list) ) - self.adaptive_batch_size_hist.labels( + self.adaptive_batch_size_hist.labels( # type: ignore runner_name=self.runner.name, worker_index=self.worker_index, method_name=runner_method.name, service_version=component_context.bento_version, service_name=component_context.bento_name, - ).observe(len(requests)) + ).observe(len(params_list)) - if not requests: + if not params_list: return [] - params_list: list[Params[t.Any]] = [] - for r in requests: - r_ = await r.body() - params_list.append(pickle.loads(r_)) input_batch_dim, output_batch_dim = runner_method.config.batch_dim @@ -242,29 +240,17 @@ async def _run(requests: t.Collection[Request]) -> list[Response]: *batched_params.args, **batched_params.kwargs ) - server_str = f"BentoML-Runner/{self.runner.name}/{runner_method.name}/{self.worker_index}" - # multiple output branch - if isinstance(batch_ret, tuple): + if LazyType["tuple[t.Any, ...]"](tuple).isinstance(batch_ret): output_num = len(batch_ret) - payloadss = [ + payloadss = tuple( AutoContainer.batch_to_payloads( batch_ret[idx], indices, batch_dim=output_batch_dim ) for idx in range(output_num) - ] - - return [ - Response( - pickle.dumps(payloads), - headers={ - PAYLOAD_META_HEADER: json.dumps({}), - "Content-Type": "application/vnd.bentoml.multiple_outputs", - "Server": server_str, - }, - ) - for payloads in zip(*payloadss) - ] + ) + ret = list(zip(*payloadss)) + return ret # single output branch payloads = AutoContainer.batch_to_payloads( @@ -272,20 +258,39 @@ async def _run(requests: t.Collection[Request]) -> list[Response]: indices, batch_dim=output_batch_dim, ) + return payloads - return [ - Response( - payload.data, + infer = self.dispatchers[runner_method.name](infer_batch) + + async def _request_handler(request: Request) -> Response: + assert self._is_ready + + r_: bytes = await request.body() + params: Params[t.Any] = pickle.loads(r_) + + payload = await infer(params) + + if not isinstance( + payload, Payload + ): # a tuple, which means user runnable has multiple outputs + return Response( + pickle.dumps(payload), headers={ - PAYLOAD_META_HEADER: json.dumps(payload.meta), - "Content-Type": f"application/vnd.bentoml.{payload.container}", + PAYLOAD_META_HEADER: json.dumps({}), + "Content-Type": "application/vnd.bentoml.multiple_outputs", "Server": server_str, }, ) - for payload in payloads - ] + return Response( + payload.data, + headers={ + PAYLOAD_META_HEADER: json.dumps(payload.meta), + "Content-Type": f"application/vnd.bentoml.{payload.container}", + "Server": server_str, + }, + ) - return _run + return _request_handler def async_run( self, @@ -302,7 +307,7 @@ async def _run(request: Request) -> Response: try: ret = await runner_method.async_run(*params.args, **params.kwargs) - except BaseException as exc: + except Exception as exc: # pylint: disable=broad-except logger.error( f"Exception on runner '{runner_method.runner.name}' method '{runner_method.name}'", exc_info=exc,