diff --git a/src/bentoml/_internal/frameworks/tensorflow_v2.py b/src/bentoml/_internal/frameworks/tensorflow_v2.py index 2bfab21845d..3f603c7fe5d 100644 --- a/src/bentoml/_internal/frameworks/tensorflow_v2.py +++ b/src/bentoml/_internal/frameworks/tensorflow_v2.py @@ -35,11 +35,12 @@ from ..external_typing import tensorflow as tf_ext TFArgType = t.Union[t.List[t.Union[int, float]], ext.NpNDArray, tf_ext.Tensor] + TFModelOutputType = tf_ext.EagerTensor | tuple[tf_ext.EagerTensor] + TFRunnableOutputType = ext.NpNDArray | tuple[ext.NpNDArray] + try: import tensorflow as tf - TFModelOutputType = tf_ext.EagerTensor | tuple[tf_ext.EagerTensor] - TFRunnableOutputType = ext.NpNDArray | tuple[ext.NpNDArray] except ImportError: # pragma: no cover raise MissingDependencyException( "'tensorflow' is required in order to use module 'bentoml.tensorflow', install tensorflow with 'pip install tensorflow'. For more information, refer to https://www.tensorflow.org/install" @@ -70,7 +71,7 @@ def get(tag_like: str | Tag) -> bentoml.Model: def load_model( bento_model: str | Tag | bentoml.Model, device_name: str = "/device:CPU:0", -) -> "tf_ext.AutoTrackable" | "tf_ext.Module": +) -> tf_ext.AutoTrackable | tf_ext.Module: """ Load a tensorflow model from BentoML local modelstore with given name. @@ -343,7 +344,7 @@ def _run_method( def add_run_method(method_name: str, options: ModelSignature): def run_method( - runnable_self: TensorflowRunnable, *args: "TFArgType", **kwargs: "TFArgType" + runnable_self: TensorflowRunnable, *args: TFArgType, **kwargs: TFArgType ) -> TFRunnableOutputType: _run_method = runnable_self.methods_cache.get( method_name @@ -374,9 +375,9 @@ class TensorflowTensorContainer( ): @classmethod def batches_to_batch( - cls, batches: t.Sequence["tf_ext.EagerTensor"], batch_dim: int = 0 - ) -> t.Tuple["tf_ext.EagerTensor", list[int]]: - batch: "tf_ext.EagerTensor" = tf.concat(batches, axis=batch_dim) + cls, batches: t.Sequence[tf_ext.EagerTensor], batch_dim: int = 0 + ) -> t.Tuple[tf_ext.EagerTensor, list[int]]: + batch: tf_ext.EagerTensor = tf.concat(batches, axis=batch_dim) # TODO: fix typing mismatch @larme indices: list[int] = list( itertools.accumulate(subbatch.shape[batch_dim] for subbatch in batches) @@ -386,15 +387,15 @@ def batches_to_batch( @classmethod def batch_to_batches( - cls, batch: "tf_ext.EagerTensor", indices: t.Sequence[int], batch_dim: int = 0 - ) -> t.List["tf_ext.EagerTensor"]: + cls, batch: tf_ext.EagerTensor, indices: t.Sequence[int], batch_dim: int = 0 + ) -> t.List[tf_ext.EagerTensor]: size_splits = [indices[i + 1] - indices[i] for i in range(len(indices) - 1)] return tf.split(batch, size_splits, axis=batch_dim) # type: ignore @classmethod def to_payload( cls, - batch: "tf_ext.EagerTensor", + batch: tf_ext.EagerTensor, batch_dim: int = 0, ) -> Payload: @@ -407,14 +408,14 @@ def to_payload( def from_payload( cls, payload: Payload, - ) -> "tf_ext.EagerTensor": + ) -> tf_ext.EagerTensor: return pickle.loads(payload.data) @classmethod def batch_to_payloads( cls, - batch: "tf_ext.EagerTensor", + batch: tf_ext.EagerTensor, indices: t.Sequence[int], batch_dim: int = 0, ) -> t.List[Payload]: @@ -429,7 +430,7 @@ def from_batch_payloads( cls, payloads: t.Sequence[Payload], batch_dim: int = 0, - ) -> t.Tuple["tf_ext.EagerTensor", t.List[int]]: + ) -> t.Tuple[tf_ext.EagerTensor, t.List[int]]: batches = [cls.from_payload(payload) for payload in payloads] return cls.batches_to_batch(batches, batch_dim)