diff --git a/src/bentoml/_internal/server/grpc/servicer/v1/__init__.py b/src/bentoml/_internal/server/grpc/servicer/v1/__init__.py index d88e7613e7d..a12fa87c274 100644 --- a/src/bentoml/_internal/server/grpc/servicer/v1/__init__.py +++ b/src/bentoml/_internal/server/grpc/servicer/v1/__init__.py @@ -1,12 +1,14 @@ from __future__ import annotations import sys +import typing as t import asyncio import logging from typing import TYPE_CHECKING import anyio +from .....utils import LazyLoader from ......exceptions import InvalidArgument from ......exceptions import BentoMLException from ......grpc.utils import import_grpc @@ -20,6 +22,7 @@ from logging import _ExcInfoType as ExcInfoType # type: ignore (private warning) import grpc + from google.protobuf import struct_pb2 from bentoml.grpc.types import BentoServicerContext @@ -29,6 +32,7 @@ else: grpc, _ = import_grpc() pb, services = import_generated_stubs(version="v1") + struct_pb2 = LazyLoader("struct_pb2", globals(), "google.protobuf.struct_pb2") def log_exception(request: pb.Request, exc_info: ExcInfoType) -> None: @@ -46,7 +50,7 @@ class BentoServiceImpl(services.BentoServiceServicer): """An asyncio implementation of BentoService servicer.""" async def Call( # type: ignore (no async types) # pylint: disable=invalid-overridden-method - self: services.BentoServiceServicer, + self, request: pb.Request, context: BentoServicerContext, ) -> pb.Response | None: @@ -98,4 +102,62 @@ async def Call( # type: ignore (no async types) # pylint: disable=invalid-overr ) return response + async def ServiceMetadata( # type: ignore (no async types) # pylint: disable=invalid-overridden-method + self: services.BentoServiceServicer, + request: pb.ServiceMetadataRequest, # pylint: disable=unused-argument + context: BentoServicerContext, # pylint: disable=unused-argument + ) -> pb.ServiceMetadataResponse: + return pb.ServiceMetadataResponse( + name=service.name, + docs=service.doc, + apis=[ + pb.ServiceMetadataResponse.InferenceAPI( + name=api.name, + docs=api.doc, + input=make_descriptor_spec( + api.input.to_spec(), pb.ServiceMetadataResponse + ), + output=make_descriptor_spec( + api.output.to_spec(), pb.ServiceMetadataResponse + ), + ) + for api in service.apis.values() + ], + ) + return BentoServiceImpl() + + +if TYPE_CHECKING: + NestedDictStrAny = dict[str, dict[str, t.Any] | t.Any] + TupleAny = tuple[t.Any, ...] + + +def _tuple_converter(d: NestedDictStrAny | None) -> NestedDictStrAny | None: + # handles case for struct_pb2.Value where nested items are tuple. + # if that is the case, then convert to list. + # This dict is only one level deep, as we don't allow nested Multipart. + if d is not None: + for key, value in d.items(): + if isinstance(value, tuple): + d[key] = list(t.cast("TupleAny", value)) + elif isinstance(value, dict): + d[key] = _tuple_converter(t.cast("NestedDictStrAny", value)) + return d + + +def make_descriptor_spec( + spec: dict[str, t.Any], pb: type[pb.ServiceMetadataResponse] +) -> pb.ServiceMetadataResponse.DescriptorMetadata: + from .....io_descriptors.json import parse_dict_to_proto + + descriptor_id = spec.pop("id") + return pb.DescriptorMetadata( + descriptor_id=descriptor_id, + attributes=struct_pb2.Struct( + fields={ + key: parse_dict_to_proto(_tuple_converter(value), struct_pb2.Value()) + for key, value in spec.items() + } + ), + )