Skip to content

Commit

Permalink
feat: grpc servicer implementation per version
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
aarnphm committed Dec 6, 2022
1 parent 4834a1f commit 7ddf489
Show file tree
Hide file tree
Showing 16 changed files with 307 additions and 163 deletions.
7 changes: 4 additions & 3 deletions src/bentoml/_internal/bento/build_dev_bentoml_whl.py
Expand Up @@ -7,6 +7,7 @@
from ..utils.pkg import source_locations
from ...exceptions import BentoMLException
from ...exceptions import MissingDependencyException
from ...grpc.utils import LATEST_PROTOCOL_VERSION
from ..configuration import is_pypi_installed_bentoml

logger = logging.getLogger(__name__)
Expand All @@ -15,7 +16,7 @@


def build_bentoml_editable_wheel(
target_path: str, *, _internal_stubs_version: str = "v1"
target_path: str, *, _internal_protocol_version: str = LATEST_PROTOCOL_VERSION
) -> None:
"""
This is for BentoML developers to create Bentos that contains the local bentoml
Expand Down Expand Up @@ -52,10 +53,10 @@ def build_bentoml_editable_wheel(
bentoml_path = Path(module_location)

if not Path(
module_location, "grpc", _internal_stubs_version, "service_pb2.py"
module_location, "grpc", _internal_protocol_version, "service_pb2.py"
).exists():
raise ModuleNotFoundError(
f"Generated stubs for version {_internal_stubs_version} are missing. Make sure to run '{bentoml_path.as_posix()}/scripts/generate_grpc_stubs.sh {_internal_stubs_version}' beforehand to generate gRPC stubs."
f"Generated stubs for version {_internal_protocol_version} are missing. Make sure to run '{bentoml_path.as_posix()}/scripts/generate_grpc_stubs.sh {_internal_protocol_version}' beforehand to generate gRPC stubs."
) from None

# location to pyproject.toml
Expand Down
2 changes: 1 addition & 1 deletion src/bentoml/_internal/io_descriptors/base.py
Expand Up @@ -36,7 +36,7 @@
IOType = t.TypeVar("IOType")


def from_spec(spec: dict[str, str]) -> IODescriptor[t.Any]:
def from_spec(spec: dict[str, t.Any]) -> IODescriptor[t.Any]:
if "id" not in spec:
raise InvalidArgument(f"IO descriptor spec ({spec}) missing ID.")
return IO_DESCRIPTOR_REGISTRY[spec["id"]].from_spec(spec)
Expand Down
41 changes: 26 additions & 15 deletions src/bentoml/_internal/io_descriptors/json.py
Expand Up @@ -30,6 +30,7 @@

import pydantic
import pydantic.schema as schema
from google.protobuf import message as _message
from google.protobuf import struct_pb2
from typing_extensions import Self

Expand Down Expand Up @@ -392,19 +393,29 @@ async def to_proto(self, obj: JSONType) -> struct_pb2.Value:
if LazyType["pydantic.BaseModel"]("pydantic.BaseModel").isinstance(obj):
obj = obj.dict()
msg = struct_pb2.Value()
# To handle None cases.
if obj is not None:
from google.protobuf.json_format import ParseDict

if isinstance(obj, (dict, str, list, float, int, bool)):
# ParseDict handles google.protobuf.Struct type
# directly if given object has a supported type
ParseDict(obj, msg)
else:
# If given object doesn't have a supported type, we will
# use given JSON encoder to convert it to dictionary
# and then parse it to google.protobuf.Struct.
# Note that if a custom JSON encoder is used, it mustn't
# take any arguments.
ParseDict(self._json_encoder().default(obj), msg)
return parse_dict_to_proto(obj, msg, json_encoder=self._json_encoder)


def parse_dict_to_proto(
obj: JSONType,
msg: _message.Message,
json_encoder: type[json.JSONEncoder] = DefaultJsonEncoder,
) -> t.Any:
if obj is None:
# this function is an identity op for the msg if obj is None.
return msg

from google.protobuf.json_format import ParseDict

if isinstance(obj, (dict, str, list, float, int, bool)):
# ParseDict handles google.protobuf.Struct type
# directly if given object has a supported type
ParseDict(obj, msg)
else:
# If given object doesn't have a supported type, we will
# use given JSON encoder to convert it to dictionary
# and then parse it to google.protobuf.Struct.
# Note that if a custom JSON encoder is used, it mustn't
# take any arguments.
ParseDict(json_encoder().default(obj), msg)
return msg
3 changes: 1 addition & 2 deletions src/bentoml/_internal/server/grpc/__init__.py
@@ -1,4 +1,3 @@
from .server import Server
from .servicer import Servicer

__all__ = ["Server", "Servicer"]
__all__ = ["Server"]
22 changes: 14 additions & 8 deletions src/bentoml/_internal/server/grpc/server.py
Expand Up @@ -4,6 +4,7 @@
import sys
import typing as t
import asyncio
import inspect
import logging
from typing import TYPE_CHECKING
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -29,7 +30,8 @@

from bentoml.grpc.v1 import service_pb2_grpc as services

from .servicer import Servicer
from ..grpc_app import GrpcServicerFactory

else:
grpc, aio = import_grpc()
_, services = import_generated_stubs()
Expand Down Expand Up @@ -61,7 +63,7 @@ class Server(aio._server.Server):
@inject
def __init__(
self,
servicer: Servicer,
servicer: GrpcServicerFactory,
bind_address: str,
max_message_length: int
| None = Provide[BentoMLContainer.grpc.max_message_length],
Expand All @@ -88,10 +90,6 @@ def __init__(
self.ssl_keyfile = ssl_keyfile
self.ssl_ca_certs = ssl_ca_certs

if not bool(self.servicer):
self.servicer.load()
assert self.servicer.loaded

super().__init__(
# Note that the max_workers are used inside ThreadPoolExecutor.
# This ThreadPoolExecutor are used by aio.Server() to execute non-AsyncIO RPC handlers.
Expand Down Expand Up @@ -189,7 +187,11 @@ async def startup(self) -> None:
from bentoml.exceptions import MissingDependencyException

# Running on_startup callback.
await self.servicer.startup()
for handler in self.servicer.on_startup:
out = handler()
if inspect.isawaitable(out):
await out

# register bento servicer
services.add_BentoServiceServicer_to_server(self.servicer.bento_servicer, self)
services_health.add_HealthServicer_to_server(
Expand Down Expand Up @@ -236,7 +238,11 @@ async def startup(self) -> None:

async def shutdown(self):
# Running on_startup callback.
await self.servicer.shutdown()
for handler in self.servicer.on_shutdown:
out = handler()
if inspect.isawaitable(out):
await out

await self.stop(grace=self.graceful_shutdown_timeout)
await self.servicer.health_servicer.enter_graceful_shutdown()
self.loop.stop()
Empty file.
101 changes: 101 additions & 0 deletions src/bentoml/_internal/server/grpc/servicer/v1/__init__.py
@@ -0,0 +1,101 @@
from __future__ import annotations

import sys
import asyncio
import logging
from typing import TYPE_CHECKING

import anyio

from ......exceptions import InvalidArgument
from ......exceptions import BentoMLException
from ......grpc.utils import import_grpc
from ......grpc.utils import grpc_status_code
from ......grpc.utils import validate_proto_fields
from ......grpc.utils import import_generated_stubs

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from logging import _ExcInfoType as ExcInfoType # type: ignore (private warning)

import grpc

from bentoml.grpc.types import BentoServicerContext

from ......grpc.v1 import service_pb2 as pb
from ......grpc.v1 import service_pb2_grpc as services
from .....service.service import Service
else:
grpc, _ = import_grpc()
pb, services = import_generated_stubs(version="v1")


def log_exception(request: pb.Request, exc_info: ExcInfoType) -> None:
# gRPC will always send a POST request.
logger.error("Exception on /%s [POST]", request.api_name, exc_info=exc_info)


def create_bento_servicer(service: Service) -> services.BentoServiceServicer:
"""
This is the actual implementation of BentoServicer.
Main inference entrypoint will be invoked via /bentoml.grpc.<version>.BentoService/Call
"""

class BentoServiceImpl(services.BentoServiceServicer):
"""An asyncio implementation of BentoService servicer."""

async def Call( # type: ignore (no async types) # pylint: disable=invalid-overridden-method
self: services.BentoServiceServicer,
request: pb.Request,
context: BentoServicerContext,
) -> pb.Response | None:
if request.api_name not in service.apis:
raise InvalidArgument(
f"given 'api_name' is not defined in {service.name}",
) from None

api = service.apis[request.api_name]
response = pb.Response()

# NOTE: since IODescriptor._proto_fields is a tuple, the order is preserved.
# This is important so that we know the order of fields to process.
# We will use fields descriptor to determine how to process that request.
try:
# we will check if the given fields list contains a pb.Multipart.
input_proto = getattr(
request,
validate_proto_fields(request.WhichOneof("content"), api.input),
)
input_data = await api.input.from_proto(input_proto)
if asyncio.iscoroutinefunction(api.func):
if api.multi_input:
output = await api.func(**input_data)
else:
output = await api.func(input_data)
else:
if api.multi_input:
output = await anyio.to_thread.run_sync(api.func, **input_data)
else:
output = await anyio.to_thread.run_sync(api.func, input_data)
res = await api.output.to_proto(output)
# TODO(aarnphm): support multiple proto fields
response = pb.Response(**{api.output._proto_fields[0]: res})
except BentoMLException as e:
log_exception(request, sys.exc_info())
await context.abort(code=grpc_status_code(e), details=e.message)
except (RuntimeError, TypeError, NotImplementedError):
log_exception(request, sys.exc_info())
await context.abort(
code=grpc.StatusCode.INTERNAL,
details="A runtime error has occurred, see stacktrace from logs.",
)
except Exception: # pylint: disable=broad-except
log_exception(request, sys.exc_info())
await context.abort(
code=grpc.StatusCode.INTERNAL,
details="An error has occurred in BentoML user code when handling this request, find the error details in server logs.",
)
return response

return BentoServiceImpl()

0 comments on commit 7ddf489

Please sign in to comment.