Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: grpc servicer implementation per version #3316

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/bentoml/_internal/bento/build_dev_bentoml_whl.py
Expand Up @@ -7,6 +7,7 @@
from ..utils.pkg import source_locations
from ...exceptions import BentoMLException
from ...exceptions import MissingDependencyException
from ...grpc.utils import LATEST_PROTOCOL_VERSION
from ..configuration import is_pypi_installed_bentoml

logger = logging.getLogger(__name__)
Expand All @@ -15,7 +16,7 @@


def build_bentoml_editable_wheel(
target_path: str, *, _internal_stubs_version: str = "v1"
target_path: str, *, _internal_protocol_version: str = LATEST_PROTOCOL_VERSION
) -> None:
"""
This is for BentoML developers to create Bentos that contains the local bentoml
Expand Down Expand Up @@ -52,10 +53,10 @@ def build_bentoml_editable_wheel(
bentoml_path = Path(module_location)

if not Path(
module_location, "grpc", _internal_stubs_version, "service_pb2.py"
module_location, "grpc", _internal_protocol_version, "service_pb2.py"
).exists():
raise ModuleNotFoundError(
f"Generated stubs for version {_internal_stubs_version} are missing. Make sure to run '{bentoml_path.as_posix()}/scripts/generate_grpc_stubs.sh {_internal_stubs_version}' beforehand to generate gRPC stubs."
f"Generated stubs for version {_internal_protocol_version} are missing. Make sure to run '{bentoml_path.as_posix()}/scripts/generate_grpc_stubs.sh {_internal_protocol_version}' beforehand to generate gRPC stubs."
) from None

# location to pyproject.toml
Expand Down
2 changes: 1 addition & 1 deletion src/bentoml/_internal/io_descriptors/base.py
Expand Up @@ -36,7 +36,7 @@
IOType = t.TypeVar("IOType")


def from_spec(spec: dict[str, str]) -> IODescriptor[t.Any]:
def from_spec(spec: dict[str, t.Any]) -> IODescriptor[t.Any]:
if "id" not in spec:
raise InvalidArgument(f"IO descriptor spec ({spec}) missing ID.")
return IO_DESCRIPTOR_REGISTRY[spec["id"]].from_spec(spec)
Expand Down
41 changes: 26 additions & 15 deletions src/bentoml/_internal/io_descriptors/json.py
Expand Up @@ -30,6 +30,7 @@

import pydantic
import pydantic.schema as schema
from google.protobuf import message as _message
from google.protobuf import struct_pb2
from typing_extensions import Self

Expand Down Expand Up @@ -392,19 +393,29 @@ async def to_proto(self, obj: JSONType) -> struct_pb2.Value:
if LazyType["pydantic.BaseModel"]("pydantic.BaseModel").isinstance(obj):
obj = obj.dict()
msg = struct_pb2.Value()
# To handle None cases.
if obj is not None:
from google.protobuf.json_format import ParseDict

if isinstance(obj, (dict, str, list, float, int, bool)):
# ParseDict handles google.protobuf.Struct type
# directly if given object has a supported type
ParseDict(obj, msg)
else:
# If given object doesn't have a supported type, we will
# use given JSON encoder to convert it to dictionary
# and then parse it to google.protobuf.Struct.
# Note that if a custom JSON encoder is used, it mustn't
# take any arguments.
ParseDict(self._json_encoder().default(obj), msg)
return parse_dict_to_proto(obj, msg, json_encoder=self._json_encoder)


def parse_dict_to_proto(
obj: JSONType,
msg: _message.Message,
json_encoder: type[json.JSONEncoder] = DefaultJsonEncoder,
) -> t.Any:
if obj is None:
# this function is an identity op for the msg if obj is None.
return msg

from google.protobuf.json_format import ParseDict

if isinstance(obj, (dict, str, list, float, int, bool)):
# ParseDict handles google.protobuf.Struct type
# directly if given object has a supported type
ParseDict(obj, msg)
else:
# If given object doesn't have a supported type, we will
# use given JSON encoder to convert it to dictionary
# and then parse it to google.protobuf.Struct.
# Note that if a custom JSON encoder is used, it mustn't
# take any arguments.
ParseDict(json_encoder().default(obj), msg)
return msg
3 changes: 1 addition & 2 deletions src/bentoml/_internal/server/grpc/__init__.py
@@ -1,4 +1,3 @@
from .server import Server
from .servicer import Servicer

__all__ = ["Server", "Servicer"]
__all__ = ["Server"]
22 changes: 14 additions & 8 deletions src/bentoml/_internal/server/grpc/server.py
Expand Up @@ -4,6 +4,7 @@
import sys
import typing as t
import asyncio
import inspect
import logging
from typing import TYPE_CHECKING
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -29,7 +30,8 @@

from bentoml.grpc.v1 import service_pb2_grpc as services

from .servicer import Servicer
from ..grpc_app import GrpcServicerFactory

else:
grpc, aio = import_grpc()
_, services = import_generated_stubs()
Expand Down Expand Up @@ -61,7 +63,7 @@ class Server(aio._server.Server):
@inject
def __init__(
self,
servicer: Servicer,
servicer: GrpcServicerFactory,
bind_address: str,
max_message_length: int
| None = Provide[BentoMLContainer.grpc.max_message_length],
Expand All @@ -88,10 +90,6 @@ def __init__(
self.ssl_keyfile = ssl_keyfile
self.ssl_ca_certs = ssl_ca_certs

if not bool(self.servicer):
self.servicer.load()
assert self.servicer.loaded

super().__init__(
# Note that the max_workers are used inside ThreadPoolExecutor.
# This ThreadPoolExecutor are used by aio.Server() to execute non-AsyncIO RPC handlers.
Expand Down Expand Up @@ -189,7 +187,11 @@ async def startup(self) -> None:
from bentoml.exceptions import MissingDependencyException

# Running on_startup callback.
await self.servicer.startup()
for handler in self.servicer.on_startup:
out = handler()
if inspect.isawaitable(out):
await out

# register bento servicer
services.add_BentoServiceServicer_to_server(self.servicer.bento_servicer, self)
services_health.add_HealthServicer_to_server(
Expand Down Expand Up @@ -236,7 +238,11 @@ async def startup(self) -> None:

async def shutdown(self):
# Running on_startup callback.
await self.servicer.shutdown()
for handler in self.servicer.on_shutdown:
out = handler()
if inspect.isawaitable(out):
await out

await self.stop(grace=self.graceful_shutdown_timeout)
await self.servicer.health_servicer.enter_graceful_shutdown()
self.loop.stop()
Empty file.
100 changes: 100 additions & 0 deletions src/bentoml/_internal/server/grpc/servicer/v1/__init__.py
@@ -0,0 +1,100 @@
from __future__ import annotations

import sys
import asyncio
import logging
from typing import TYPE_CHECKING

import anyio

from ......exceptions import InvalidArgument
from ......exceptions import BentoMLException
from ......grpc.utils import import_grpc
from ......grpc.utils import grpc_status_code
from ......grpc.utils import validate_proto_fields
from ......grpc.utils import import_generated_stubs

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from logging import _ExcInfoType as ExcInfoType # type: ignore (private warning)

import grpc

from ......grpc.v1 import service_pb2 as pb
from ......grpc.v1 import service_pb2_grpc as services
from ......grpc.types import BentoServicerContext
from .....service.service import Service
else:
grpc, _ = import_grpc()
pb, services = import_generated_stubs(version="v1")


def log_exception(request: pb.Request, exc_info: ExcInfoType) -> None:
# gRPC will always send a POST request.
logger.error("Exception on /%s [POST]", request.api_name, exc_info=exc_info)


def create_bento_servicer(service: Service) -> services.BentoServiceServicer:
"""
This is the actual implementation of BentoServicer.
Main inference entrypoint will be invoked via /bentoml.grpc.<version>.BentoService/Call
"""

class BentoServiceImpl(services.BentoServiceServicer):
"""An asyncio implementation of BentoService servicer."""

async def Call( # type: ignore (no async types) # pylint: disable=invalid-overridden-method
self: services.BentoServiceServicer,
request: pb.Request,
context: BentoServicerContext,
) -> pb.Response | None:
if request.api_name not in service.apis:
raise InvalidArgument(
f"given 'api_name' is not defined in {service.name}",
) from None

api = service.apis[request.api_name]
response = pb.Response()

# NOTE: since IODescriptor._proto_fields is a tuple, the order is preserved.
# This is important so that we know the order of fields to process.
# We will use fields descriptor to determine how to process that request.
try:
# we will check if the given fields list contains a pb.Multipart.
input_proto = getattr(
request,
validate_proto_fields(request.WhichOneof("content"), api.input),
)
input_data = await api.input.from_proto(input_proto)
if asyncio.iscoroutinefunction(api.func):
if api.multi_input:
output = await api.func(**input_data)
else:
output = await api.func(input_data)
else:
if api.multi_input:
output = await anyio.to_thread.run_sync(api.func, **input_data)
else:
output = await anyio.to_thread.run_sync(api.func, input_data)
res = await api.output.to_proto(output)
# TODO(aarnphm): support multiple proto fields
response = pb.Response(**{api.output._proto_fields[0]: res})
except BentoMLException as e:
log_exception(request, sys.exc_info())
await context.abort(code=grpc_status_code(e), details=e.message)
except (RuntimeError, TypeError, NotImplementedError):
log_exception(request, sys.exc_info())
await context.abort(
code=grpc.StatusCode.INTERNAL,
details="A runtime error has occurred, see stacktrace from logs.",
)
except Exception: # pylint: disable=broad-except
log_exception(request, sys.exc_info())
await context.abort(
code=grpc.StatusCode.INTERNAL,
details="An error has occurred in BentoML user code when handling this request, find the error details in server logs.",
)
return response

return BentoServiceImpl()