Skip to content

Commit

Permalink
chore: refactor and __slots__ implementation for Base
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
aarnphm committed Nov 8, 2022
1 parent dafcac5 commit 591418b
Showing 1 changed file with 57 additions and 29 deletions.
86 changes: 57 additions & 29 deletions src/bentoml/_internal/io_descriptors/base.py
Expand Up @@ -50,13 +50,44 @@ def from_spec(spec: dict[str, str]) -> IODescriptor[t.Any]:
return IO_DESCRIPTOR_REGISTRY[spec["id"]].from_spec(spec)


class IODescriptor(ABC, t.Generic[IOType]):
class _OpenAPIMeta:
@abstractmethod
def openapi_schema(self) -> Schema | Reference:
raise NotImplementedError

@abstractmethod
def openapi_components(self) -> dict[str, t.Any] | None:
raise NotImplementedError

@abstractmethod
def openapi_example(self) -> t.Any | None:
raise NotImplementedError

@abstractmethod
def openapi_request_body(self) -> dict[str, t.Any]:
raise NotImplementedError

@abstractmethod
def openapi_responses(self) -> dict[str, t.Any]:
raise NotImplementedError


class IODescriptor(ABC, _OpenAPIMeta, t.Generic[IOType]):
"""
IODescriptor describes the input/output data format of an InferenceAPI defined
in a :code:`bentoml.Service`. This is an abstract base class for extending new HTTP
endpoint IO descriptor types in BentoServer.
"""

__slots__ = (
"_initialized",
"_args",
"_kwargs",
"_proto_fields",
"_mime_type",
"descriptor_id",
)

HTTP_METHODS = ["POST"]

descriptor_id: str | None
Expand All @@ -65,7 +96,9 @@ class IODescriptor(ABC, t.Generic[IOType]):
_rpc_content_type: str = "application/grpc"
_proto_fields: tuple[ProtoField]
_sample: IOType | None = None
_set_sample: singledispatchmethod["IODescriptor[t.Any]"] = set_sample
_initialized: bool
_args: t.Sequence[t.Any]
_kwargs: dict[str, t.Any]

def __init_subclass__(cls, *, descriptor_id: str | None = None):
if descriptor_id is not None:
Expand All @@ -75,17 +108,35 @@ def __init_subclass__(cls, *, descriptor_id: str | None = None):
)
IO_DESCRIPTOR_REGISTRY[descriptor_id] = cls
cls.descriptor_id = descriptor_id
cls._initialized = False

def __new__(cls, *args: t.Any, **kwargs: t.Any) -> Self:
sample = kwargs.pop("_sample", None)
kls = object.__new__(cls)
klass = object.__new__(cls)
if sample is None:
set_sample.register(type(None), lambda self, _: self)
kls = kls._set_sample(sample)
# TODO: lazy init
kls.__init__(*args, **kwargs)
kls = klass._set_sample(sample)
kls._args = args
kls._kwargs = kwargs
return kls

def __getattr__(self, name: str) -> t.Any:
if not self._initialized:
self._lazy_init()
assert self._initialized
return getattr(self, name)

def __repr__(self) -> str:
return self.__class__.__qualname__

def _lazy_init(self) -> None:
self.__init__(*self._args, **self._kwargs)
self._initialized = True
del self._args
del self._kwargs

_set_sample: singledispatchmethod[IODescriptor[t.Any]] = set_sample

@property
def sample(self) -> IOType | None:
return self._sample
Expand Down Expand Up @@ -114,33 +165,10 @@ def to_spec(self) -> dict[str, t.Any]:
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
raise NotImplementedError

def __repr__(self) -> str:
return self.__class__.__qualname__

@abstractmethod
def input_type(self) -> InputType:
raise NotImplementedError

@abstractmethod
def openapi_schema(self) -> Schema | Reference:
raise NotImplementedError

@abstractmethod
def openapi_components(self) -> dict[str, t.Any] | None:
raise NotImplementedError

@abstractmethod
def openapi_example(self) -> t.Any | None:
raise NotImplementedError

@abstractmethod
def openapi_request_body(self) -> dict[str, t.Any]:
raise NotImplementedError

@abstractmethod
def openapi_responses(self) -> dict[str, t.Any]:
raise NotImplementedError

@abstractmethod
async def from_http_request(self, request: Request) -> IOType:
...
Expand Down

0 comments on commit 591418b

Please sign in to comment.