From 591418b5a00ba9c89c4eb139bfa91ca0fa690b8d Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Tue, 8 Nov 2022 03:11:10 -0800 Subject: [PATCH] chore: refactor and __slots__ implementation for Base Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> --- src/bentoml/_internal/io_descriptors/base.py | 86 +++++++++++++------- 1 file changed, 57 insertions(+), 29 deletions(-) diff --git a/src/bentoml/_internal/io_descriptors/base.py b/src/bentoml/_internal/io_descriptors/base.py index 820734ae026..83655c7900a 100644 --- a/src/bentoml/_internal/io_descriptors/base.py +++ b/src/bentoml/_internal/io_descriptors/base.py @@ -50,13 +50,44 @@ def from_spec(spec: dict[str, str]) -> IODescriptor[t.Any]: return IO_DESCRIPTOR_REGISTRY[spec["id"]].from_spec(spec) -class IODescriptor(ABC, t.Generic[IOType]): +class _OpenAPIMeta: + @abstractmethod + def openapi_schema(self) -> Schema | Reference: + raise NotImplementedError + + @abstractmethod + def openapi_components(self) -> dict[str, t.Any] | None: + raise NotImplementedError + + @abstractmethod + def openapi_example(self) -> t.Any | None: + raise NotImplementedError + + @abstractmethod + def openapi_request_body(self) -> dict[str, t.Any]: + raise NotImplementedError + + @abstractmethod + def openapi_responses(self) -> dict[str, t.Any]: + raise NotImplementedError + + +class IODescriptor(ABC, _OpenAPIMeta, t.Generic[IOType]): """ IODescriptor describes the input/output data format of an InferenceAPI defined in a :code:`bentoml.Service`. This is an abstract base class for extending new HTTP endpoint IO descriptor types in BentoServer. """ + __slots__ = ( + "_initialized", + "_args", + "_kwargs", + "_proto_fields", + "_mime_type", + "descriptor_id", + ) + HTTP_METHODS = ["POST"] descriptor_id: str | None @@ -65,7 +96,9 @@ class IODescriptor(ABC, t.Generic[IOType]): _rpc_content_type: str = "application/grpc" _proto_fields: tuple[ProtoField] _sample: IOType | None = None - _set_sample: singledispatchmethod["IODescriptor[t.Any]"] = set_sample + _initialized: bool + _args: t.Sequence[t.Any] + _kwargs: dict[str, t.Any] def __init_subclass__(cls, *, descriptor_id: str | None = None): if descriptor_id is not None: @@ -75,17 +108,35 @@ def __init_subclass__(cls, *, descriptor_id: str | None = None): ) IO_DESCRIPTOR_REGISTRY[descriptor_id] = cls cls.descriptor_id = descriptor_id + cls._initialized = False def __new__(cls, *args: t.Any, **kwargs: t.Any) -> Self: sample = kwargs.pop("_sample", None) - kls = object.__new__(cls) + klass = object.__new__(cls) if sample is None: set_sample.register(type(None), lambda self, _: self) - kls = kls._set_sample(sample) - # TODO: lazy init - kls.__init__(*args, **kwargs) + kls = klass._set_sample(sample) + kls._args = args + kls._kwargs = kwargs return kls + def __getattr__(self, name: str) -> t.Any: + if not self._initialized: + self._lazy_init() + assert self._initialized + return getattr(self, name) + + def __repr__(self) -> str: + return self.__class__.__qualname__ + + def _lazy_init(self) -> None: + self.__init__(*self._args, **self._kwargs) + self._initialized = True + del self._args + del self._kwargs + + _set_sample: singledispatchmethod[IODescriptor[t.Any]] = set_sample + @property def sample(self) -> IOType | None: return self._sample @@ -114,33 +165,10 @@ def to_spec(self) -> dict[str, t.Any]: def from_spec(cls, spec: dict[str, t.Any]) -> Self: raise NotImplementedError - def __repr__(self) -> str: - return self.__class__.__qualname__ - @abstractmethod def input_type(self) -> InputType: raise NotImplementedError - @abstractmethod - def openapi_schema(self) -> Schema | Reference: - raise NotImplementedError - - @abstractmethod - def openapi_components(self) -> dict[str, t.Any] | None: - raise NotImplementedError - - @abstractmethod - def openapi_example(self) -> t.Any | None: - raise NotImplementedError - - @abstractmethod - def openapi_request_body(self) -> dict[str, t.Any]: - raise NotImplementedError - - @abstractmethod - def openapi_responses(self) -> dict[str, t.Any]: - raise NotImplementedError - @abstractmethod async def from_http_request(self, request: Request) -> IOType: ...