Skip to content

Commit

Permalink
feat(grpc): adding service metadata
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
aarnphm committed Nov 25, 2022
1 parent 4e771b6 commit db485b9
Show file tree
Hide file tree
Showing 10 changed files with 633 additions and 112 deletions.
42 changes: 26 additions & 16 deletions src/bentoml/_internal/io_descriptors/json.py
Expand Up @@ -30,6 +30,7 @@

import pydantic
import pydantic.schema as schema
from google.protobuf import message as _message
from google.protobuf import struct_pb2
from typing_extensions import Self

Expand Down Expand Up @@ -392,19 +393,28 @@ async def to_proto(self, obj: JSONType) -> struct_pb2.Value:
if LazyType["pydantic.BaseModel"]("pydantic.BaseModel").isinstance(obj):
obj = obj.dict()
msg = struct_pb2.Value()
# To handle None cases.
if obj is not None:
from google.protobuf.json_format import ParseDict

if isinstance(obj, (dict, str, list, float, int, bool)):
# ParseDict handles google.protobuf.Struct type
# directly if given object has a supported type
ParseDict(obj, msg)
else:
# If given object doesn't have a supported type, we will
# use given JSON encoder to convert it to dictionary
# and then parse it to google.protobuf.Struct.
# Note that if a custom JSON encoder is used, it mustn't
# take any arguments.
ParseDict(self._json_encoder().default(obj), msg)
return msg
return parse_dict_to_proto(obj, msg, json_encoder=self._json_encoder)


def parse_dict_to_proto(
obj: JSONType,
msg: _message.Message,
json_encoder: type[json.JSONEncoder] = DefaultJsonEncoder,
) -> t.Any:
# To handle None cases.
if obj is not None:
from google.protobuf.json_format import ParseDict

if isinstance(obj, (dict, str, list, float, int, bool)):
# ParseDict handles google.protobuf.Struct type
# directly if given object has a supported type
ParseDict(obj, msg)
else:
# If given object doesn't have a supported type, we will
# use given JSON encoder to convert it to dictionary
# and then parse it to google.protobuf.Struct.
# Note that if a custom JSON encoder is used, it mustn't
# take any arguments.
ParseDict(json_encoder().default(obj), msg)
# otherwise this function is an identity op for the msg if obj is None.
return msg
231 changes: 160 additions & 71 deletions src/bentoml/_internal/server/grpc/servicer.py
Expand Up @@ -9,14 +9,14 @@

import anyio

from bentoml.grpc.utils import import_grpc
from bentoml.grpc.utils import grpc_status_code
from bentoml.grpc.utils import validate_proto_fields
from bentoml.grpc.utils import import_generated_stubs

from ...utils import LazyLoader
from ....exceptions import InvalidArgument
from ....exceptions import BentoMLException
from ....grpc.utils import import_grpc
from ....grpc.utils import grpc_status_code
from ....grpc.utils import validate_proto_fields
from ....grpc.utils import import_generated_stubs
from ....grpc.utils import LATEST_PROTOCOL_VERSION

logger = logging.getLogger(__name__)

Expand All @@ -26,29 +26,32 @@
import grpc
from grpc import aio
from grpc_health.v1 import health
from google.protobuf import struct_pb2
from typing_extensions import Self

from bentoml.grpc.v1 import service_pb2 as pb
from bentoml.grpc.v1 import service_pb2_grpc as services
from bentoml.grpc.types import Interceptors
from bentoml.grpc.types import AddServicerFn
from bentoml.grpc.types import ServicerClass
from bentoml.grpc.types import BentoServicerContext

from ...service.service import Service

from ...io_descriptors.base import IODescriptor

if LATEST_PROTOCOL_VERSION == "v1":
from bentoml.grpc.v1 import service_pb2 as pb
from bentoml.grpc.v1 import service_pb2_grpc as services
else:
from bentoml.grpc.v1alpha1 import service_pb2 as pb
from bentoml.grpc.v1alpha1 import service_pb2_grpc as services
else:
pb, services = import_generated_stubs()
grpc, aio = import_grpc()
health = LazyLoader(
"health",
globals(),
"grpc_health.v1.health",
exc_msg="'grpcio-health-checking' is required for using health checking endpoints. Install with 'pip install grpcio-health-checking'.",
)
containers = LazyLoader(
"containers", globals(), "google.protobuf.internal.containers"
)
struct_pb2 = LazyLoader("struct_pb2", globals(), "google.protobuf.struct_pb2")


def log_exception(request: pb.Request, exc_info: ExcInfoType) -> None:
Expand All @@ -67,6 +70,7 @@ def __init__(
mount_servicers: t.Sequence[tuple[ServicerClass, AddServicerFn, list[str]]]
| None = None,
interceptors: Interceptors | None = None,
protocol_version: str = LATEST_PROTOCOL_VERSION,
) -> None:
self.bento_service = service

Expand All @@ -75,13 +79,17 @@ def __init__(
self.mount_servicers = [] if not mount_servicers else list(mount_servicers)
self.interceptors = [] if not interceptors else list(interceptors)
self.loaded = False
self.protocol_version = protocol_version

def load(self):
pb, _ = import_generated_stubs(self.protocol_version)
assert not self.loaded

self.interceptors_stack = self.build_interceptors_stack()

self.bento_servicer = create_bento_servicer(self.bento_service)
self.bento_servicer = create_bento_servicer(
self.bento_service, protocol_version=self.protocol_version
)

# Create a health check servicer. We use the non-blocking implementation
# to avoid thread starvation.
Expand Down Expand Up @@ -111,66 +119,147 @@ def __bool__(self):
return self.loaded


def create_bento_servicer(service: Service) -> services.BentoServiceServicer:
def create_bento_servicer(
service: Service, protocol_version: str = LATEST_PROTOCOL_VERSION
) -> services.BentoServiceServicer:
"""
This is the actual implementation of BentoServicer.
Main inference entrypoint will be invoked via /bentoml.grpc.<version>.BentoService/Call
"""

class BentoServiceImpl(services.BentoServiceServicer):
"""An asyncio implementation of BentoService servicer."""

async def Call( # type: ignore (no async types) # pylint: disable=invalid-overridden-method
self,
request: pb.Request,
context: BentoServicerContext,
) -> pb.Response | None:
if request.api_name not in service.apis:
raise InvalidArgument(
f"given 'api_name' is not defined in {service.name}",
) from None

api = service.apis[request.api_name]
response = pb.Response()

# NOTE: since IODescriptor._proto_fields is a tuple, the order is preserved.
# This is important so that we know the order of fields to process.
# We will use fields descriptor to determine how to process that request.
try:
# we will check if the given fields list contains a pb.Multipart.
input_proto = getattr(
request,
validate_proto_fields(request.WhichOneof("content"), api.input),
)
input_data = await api.input.from_proto(input_proto)
if asyncio.iscoroutinefunction(api.func):
if api.multi_input:
output = await api.func(**input_data)
else:
output = await api.func(input_data)
if protocol_version == "v1":
from bentoml.grpc.v1 import service_pb2 as pb
from bentoml.grpc.v1.service_pb2_grpc import BentoServiceServicer
else:
from bentoml.grpc.v1alpha1 import service_pb2 as pb
from bentoml.grpc.v1alpha1.service_pb2_grpc import BentoServiceServicer

attrs: dict[str, t.Any] = {
"__doc__": "An asyncio implementation of BentoService servicer."
}

async def Call(
_: BentoServiceServicer,
request: pb.Request,
context: BentoServicerContext,
) -> pb.Response | None:
if request.api_name not in service.apis:
raise InvalidArgument(
f"given 'api_name' is not defined in {service.name}",
) from None

api = service.apis[request.api_name]
response = pb.Response()

# NOTE: since IODescriptor._proto_fields is a tuple, the order is preserved.
# This is important so that we know the order of fields to process.
# We will use fields descriptor to determine how to process that request.
try:
# we will check if the given fields list contains a pb.Multipart.
input_proto = getattr(
request,
validate_proto_fields(request.WhichOneof("content"), api.input),
)
input_data = await api.input.from_proto(input_proto)
if asyncio.iscoroutinefunction(api.func):
if api.multi_input:
output = await api.func(**input_data)
else:
output = await api.func(input_data)
else:
if api.multi_input:
output = await anyio.to_thread.run_sync(api.func, **input_data)
else:
if api.multi_input:
output = await anyio.to_thread.run_sync(api.func, **input_data)
else:
output = await anyio.to_thread.run_sync(api.func, input_data)
res = await api.output.to_proto(output)
# TODO(aarnphm): support multiple proto fields
response = pb.Response(**{api.output._proto_fields[0]: res})
except BentoMLException as e:
log_exception(request, sys.exc_info())
await context.abort(code=grpc_status_code(e), details=e.message)
except (RuntimeError, TypeError, NotImplementedError):
log_exception(request, sys.exc_info())
await context.abort(
code=grpc.StatusCode.INTERNAL,
details="A runtime error has occurred, see stacktrace from logs.",
)
except Exception: # pylint: disable=broad-except
log_exception(request, sys.exc_info())
await context.abort(
code=grpc.StatusCode.INTERNAL,
details="An error has occurred in BentoML user code when handling this request, find the error details in server logs.",
)
return response

return BentoServiceImpl()
output = await anyio.to_thread.run_sync(api.func, input_data)
res = await api.output.to_proto(output)
# TODO(aarnphm): support multiple proto fields
response = pb.Response(**{api.output._proto_fields[0]: res})
except BentoMLException as e:
log_exception(request, sys.exc_info())
await context.abort(code=grpc_status_code(e), details=e.message)
except (RuntimeError, TypeError, NotImplementedError):
log_exception(request, sys.exc_info())
await context.abort(
code=grpc.StatusCode.INTERNAL,
details="A runtime error has occurred, see stacktrace from logs.",
)
except Exception: # pylint: disable=broad-except
log_exception(request, sys.exc_info())
await context.abort(
code=grpc.StatusCode.INTERNAL,
details="An error has occurred in BentoML user code when handling this request, find the error details in server logs.",
)
return response

attrs.setdefault("Call", Call)

if protocol_version == "v1":
# "v1" introduces ServiceMetadata to send in bentoml.Service information.
from bentoml.grpc.v1.service_pb2 import ServiceMetadataRequest
from bentoml.grpc.v1.service_pb2 import ServiceMetadataResponse

async def ServiceMetadata(
_: BentoServiceServicer,
request: ServiceMetadataRequest, # pylint: disable=unused-argument
context: BentoServicerContext, # pylint: disable=unused-argument
) -> ServiceMetadataResponse:
return ServiceMetadataResponse(
name=service.name,
docs=service.doc,
apis=[
ServiceMetadataResponse.InferenceAPI(
name=api.name,
docs=api.doc,
input=ServiceMetadataResponse.DescriptorMetadata(
descriptor_id=api.input.descriptor_id,
attributes=make_attributes_struct(api.input),
),
output=ServiceMetadataResponse.DescriptorMetadata(
descriptor_id=api.output.descriptor_id,
attributes=make_attributes_struct(api.output),
),
)
for api in service.apis.values()
],
)

attrs.setdefault("ServiceMetadata", ServiceMetadata)

if TYPE_CHECKING:
# NOTE: typeshed only accept type expression for type() class creation.
# Hence, pyright will raise an error if we only pass in BentoServiceServicer, as it won't
# acknowledge BentoServiceServicer as a type expression.
BentoServiceServicerT = type(BentoServiceServicer)
else:
BentoServiceServicerT = BentoServiceServicer

return type("BentoServiceImpl", (BentoServiceServicerT,), attrs)()


if TYPE_CHECKING:
NestedDictStrAny = dict[str, dict[str, t.Any] | t.Any]
TupleAny = tuple[t.Any, ...]


def _tuple_converter(d: NestedDictStrAny | None) -> NestedDictStrAny | None:
# handles case for struct_pb2.Value where nested items are tuple.
# if that is the case, then convert to list.
# This dict is only one level deep, as we don't allow nested Multipart.
if d is not None:
for key, value in d.items():
if isinstance(value, tuple):
d[key] = list(t.cast("TupleAny", value))
elif isinstance(value, dict):
d[key] = _tuple_converter(t.cast("NestedDictStrAny", value))
return d


def make_attributes_struct(io: IODescriptor[t.Any]) -> struct_pb2.Struct:
from ...io_descriptors.json import parse_dict_to_proto

return struct_pb2.Struct(
fields={
"args": parse_dict_to_proto(
_tuple_converter(io.to_spec().get("args", None)), struct_pb2.Value()
)
}
)
2 changes: 1 addition & 1 deletion src/bentoml/_internal/service/inference_api.py
Expand Up @@ -26,7 +26,7 @@
class InferenceAPI:
def __init__(
self,
user_defined_callback: t.Callable[..., t.Any] | None,
user_defined_callback: t.Callable[..., t.Any],
input_descriptor: IODescriptor[t.Any],
output_descriptor: IODescriptor[t.Any],
name: Optional[str],
Expand Down
2 changes: 2 additions & 0 deletions src/bentoml/grpc/utils/__init__.py
Expand Up @@ -10,6 +10,7 @@
from bentoml.exceptions import InvalidArgument
from bentoml.grpc.utils._import_hook import import_grpc
from bentoml.grpc.utils._import_hook import import_generated_stubs
from bentoml.grpc.utils._import_hook import LATEST_PROTOCOL_VERSION

if TYPE_CHECKING:
from enum import Enum
Expand All @@ -36,6 +37,7 @@
"import_generated_stubs",
"import_grpc",
"validate_proto_fields",
"LATEST_PROTOCOL_VERSION",
]

logger = logging.getLogger(__name__)
Expand Down
4 changes: 3 additions & 1 deletion src/bentoml/grpc/utils/_import_hook.py
Expand Up @@ -5,9 +5,11 @@
if TYPE_CHECKING:
import types

LATEST_PROTOCOL_VERSION = "v1"


def import_generated_stubs(
version: str = "v1",
version: str = LATEST_PROTOCOL_VERSION,
file: str = "service.proto",
) -> tuple[types.ModuleType, types.ModuleType]:
"""
Expand Down

0 comments on commit db485b9

Please sign in to comment.