diff --git a/bentoml/_internal/marshal/dispatcher.py b/bentoml/_internal/marshal/dispatcher.py index 25e8ba5e70..3a3b7dca5e 100644 --- a/bentoml/_internal/marshal/dispatcher.py +++ b/bentoml/_internal/marshal/dispatcher.py @@ -149,7 +149,7 @@ def _wake_event(self): def __call__( self, callback: t.Callable[ - [t.Collection[T_IN]], t.Coroutine[None, None, t.Collection[T_OUT]] + [t.Sequence[T_IN]], t.Coroutine[None, None, t.Sequence[T_OUT]] ], ) -> t.Callable[[T_IN], t.Coroutine[None, None, T_OUT]]: self.callback = callback diff --git a/bentoml/_internal/server/runner_app.py b/bentoml/_internal/server/runner_app.py index e7fa1f29f0..b3f86fc81e 100644 --- a/bentoml/_internal/server/runner_app.py +++ b/bentoml/_internal/server/runner_app.py @@ -35,6 +35,7 @@ from ..runner.runner import Runner from ..runner.runner import RunnerMethod + from ..runner.container import Payload class RunnerAppFactory(BaseAppFactory): @@ -138,13 +139,10 @@ def routes(self) -> t.List[BaseRoute]: for method in self.runner.runner_methods: path = "/" if method.name == "__call__" else "/" + method.name if method.config.batchable: - _func = self.dispatchers[method.name]( - self._async_cork_run(runner_method=method) - ) routes.append( Route( path=path, - endpoint=_func, + endpoint=self._mk_request_handler(runner_method=method), methods=["POST"], ) ) @@ -205,32 +203,26 @@ def client_request_hook(span: Span, _scope: t.Dict[str, t.Any]) -> None: return middlewares - def _async_cork_run( + def _mk_request_handler( self, runner_method: RunnerMethod[t.Any, t.Any, t.Any], - ) -> t.Callable[[t.Collection[Request]], t.Coroutine[None, None, list[Response]]]: + ) -> t.Callable[[Request], t.Coroutine[None, None, Response]]: from starlette.responses import Response - async def _run(requests: t.Collection[Request]) -> list[Response]: - assert self._is_ready - - self.legacy_adaptive_batch_size_hist_map[runner_method.name].observe( - len(requests) + async def infer_batch(params_list: t.Sequence[Params[t.Any]]) -> list[Payload]: + self.legacy_adaptive_batch_size_hist_map[runner_method.name].observe( # type: ignore + len(params_list) ) - self.adaptive_batch_size_hist.labels( + self.adaptive_batch_size_hist.labels( # type: ignore runner_name=self.runner.name, worker_index=self.worker_index, method_name=runner_method.name, service_version=component_context.bento_version, service_name=component_context.bento_name, - ).observe(len(requests)) + ).observe(len(params_list)) - if not requests: + if not params_list: return [] - params_list: list[Params[t.Any]] = [] - for r in requests: - r_ = await r.body() - params_list.append(pickle.loads(r_)) input_batch_dim, output_batch_dim = runner_method.config.batch_dim @@ -247,20 +239,28 @@ async def _run(requests: t.Collection[Request]) -> list[Response]: indices, batch_dim=output_batch_dim, ) + return payloads - return [ - Response( - payload.data, - headers={ - PAYLOAD_META_HEADER: json.dumps(payload.meta), - "Content-Type": f"application/vnd.bentoml.{payload.container}", - "Server": f"BentoML-Runner/{self.runner.name}/{runner_method.name}/{self.worker_index}", - }, - ) - for payload in payloads - ] + infer = self.dispatchers[runner_method.name](infer_batch) - return _run + async def _request_handler(request: Request) -> Response: + assert self._is_ready + + r_: bytes = await request.body() + params: Params[t.Any] = pickle.loads(r_) + + payload = await infer(params) + + return Response( + payload.data, + headers={ + PAYLOAD_META_HEADER: json.dumps(payload.meta), + "Content-Type": f"application/vnd.bentoml.{payload.container}", + "Server": f"BentoML-Runner/{self.runner.name}/{runner_method.name}/{self.worker_index}", + }, + ) + + return _request_handler def async_run( self, @@ -277,7 +277,7 @@ async def _run(request: Request) -> Response: try: ret = await runner_method.async_run(*params.args, **params.kwargs) - except BaseException as exc: + except Exception as exc: # pylint: disable=broad-except logger.error( f"Exception on runner '{runner_method.runner.name}' method '{runner_method.name}'", exc_info=exc,