Skip to content

Commit

Permalink
fix: make sure descriptor_id is parsed correctly from servicer
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
aarnphm committed Nov 26, 2022
1 parent db485b9 commit 6ea81ff
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 25 deletions.
8 changes: 6 additions & 2 deletions src/bentoml/_internal/io_descriptors/base.py
Expand Up @@ -30,13 +30,17 @@
)
OpenAPIResponse = dict[str, str | dict[str, MediaType] | dict[str, t.Any]]

class SpecDict(t.TypedDict):
id: str
args: dict[str, t.Any]


IO_DESCRIPTOR_REGISTRY: dict[str, type[IODescriptor[t.Any]]] = {}

IOType = t.TypeVar("IOType")


def from_spec(spec: dict[str, str]) -> IODescriptor[t.Any]:
def from_spec(spec: SpecDict) -> IODescriptor[t.Any]:
if "id" not in spec:
raise InvalidArgument(f"IO descriptor spec ({spec}) missing ID.")
return IO_DESCRIPTOR_REGISTRY[spec["id"]].from_spec(spec)
Expand Down Expand Up @@ -129,7 +133,7 @@ def to_spec(self) -> dict[str, t.Any]:

@classmethod
@abstractmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
def from_spec(cls, spec: SpecDict) -> Self:
raise NotImplementedError

@abstractmethod
Expand Down
3 changes: 2 additions & 1 deletion src/bentoml/_internal/io_descriptors/file.py
Expand Up @@ -30,6 +30,7 @@

from bentoml.grpc.v1 import service_pb2 as pb

from .base import SpecDict
from .base import OpenAPIResponse
from ..context import InferenceApiContext as Context

Expand Down Expand Up @@ -143,7 +144,7 @@ def _from_sample(self, sample: FileType | str) -> FileType:
return sample

@classmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
def from_spec(cls, spec: SpecDict) -> Self:
if "args" not in spec:
raise InvalidArgument(f"Missing args key in File spec: {spec}")
return cls(**spec["args"])
Expand Down
3 changes: 2 additions & 1 deletion src/bentoml/_internal/io_descriptors/image.py
Expand Up @@ -36,6 +36,7 @@
from bentoml.grpc.v1 import service_pb2 as pb

from .. import external_typing as ext
from .base import SpecDict
from .base import OpenAPIResponse
from ..context import InferenceApiContext as Context

Expand Down Expand Up @@ -245,7 +246,7 @@ def to_spec(self) -> dict[str, t.Any]:
}

@classmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
def from_spec(cls, spec: SpecDict) -> Self:
if "args" not in spec:
raise InvalidArgument(f"Missing args key in Image spec: {spec}")

Expand Down
3 changes: 2 additions & 1 deletion src/bentoml/_internal/io_descriptors/json.py
Expand Up @@ -35,6 +35,7 @@
from typing_extensions import Self

from .. import external_typing as ext
from .base import SpecDict
from .base import OpenAPIResponse
from ..context import InferenceApiContext as Context

Expand Down Expand Up @@ -235,7 +236,7 @@ def to_spec(self) -> dict[str, t.Any]:
}

@classmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
def from_spec(cls, spec: SpecDict) -> Self:
if "args" not in spec:
raise InvalidArgument(f"Missing args key in JSON spec: {spec}")
if "has_pydantic_model" in spec["args"] and spec["args"]["has_pydantic_model"]:
Expand Down
3 changes: 2 additions & 1 deletion src/bentoml/_internal/io_descriptors/multipart.py
Expand Up @@ -25,6 +25,7 @@

from bentoml.grpc.v1 import service_pb2 as pb

from .base import SpecDict
from .base import OpenAPIResponse
from ..types import LazyType
from ..context import InferenceApiContext as Context
Expand Down Expand Up @@ -202,7 +203,7 @@ def to_spec(self) -> dict[str, t.Any]:
}

@classmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
def from_spec(cls, spec: SpecDict) -> Self:
if "args" not in spec:
raise InvalidArgument(f"Missing args key in Multipart spec: {spec}")
return Multipart(
Expand Down
3 changes: 2 additions & 1 deletion src/bentoml/_internal/io_descriptors/numpy.py
Expand Up @@ -28,6 +28,7 @@
from bentoml.grpc.v1 import service_pb2 as pb

from .. import external_typing as ext
from .base import SpecDict
from .base import OpenAPIResponse
from ..context import InferenceApiContext as Context
else:
Expand Down Expand Up @@ -252,7 +253,7 @@ def to_spec(self) -> dict[str, t.Any]:
}

@classmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
def from_spec(cls, spec: SpecDict) -> Self:
if "args" not in spec:
raise InvalidArgument(f"Missing args key in NumpyNdarray spec: {spec}")
res = NumpyNdarray(**spec["args"])
Expand Down
5 changes: 3 additions & 2 deletions src/bentoml/_internal/io_descriptors/pandas.py
Expand Up @@ -35,6 +35,7 @@
from bentoml.grpc.v1 import service_pb2 as pb

from .. import external_typing as ext
from .base import SpecDict
from .base import OpenAPIResponse
from ..context import InferenceApiContext as Context

Expand Down Expand Up @@ -453,7 +454,7 @@ def to_spec(self) -> dict[str, t.Any]:
}

@classmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
def from_spec(cls, spec: SpecDict) -> Self:
if "args" not in spec:
raise InvalidArgument(f"Missing args key in PandasDataFrame spec: {spec}")
res = PandasDataFrame(**spec["args"])
Expand Down Expand Up @@ -900,7 +901,7 @@ def to_spec(self) -> dict[str, t.Any]:
}

@classmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
def from_spec(cls, spec: SpecDict) -> Self:
if "args" not in spec:
raise InvalidArgument(f"Missing args key in PandasSeries spec: {spec}")
res = PandasSeries(**spec["args"])
Expand Down
3 changes: 2 additions & 1 deletion src/bentoml/_internal/io_descriptors/text.py
Expand Up @@ -19,6 +19,7 @@
from google.protobuf import wrappers_pb2
from typing_extensions import Self

from .base import SpecDict
from .base import OpenAPIResponse
from ..context import InferenceApiContext as Context
else:
Expand Down Expand Up @@ -111,7 +112,7 @@ def to_spec(self) -> dict[str, t.Any]:
return {"id": self.descriptor_id}

@classmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
def from_spec(cls, spec: SpecDict) -> Self:
return cls()

def openapi_schema(self) -> Schema:
Expand Down
35 changes: 20 additions & 15 deletions src/bentoml/_internal/server/grpc/servicer.py
Expand Up @@ -35,14 +35,17 @@
from bentoml.grpc.types import BentoServicerContext

from ...service.service import Service
from ...io_descriptors.base import IODescriptor
from ...io_descriptors.base import SpecDict

if LATEST_PROTOCOL_VERSION == "v1":
from bentoml.grpc.v1 import service_pb2 as pb
from bentoml.grpc.v1 import service_pb2_grpc as services
from bentoml.grpc.v1.service_pb2 import ServiceMetadataRequest
from bentoml.grpc.v1.service_pb2 import ServiceMetadataResponse
else:
from bentoml.grpc.v1alpha1 import service_pb2 as pb
from bentoml.grpc.v1alpha1 import service_pb2_grpc as services

else:
grpc, aio = import_grpc()
health = LazyLoader(
Expand Down Expand Up @@ -194,7 +197,6 @@ async def Call(

if protocol_version == "v1":
# "v1" introduces ServiceMetadata to send in bentoml.Service information.
from bentoml.grpc.v1.service_pb2 import ServiceMetadataRequest
from bentoml.grpc.v1.service_pb2 import ServiceMetadataResponse

async def ServiceMetadata(
Expand All @@ -209,13 +211,11 @@ async def ServiceMetadata(
ServiceMetadataResponse.InferenceAPI(
name=api.name,
docs=api.doc,
input=ServiceMetadataResponse.DescriptorMetadata(
descriptor_id=api.input.descriptor_id,
attributes=make_attributes_struct(api.input),
input=make_descriptor_spec(
api.input.to_spec(), ServiceMetadataResponse
),
output=ServiceMetadataResponse.DescriptorMetadata(
descriptor_id=api.output.descriptor_id,
attributes=make_attributes_struct(api.output),
output=make_descriptor_spec(
api.output.to_spec(), ServiceMetadataResponse
),
)
for api in service.apis.values()
Expand Down Expand Up @@ -253,13 +253,18 @@ def _tuple_converter(d: NestedDictStrAny | None) -> NestedDictStrAny | None:
return d


def make_attributes_struct(io: IODescriptor[t.Any]) -> struct_pb2.Struct:
def make_descriptor_spec(
spec: SpecDict, pb: type[ServiceMetadataResponse]
) -> ServiceMetadataResponse.DescriptorMetadata:
from ...io_descriptors.json import parse_dict_to_proto

return struct_pb2.Struct(
fields={
"args": parse_dict_to_proto(
_tuple_converter(io.to_spec().get("args", None)), struct_pb2.Value()
)
}
return pb.DescriptorMetadata(
descriptor_id=spec["id"],
attributes=struct_pb2.Struct(
fields={
"args": parse_dict_to_proto(
_tuple_converter(spec["args"]), struct_pb2.Value()
)
}
),
)

0 comments on commit 6ea81ff

Please sign in to comment.