diff --git a/src/bentoml/_internal/io_descriptors/base.py b/src/bentoml/_internal/io_descriptors/base.py index c64b8034009..f6c03e0f1e1 100644 --- a/src/bentoml/_internal/io_descriptors/base.py +++ b/src/bentoml/_internal/io_descriptors/base.py @@ -30,6 +30,7 @@ ) OpenAPIResponse = dict[str, str | dict[str, MediaType] | dict[str, t.Any]] + IO_DESCRIPTOR_REGISTRY: dict[str, type[IODescriptor[t.Any]]] = {} IOType = t.TypeVar("IOType") diff --git a/src/bentoml/_internal/server/grpc/servicer.py b/src/bentoml/_internal/server/grpc/servicer.py deleted file mode 100644 index e5e7209b8f1..00000000000 --- a/src/bentoml/_internal/server/grpc/servicer.py +++ /dev/null @@ -1,269 +0,0 @@ -from __future__ import annotations - -import sys -import typing as t -import asyncio -import logging -from typing import TYPE_CHECKING -from inspect import isawaitable - -import anyio - -from ...utils import LazyLoader -from ....exceptions import InvalidArgument -from ....exceptions import BentoMLException -from ....grpc.utils import import_grpc -from ....grpc.utils import grpc_status_code -from ....grpc.utils import validate_proto_fields -from ....grpc.utils import import_generated_stubs -from ....grpc.utils import LATEST_PROTOCOL_VERSION - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from logging import _ExcInfoType as ExcInfoType # type: ignore (private warning) - - import grpc - from grpc import aio - from grpc_health.v1 import health - from google.protobuf import struct_pb2 - from typing_extensions import Self - - from bentoml.grpc.types import Interceptors - from bentoml.grpc.types import AddServicerFn - from bentoml.grpc.types import ServicerClass - from bentoml.grpc.types import BentoServicerContext - - from ...service.service import Service - - if LATEST_PROTOCOL_VERSION == "v1": - from bentoml.grpc.v1 import service_pb2 as pb - from bentoml.grpc.v1 import service_pb2_grpc as services - from bentoml.grpc.v1.service_pb2 import ServiceMetadataRequest - from bentoml.grpc.v1.service_pb2 import ServiceMetadataResponse - else: - from bentoml.grpc.v1alpha1 import service_pb2 as pb - from bentoml.grpc.v1alpha1 import service_pb2_grpc as services - -else: - grpc, aio = import_grpc() - health = LazyLoader( - "health", - globals(), - "grpc_health.v1.health", - exc_msg="'grpcio-health-checking' is required for using health checking endpoints. Install with 'pip install grpcio-health-checking'.", - ) - struct_pb2 = LazyLoader("struct_pb2", globals(), "google.protobuf.struct_pb2") - - -def log_exception(request: pb.Request, exc_info: ExcInfoType) -> None: - # gRPC will always send a POST request. - logger.error("Exception on /%s [POST]", request.api_name, exc_info=exc_info) - - -class Servicer: - """Create an instance of gRPC Servicer.""" - - def __init__( - self: Self, - service: Service, - on_startup: t.Sequence[t.Callable[[], t.Any]] | None = None, - on_shutdown: t.Sequence[t.Callable[[], t.Any]] | None = None, - mount_servicers: t.Sequence[tuple[ServicerClass, AddServicerFn, list[str]]] - | None = None, - interceptors: Interceptors | None = None, - protocol_version: str = LATEST_PROTOCOL_VERSION, - ) -> None: - self.bento_service = service - - self.on_startup = [] if not on_startup else list(on_startup) - self.on_shutdown = [] if not on_shutdown else list(on_shutdown) - self.mount_servicers = [] if not mount_servicers else list(mount_servicers) - self.interceptors = [] if not interceptors else list(interceptors) - self.loaded = False - self.protocol_version = protocol_version - - def load(self): - pb, _ = import_generated_stubs(self.protocol_version) - assert not self.loaded - - self.interceptors_stack = self.build_interceptors_stack() - - self.bento_servicer = create_bento_servicer( - self.bento_service, protocol_version=self.protocol_version - ) - - # Create a health check servicer. We use the non-blocking implementation - # to avoid thread starvation. - self.health_servicer = health.aio.HealthServicer() - - self.service_names = tuple( - service.full_name for service in pb.DESCRIPTOR.services_by_name.values() - ) + (health.SERVICE_NAME,) - self.loaded = True - - def build_interceptors_stack(self) -> list[aio.ServerInterceptor]: - return list(map(lambda x: x(), self.interceptors)) - - async def startup(self): - for handler in self.on_startup: - out = handler() - if isawaitable(out): - await out - - async def shutdown(self): - for handler in self.on_shutdown: - out = handler() - if isawaitable(out): - await out - - def __bool__(self): - return self.loaded - - -def create_bento_servicer( - service: Service, protocol_version: str = LATEST_PROTOCOL_VERSION -) -> services.BentoServiceServicer: - """ - This is the actual implementation of BentoServicer. - Main inference entrypoint will be invoked via /bentoml.grpc..BentoService/Call - """ - if protocol_version == "v1": - from bentoml.grpc.v1 import service_pb2 as pb - from bentoml.grpc.v1.service_pb2_grpc import BentoServiceServicer - else: - from bentoml.grpc.v1alpha1 import service_pb2 as pb - from bentoml.grpc.v1alpha1.service_pb2_grpc import BentoServiceServicer - - attrs: dict[str, t.Any] = { - "__doc__": "An asyncio implementation of BentoService servicer." - } - - async def Call( - _: BentoServiceServicer, - request: pb.Request, - context: BentoServicerContext, - ) -> pb.Response | None: - if request.api_name not in service.apis: - raise InvalidArgument( - f"given 'api_name' is not defined in {service.name}", - ) from None - - api = service.apis[request.api_name] - response = pb.Response() - - # NOTE: since IODescriptor._proto_fields is a tuple, the order is preserved. - # This is important so that we know the order of fields to process. - # We will use fields descriptor to determine how to process that request. - try: - # we will check if the given fields list contains a pb.Multipart. - input_proto = getattr( - request, - validate_proto_fields(request.WhichOneof("content"), api.input), - ) - input_data = await api.input.from_proto(input_proto) - if asyncio.iscoroutinefunction(api.func): - if api.multi_input: - output = await api.func(**input_data) - else: - output = await api.func(input_data) - else: - if api.multi_input: - output = await anyio.to_thread.run_sync(api.func, **input_data) - else: - output = await anyio.to_thread.run_sync(api.func, input_data) - res = await api.output.to_proto(output) - # TODO(aarnphm): support multiple proto fields - response = pb.Response(**{api.output._proto_fields[0]: res}) - except BentoMLException as e: - log_exception(request, sys.exc_info()) - await context.abort(code=grpc_status_code(e), details=e.message) - except (RuntimeError, TypeError, NotImplementedError): - log_exception(request, sys.exc_info()) - await context.abort( - code=grpc.StatusCode.INTERNAL, - details="A runtime error has occurred, see stacktrace from logs.", - ) - except Exception: # pylint: disable=broad-except - log_exception(request, sys.exc_info()) - await context.abort( - code=grpc.StatusCode.INTERNAL, - details="An error has occurred in BentoML user code when handling this request, find the error details in server logs.", - ) - return response - - attrs.setdefault("Call", Call) - - if protocol_version == "v1": - # "v1" introduces ServiceMetadata to send in bentoml.Service information. - from bentoml.grpc.v1.service_pb2 import ServiceMetadataResponse - - async def ServiceMetadata( - _: BentoServiceServicer, - request: ServiceMetadataRequest, # pylint: disable=unused-argument - context: BentoServicerContext, # pylint: disable=unused-argument - ) -> ServiceMetadataResponse: - return ServiceMetadataResponse( - name=service.name, - docs=service.doc, - apis=[ - ServiceMetadataResponse.InferenceAPI( - name=api.name, - docs=api.doc, - input=make_descriptor_spec( - api.input.to_spec(), ServiceMetadataResponse - ), - output=make_descriptor_spec( - api.output.to_spec(), ServiceMetadataResponse - ), - ) - for api in service.apis.values() - ], - ) - - attrs.setdefault("ServiceMetadata", ServiceMetadata) - - if TYPE_CHECKING: - # NOTE: typeshed only accept type expression for type() class creation. - # Hence, pyright will raise an error if we only pass in BentoServiceServicer, as it won't - # acknowledge BentoServiceServicer as a type expression. - BentoServiceServicerT = type(BentoServiceServicer) - else: - BentoServiceServicerT = BentoServiceServicer - - return type("BentoServiceImpl", (BentoServiceServicerT,), attrs)() - - -if TYPE_CHECKING: - NestedDictStrAny = dict[str, dict[str, t.Any] | t.Any] - TupleAny = tuple[t.Any, ...] - - -def _tuple_converter(d: NestedDictStrAny | None) -> NestedDictStrAny | None: - # handles case for struct_pb2.Value where nested items are tuple. - # if that is the case, then convert to list. - # This dict is only one level deep, as we don't allow nested Multipart. - if d is not None: - for key, value in d.items(): - if isinstance(value, tuple): - d[key] = list(t.cast("TupleAny", value)) - elif isinstance(value, dict): - d[key] = _tuple_converter(t.cast("NestedDictStrAny", value)) - return d - - -def make_descriptor_spec( - spec: dict[str, t.Any], pb: type[ServiceMetadataResponse] -) -> ServiceMetadataResponse.DescriptorMetadata: - from ...io_descriptors.json import parse_dict_to_proto - - return pb.DescriptorMetadata( - descriptor_id=spec["id"], - attributes=struct_pb2.Struct( - fields={ - "args": parse_dict_to_proto( - _tuple_converter(spec.get("args", None)), struct_pb2.Value() - ) - } - ), - ) diff --git a/src/bentoml/_internal/server/grpc/servicer/__init__.py b/src/bentoml/_internal/server/grpc/servicer/__init__.py new file mode 100644 index 00000000000..058d1ca7062 --- /dev/null +++ b/src/bentoml/_internal/server/grpc/servicer/__init__.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +import typing as t +import logging +import importlib +from typing import TYPE_CHECKING +from inspect import isawaitable + +from ....utils import LazyLoader +from .....grpc.utils import import_generated_stubs +from .....grpc.utils import LATEST_PROTOCOL_VERSION + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from types import ModuleType + + from grpc import aio + from grpc_health.v1 import health + from typing_extensions import Self + + from bentoml.grpc.v1 import service_pb2_grpc as services + from bentoml.grpc.types import Interceptors + from bentoml.grpc.types import AddServicerFn + from bentoml.grpc.types import ServicerClass + + from ....service.service import Service + + class ServicerModule(ModuleType): + @staticmethod + def create_bento_servicer(service: Service) -> services.BentoServiceServicer: + ... + +else: + health = LazyLoader( + "health", + globals(), + "grpc_health.v1.health", + exc_msg="'grpcio-health-checking' is required for using health checking endpoints. Install with 'pip install grpcio-health-checking'.", + ) + struct_pb2 = LazyLoader("struct_pb2", globals(), "google.protobuf.struct_pb2") + + +class Servicer: + """Create an instance of gRPC Servicer.""" + + _cached_module = None + + def __init__( + self: Self, + service: Service, + on_startup: t.Sequence[t.Callable[[], t.Any]] | None = None, + on_shutdown: t.Sequence[t.Callable[[], t.Any]] | None = None, + mount_servicers: t.Sequence[tuple[ServicerClass, AddServicerFn, list[str]]] + | None = None, + interceptors: Interceptors | None = None, + protocol_version: str = LATEST_PROTOCOL_VERSION, + ) -> None: + self.bento_service = service + + self.on_startup = [] if not on_startup else list(on_startup) + self.on_shutdown = [] if not on_shutdown else list(on_shutdown) + self.mount_servicers = [] if not mount_servicers else list(mount_servicers) + self.interceptors = [] if not interceptors else list(interceptors) + self.protocol_version = protocol_version + + self.loaded = False + + @property + def _servicer_module(self) -> ServicerModule: + if self._cached_module is None: + object.__setattr__( + self, + "_cached_module", + importlib.import_module(f".{self.protocol_version}", package=__name__), + ) + assert self._cached_module is not None + return self._cached_module + + def load(self): + assert not self.loaded + + pb, _ = import_generated_stubs(self.protocol_version) + + self.interceptors_stack = self.build_interceptors_stack() + + self.bento_servicer = self._servicer_module.create_bento_servicer( + self.bento_service + ) + + # Create a health check servicer. We use the non-blocking implementation + # to avoid thread starvation. + self.health_servicer = health.aio.HealthServicer() + + self.service_names = tuple( + service.full_name for service in pb.DESCRIPTOR.services_by_name.values() + ) + (health.SERVICE_NAME,) + self.loaded = True + + def build_interceptors_stack(self) -> list[aio.ServerInterceptor]: + return list(map(lambda x: x(), self.interceptors)) + + async def startup(self): + for handler in self.on_startup: + out = handler() + if isawaitable(out): + await out + + async def shutdown(self): + for handler in self.on_shutdown: + out = handler() + if isawaitable(out): + await out + + def __bool__(self): + return self.loaded diff --git a/src/bentoml/_internal/server/grpc/servicer/v1/__init__.py b/src/bentoml/_internal/server/grpc/servicer/v1/__init__.py new file mode 100644 index 00000000000..f453d09a62a --- /dev/null +++ b/src/bentoml/_internal/server/grpc/servicer/v1/__init__.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +import sys +import typing as t +import asyncio +import logging +from typing import TYPE_CHECKING + +import anyio + +from .....utils import LazyLoader +from ......exceptions import InvalidArgument +from ......exceptions import BentoMLException +from ......grpc.utils import import_grpc +from ......grpc.utils import grpc_status_code +from ......grpc.utils import validate_proto_fields +from ......grpc.utils import import_generated_stubs + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from logging import _ExcInfoType as ExcInfoType # type: ignore (private warning) + + import grpc + from google.protobuf import struct_pb2 + + from bentoml.grpc.types import BentoServicerContext + from bentoml.grpc.v1.service_pb2 import ServiceMetadataRequest + from bentoml.grpc.v1.service_pb2 import ServiceMetadataResponse + + from ......grpc.v1 import service_pb2 as pb + from ......grpc.v1 import service_pb2_grpc as services + from .....service.service import Service +else: + grpc, _ = import_grpc() + pb, services = import_generated_stubs(version="v1") + struct_pb2 = LazyLoader("struct_pb2", globals(), "google.protobuf.struct_pb2") + + +def log_exception(request: pb.Request, exc_info: ExcInfoType) -> None: + # gRPC will always send a POST request. + logger.error("Exception on /%s [POST]", request.api_name, exc_info=exc_info) + + +def create_bento_servicer(service: Service) -> services.BentoServiceServicer: + """ + This is the actual implementation of BentoServicer. + Main inference entrypoint will be invoked via /bentoml.grpc..BentoService/Call + """ + + class BentoServiceImpl(services.BentoServiceServicer): + """An asyncio implementation of BentoService servicer.""" + + async def Call( # type: ignore (no async types) # pylint: disable=invalid-overridden-method + self: services.BentoServiceServicer, + request: pb.Request, + context: BentoServicerContext, + ) -> pb.Response | None: + if request.api_name not in service.apis: + raise InvalidArgument( + f"given 'api_name' is not defined in {service.name}", + ) from None + + api = service.apis[request.api_name] + response = pb.Response() + + # NOTE: since IODescriptor._proto_fields is a tuple, the order is preserved. + # This is important so that we know the order of fields to process. + # We will use fields descriptor to determine how to process that request. + try: + # we will check if the given fields list contains a pb.Multipart. + input_proto = getattr( + request, + validate_proto_fields(request.WhichOneof("content"), api.input), + ) + input_data = await api.input.from_proto(input_proto) + if asyncio.iscoroutinefunction(api.func): + if api.multi_input: + output = await api.func(**input_data) + else: + output = await api.func(input_data) + else: + if api.multi_input: + output = await anyio.to_thread.run_sync(api.func, **input_data) + else: + output = await anyio.to_thread.run_sync(api.func, input_data) + res = await api.output.to_proto(output) + # TODO(aarnphm): support multiple proto fields + response = pb.Response(**{api.output._proto_fields[0]: res}) + except BentoMLException as e: + log_exception(request, sys.exc_info()) + await context.abort(code=grpc_status_code(e), details=e.message) + except (RuntimeError, TypeError, NotImplementedError): + log_exception(request, sys.exc_info()) + await context.abort( + code=grpc.StatusCode.INTERNAL, + details="A runtime error has occurred, see stacktrace from logs.", + ) + except Exception: # pylint: disable=broad-except + log_exception(request, sys.exc_info()) + await context.abort( + code=grpc.StatusCode.INTERNAL, + details="An error has occurred in BentoML user code when handling this request, find the error details in server logs.", + ) + return response + + async def ServiceMetadata( # type: ignore + self: services.BentoServiceServicer, + request: ServiceMetadataRequest, # pylint: disable=unused-argument + context: BentoServicerContext, # pylint: disable=unused-argument + ) -> ServiceMetadataResponse: + return pb.ServiceMetadataResponse( + name=service.name, + docs=service.doc, + apis=[ + pb.ServiceMetadataResponse.InferenceAPI( + name=api.name, + docs=api.doc, + input=make_descriptor_spec( + api.input.to_spec(), pb.ServiceMetadataResponse + ), + output=make_descriptor_spec( + api.output.to_spec(), pb.ServiceMetadataResponse + ), + ) + for api in service.apis.values() + ], + ) + + return BentoServiceImpl() + + +if TYPE_CHECKING: + NestedDictStrAny = dict[str, dict[str, t.Any] | t.Any] + TupleAny = tuple[t.Any, ...] + + +def _tuple_converter(d: NestedDictStrAny | None) -> NestedDictStrAny | None: + # handles case for struct_pb2.Value where nested items are tuple. + # if that is the case, then convert to list. + # This dict is only one level deep, as we don't allow nested Multipart. + if d is not None: + for key, value in d.items(): + if isinstance(value, tuple): + d[key] = list(t.cast("TupleAny", value)) + elif isinstance(value, dict): + d[key] = _tuple_converter(t.cast("NestedDictStrAny", value)) + return d + + +def make_descriptor_spec( + spec: dict[str, t.Any], pb: type[ServiceMetadataResponse] +) -> ServiceMetadataResponse.DescriptorMetadata: + from .....io_descriptors.json import parse_dict_to_proto + + descriptor_id = spec.pop("id") + return pb.DescriptorMetadata( + descriptor_id=descriptor_id, + attributes=struct_pb2.Struct( + fields={ + key: parse_dict_to_proto(_tuple_converter(value), struct_pb2.Value()) + for key, value in spec.items() + } + ), + ) diff --git a/src/bentoml/_internal/server/grpc/servicer/v1alpha1/__init__.py b/src/bentoml/_internal/server/grpc/servicer/v1alpha1/__init__.py new file mode 100644 index 00000000000..33688f14c93 --- /dev/null +++ b/src/bentoml/_internal/server/grpc/servicer/v1alpha1/__init__.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import sys +import asyncio +import logging +from typing import TYPE_CHECKING + +import anyio + +from ......exceptions import InvalidArgument +from ......exceptions import BentoMLException +from ......grpc.utils import import_grpc +from ......grpc.utils import grpc_status_code +from ......grpc.utils import validate_proto_fields +from ......grpc.utils import import_generated_stubs + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from logging import _ExcInfoType as ExcInfoType # type: ignore (private warning) + + import grpc + + from bentoml.grpc.types import BentoServicerContext + + from ......grpc.v1alpha1 import service_pb2 as pb + from ......grpc.v1alpha1 import service_pb2_grpc as services + from .....service.service import Service +else: + grpc, _ = import_grpc() + pb, services = import_generated_stubs(version="v1alpha1") + + +def log_exception(request: pb.Request, exc_info: ExcInfoType) -> None: + # gRPC will always send a POST request. + logger.error("Exception on /%s [POST]", request.api_name, exc_info=exc_info) + + +def create_bento_servicer(service: Service) -> services.BentoServiceServicer: + """ + This is the actual implementation of BentoServicer. + Main inference entrypoint will be invoked via /bentoml.grpc..BentoService/Call + """ + + class BentoServiceImpl(services.BentoServiceServicer): + """An asyncio implementation of BentoService servicer.""" + + async def Call( # type: ignore (no async types) # pylint: disable=invalid-overridden-method + self: services.BentoServiceServicer, + request: pb.Request, + context: BentoServicerContext, + ) -> pb.Response | None: + if request.api_name not in service.apis: + raise InvalidArgument( + f"given 'api_name' is not defined in {service.name}", + ) from None + + api = service.apis[request.api_name] + response = pb.Response() + + # NOTE: since IODescriptor._proto_fields is a tuple, the order is preserved. + # This is important so that we know the order of fields to process. + # We will use fields descriptor to determine how to process that request. + try: + # we will check if the given fields list contains a pb.Multipart. + input_proto = getattr( + request, + validate_proto_fields(request.WhichOneof("content"), api.input), + ) + input_data = await api.input.from_proto(input_proto) + if asyncio.iscoroutinefunction(api.func): + if api.multi_input: + output = await api.func(**input_data) + else: + output = await api.func(input_data) + else: + if api.multi_input: + output = await anyio.to_thread.run_sync(api.func, **input_data) + else: + output = await anyio.to_thread.run_sync(api.func, input_data) + res = await api.output.to_proto(output) + # TODO(aarnphm): support multiple proto fields + response = pb.Response(**{api.output._proto_fields[0]: res}) + except BentoMLException as e: + log_exception(request, sys.exc_info()) + await context.abort(code=grpc_status_code(e), details=e.message) + except (RuntimeError, TypeError, NotImplementedError): + log_exception(request, sys.exc_info()) + await context.abort( + code=grpc.StatusCode.INTERNAL, + details="A runtime error has occurred, see stacktrace from logs.", + ) + except Exception: # pylint: disable=broad-except + log_exception(request, sys.exc_info()) + await context.abort( + code=grpc.StatusCode.INTERNAL, + details="An error has occurred in BentoML user code when handling this request, find the error details in server logs.", + ) + return response + + return BentoServiceImpl() diff --git a/src/bentoml/_internal/service/inference_api.py b/src/bentoml/_internal/service/inference_api.py index bffd6126965..9fb9d98b609 100644 --- a/src/bentoml/_internal/service/inference_api.py +++ b/src/bentoml/_internal/service/inference_api.py @@ -26,7 +26,7 @@ class InferenceAPI: def __init__( self, - user_defined_callback: t.Callable[..., t.Any], + user_defined_callback: t.Callable[..., t.Any] | None, input_descriptor: IODescriptor[t.Any], output_descriptor: IODescriptor[t.Any], name: Optional[str],