Skip to content

Commit

Permalink
tidy up type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
larme committed Nov 4, 2022
1 parent e636774 commit 23e8ecf
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions src/bentoml/_internal/frameworks/tensorflow_v2.py
Expand Up @@ -35,11 +35,12 @@
from ..external_typing import tensorflow as tf_ext

TFArgType = t.Union[t.List[t.Union[int, float]], ext.NpNDArray, tf_ext.Tensor]
TFModelOutputType = tf_ext.EagerTensor | tuple[tf_ext.EagerTensor]
TFRunnableOutputType = ext.NpNDArray | tuple[ext.NpNDArray]


try:
import tensorflow as tf
TFModelOutputType = tf_ext.EagerTensor | tuple[tf_ext.EagerTensor]
TFRunnableOutputType = ext.NpNDArray | tuple[ext.NpNDArray]
except ImportError: # pragma: no cover
raise MissingDependencyException(
"'tensorflow' is required in order to use module 'bentoml.tensorflow', install tensorflow with 'pip install tensorflow'. For more information, refer to https://www.tensorflow.org/install"
Expand Down Expand Up @@ -70,7 +71,7 @@ def get(tag_like: str | Tag) -> bentoml.Model:
def load_model(
bento_model: str | Tag | bentoml.Model,
device_name: str = "/device:CPU:0",
) -> "tf_ext.AutoTrackable" | "tf_ext.Module":
) -> tf_ext.AutoTrackable | tf_ext.Module:
"""
Load a tensorflow model from BentoML local modelstore with given name.
Expand Down Expand Up @@ -343,7 +344,7 @@ def _run_method(

def add_run_method(method_name: str, options: ModelSignature):
def run_method(
runnable_self: TensorflowRunnable, *args: "TFArgType", **kwargs: "TFArgType"
runnable_self: TensorflowRunnable, *args: TFArgType, **kwargs: TFArgType
) -> TFRunnableOutputType:
_run_method = runnable_self.methods_cache.get(
method_name
Expand Down Expand Up @@ -374,9 +375,9 @@ class TensorflowTensorContainer(
):
@classmethod
def batches_to_batch(
cls, batches: t.Sequence["tf_ext.EagerTensor"], batch_dim: int = 0
) -> t.Tuple["tf_ext.EagerTensor", list[int]]:
batch: "tf_ext.EagerTensor" = tf.concat(batches, axis=batch_dim)
cls, batches: t.Sequence[tf_ext.EagerTensor], batch_dim: int = 0
) -> t.Tuple[tf_ext.EagerTensor, list[int]]:
batch: tf_ext.EagerTensor = tf.concat(batches, axis=batch_dim)
# TODO: fix typing mismatch @larme
indices: list[int] = list(
itertools.accumulate(subbatch.shape[batch_dim] for subbatch in batches)
Expand All @@ -386,15 +387,15 @@ def batches_to_batch(

@classmethod
def batch_to_batches(
cls, batch: "tf_ext.EagerTensor", indices: t.Sequence[int], batch_dim: int = 0
) -> t.List["tf_ext.EagerTensor"]:
cls, batch: tf_ext.EagerTensor, indices: t.Sequence[int], batch_dim: int = 0
) -> t.List[tf_ext.EagerTensor]:
size_splits = [indices[i + 1] - indices[i] for i in range(len(indices) - 1)]
return tf.split(batch, size_splits, axis=batch_dim) # type: ignore

@classmethod
def to_payload(
cls,
batch: "tf_ext.EagerTensor",
batch: tf_ext.EagerTensor,
batch_dim: int = 0,
) -> Payload:

Expand All @@ -407,14 +408,14 @@ def to_payload(
def from_payload(
cls,
payload: Payload,
) -> "tf_ext.EagerTensor":
) -> tf_ext.EagerTensor:

return pickle.loads(payload.data)

@classmethod
def batch_to_payloads(
cls,
batch: "tf_ext.EagerTensor",
batch: tf_ext.EagerTensor,
indices: t.Sequence[int],
batch_dim: int = 0,
) -> t.List[Payload]:
Expand All @@ -429,7 +430,7 @@ def from_batch_payloads(
cls,
payloads: t.Sequence[Payload],
batch_dim: int = 0,
) -> t.Tuple["tf_ext.EagerTensor", t.List[int]]:
) -> t.Tuple[tf_ext.EagerTensor, t.List[int]]:
batches = [cls.from_payload(payload) for payload in payloads]
return cls.batches_to_batch(batches, batch_dim)

Expand Down

0 comments on commit 23e8ecf

Please sign in to comment.