diff --git a/src/bentoml/_internal/external_typing/tensorflow.py b/src/bentoml/_internal/external_typing/tensorflow.py index 7b6fd43f343..bbfce38f6e1 100644 --- a/src/bentoml/_internal/external_typing/tensorflow.py +++ b/src/bentoml/_internal/external_typing/tensorflow.py @@ -20,6 +20,16 @@ from tensorflow.python.training.tracking.tracking import AutoTrackable from tensorflow.python.saved_model.function_deserialization import RestoredFunction +# NOTE: FunctionSpec moved from eager.function to eager.function_spec +# and then to eager.polymorphic_function +try: + from tensorflow.python.eager.function import FunctionSpec +except ImportError: + try: + from tensorflow.python.eager.function_spec import FunctionSpec + except ImportError: + from tensorflow.python.eager.polymorphic_function import FunctionSpec + try: from tensorflow.python.types.core import GenericFunction from tensorflow.python.types.core import ConcreteFunction @@ -84,6 +94,7 @@ def __repr__(self) -> str: KerasModel = t.Union[Model, Sequential] __all__ = [ + "EagerTensor", "CastableTensorType", "TensorLike", "InputSignature", diff --git a/src/bentoml/_internal/frameworks/tensorflow_v2.py b/src/bentoml/_internal/frameworks/tensorflow_v2.py index 12c916f3db6..3f603c7fe5d 100644 --- a/src/bentoml/_internal/frameworks/tensorflow_v2.py +++ b/src/bentoml/_internal/frameworks/tensorflow_v2.py @@ -25,6 +25,7 @@ from ..runner.container import DataContainerRegistry from ..utils.tensorflow import get_tf_version from ..utils.tensorflow import get_input_signatures_v2 +from ..utils.tensorflow import get_output_signatures_v2 from ..utils.tensorflow import get_restorable_functions from ..utils.tensorflow import cast_py_args_to_tf_function_args @@ -34,6 +35,9 @@ from ..external_typing import tensorflow as tf_ext TFArgType = t.Union[t.List[t.Union[int, float]], ext.NpNDArray, tf_ext.Tensor] + TFModelOutputType = tf_ext.EagerTensor | tuple[tf_ext.EagerTensor] + TFRunnableOutputType = ext.NpNDArray | tuple[ext.NpNDArray] + try: import tensorflow as tf @@ -67,7 +71,7 @@ def get(tag_like: str | Tag) -> bentoml.Model: def load_model( bento_model: str | Tag | bentoml.Model, device_name: str = "/device:CPU:0", -) -> "tf_ext.AutoTrackable" | "tf_ext.Module": +) -> tf_ext.AutoTrackable | tf_ext.Module: """ Load a tensorflow model from BentoML local modelstore with given name. @@ -275,11 +279,44 @@ def _gen_run_method(runnable_self: TensorflowRunnable, method_name: str): raw_method = getattr(runnable_self.model, method_name) method_partial_kwargs = partial_kwargs.get(method_name) + output_sigs = get_output_signatures_v2(raw_method) + + if len(output_sigs) == 1: + + # if there's only one output signatures, then we can + # define the _postprocess function without doing + # conditional casting each time + + sig = output_sigs[0] + if isinstance(sig, tuple): + + def _postprocess( + res: tuple[tf_ext.EagerTensor], + ) -> TFRunnableOutputType: + return tuple(t.cast("ext.NpNDArray", r.numpy()) for r in res) + + else: + + def _postprocess(res: tf_ext.EagerTensor) -> TFRunnableOutputType: + return t.cast("ext.NpNDArray", res.numpy()) + + else: + + # if there are no output signature or more than one output + # signatures, the post process function need to do casting + # depends on the real output value each time + + def _postprocess(res: TFModelOutputType) -> TFRunnableOutputType: + if isinstance(res, tuple): + return tuple(t.cast("ext.NpNDArray", r.numpy()) for r in res) + else: + return t.cast("ext.NpNDArray", res.numpy()) + def _run_method( _runnable_self: TensorflowRunnable, - *args: "TFArgType", - **kwargs: "TFArgType", - ) -> "ext.NpNDArray": + *args: TFArgType, + **kwargs: TFArgType, + ) -> TFRunnableOutputType: if method_partial_kwargs is not None: kwargs = dict(method_partial_kwargs, **kwargs) @@ -301,14 +338,14 @@ def _run_method( raise res = raw_method(*casted_args) - return t.cast("ext.NpNDArray", res.numpy()) + return _postprocess(res) return _run_method def add_run_method(method_name: str, options: ModelSignature): def run_method( - runnable_self: TensorflowRunnable, *args: "TFArgType", **kwargs: "TFArgType" - ) -> "ext.NpNDArray": + runnable_self: TensorflowRunnable, *args: TFArgType, **kwargs: TFArgType + ) -> TFRunnableOutputType: _run_method = runnable_self.methods_cache.get( method_name ) # is methods_cache nessesary? @@ -338,9 +375,9 @@ class TensorflowTensorContainer( ): @classmethod def batches_to_batch( - cls, batches: t.Sequence["tf_ext.EagerTensor"], batch_dim: int = 0 - ) -> t.Tuple["tf_ext.EagerTensor", list[int]]: - batch: "tf_ext.EagerTensor" = tf.concat(batches, axis=batch_dim) + cls, batches: t.Sequence[tf_ext.EagerTensor], batch_dim: int = 0 + ) -> t.Tuple[tf_ext.EagerTensor, list[int]]: + batch: tf_ext.EagerTensor = tf.concat(batches, axis=batch_dim) # TODO: fix typing mismatch @larme indices: list[int] = list( itertools.accumulate(subbatch.shape[batch_dim] for subbatch in batches) @@ -350,15 +387,15 @@ def batches_to_batch( @classmethod def batch_to_batches( - cls, batch: "tf_ext.EagerTensor", indices: t.Sequence[int], batch_dim: int = 0 - ) -> t.List["tf_ext.EagerTensor"]: + cls, batch: tf_ext.EagerTensor, indices: t.Sequence[int], batch_dim: int = 0 + ) -> t.List[tf_ext.EagerTensor]: size_splits = [indices[i + 1] - indices[i] for i in range(len(indices) - 1)] return tf.split(batch, size_splits, axis=batch_dim) # type: ignore @classmethod def to_payload( cls, - batch: "tf_ext.EagerTensor", + batch: tf_ext.EagerTensor, batch_dim: int = 0, ) -> Payload: @@ -371,14 +408,14 @@ def to_payload( def from_payload( cls, payload: Payload, - ) -> "tf_ext.EagerTensor": + ) -> tf_ext.EagerTensor: return pickle.loads(payload.data) @classmethod def batch_to_payloads( cls, - batch: "tf_ext.EagerTensor", + batch: tf_ext.EagerTensor, indices: t.Sequence[int], batch_dim: int = 0, ) -> t.List[Payload]: @@ -393,7 +430,7 @@ def from_batch_payloads( cls, payloads: t.Sequence[Payload], batch_dim: int = 0, - ) -> t.Tuple["tf_ext.EagerTensor", t.List[int]]: + ) -> t.Tuple[tf_ext.EagerTensor, t.List[int]]: batches = [cls.from_payload(payload) for payload in payloads] return cls.batches_to_batch(batches, batch_dim) diff --git a/src/bentoml/_internal/utils/tensorflow.py b/src/bentoml/_internal/utils/tensorflow.py index 0d35a3e4bf9..df5188171d2 100644 --- a/src/bentoml/_internal/utils/tensorflow.py +++ b/src/bentoml/_internal/utils/tensorflow.py @@ -213,6 +213,23 @@ def get_input_signatures_v2( return [] +def get_output_signatures_v2( + func: tf_ext.RestoredFunction, +) -> list[tuple[tf_ext.TensorSpec, ...] | tf_ext.TensorSpec]: + if hasattr(func, "concrete_functions") and func.concrete_functions: + return [ + s + for conc in func.concrete_functions + for s in get_output_signatures_v2(conc) + ] + + if hasattr(func, "structured_outputs"): + # for concrete_functions + return [func.structured_outputs] + + return [] + + def get_input_signatures( func: tf_ext.DecoratedFunction, ) -> list[tuple[tf_ext.InputSignature, ...]]: diff --git a/tests/integration/frameworks/models/tensorflow.py b/tests/integration/frameworks/models/tensorflow.py index 06c44aa219b..dc729756656 100644 --- a/tests/integration/frameworks/models/tensorflow.py +++ b/tests/integration/frameworks/models/tensorflow.py @@ -56,6 +56,30 @@ def __call__(self, x1: tf.Tensor, x2: tf.Tensor, factor: tf.Tensor): return self.dense(x1 + x2 * factor) +class MultiOutputModel(tf.Module): + def __init__(self): + super().__init__() + self.v = tf.Variable(2.0) + + @tf.function(input_signature=[tf.TensorSpec([1, 5], tf.float32)]) + def __call__(self, x: tf.Tensor): + return (x * self.v, x) + + +# This model could have 2 output signatures depends on the input +class MultiOutputModel2(tf.Module): + def __init__(self): + super().__init__() + self.v = tf.Variable(2.0) + + @tf.function + def __call__(self, x): + if x.shape[0] > 2: + return (x * self.v, x) + else: + return x + + def make_keras_sequential_model() -> tf.keras.models.Model: net = keras.models.Sequential( ( @@ -157,6 +181,76 @@ def make_keras_functional_model() -> tf.keras.Model: ], ) +native_multi_output_model = FrameworkTestModel( + name="tf2", + model=MultiOutputModel(), + configurations=[ + Config( + test_inputs={ + "__call__": [ + Input( + input_args=[i], + expected=lambda out: np.isclose(out[0], input_array * 2).all(), + ) + for i in [ + input_tensor, + input_tensor_f32, + input_array, + input_array_i32, + input_data, + ] + ], + }, + ), + ], +) + +input_array2 = np.arange(15, dtype=np.float32).reshape((3, 5)) +input_array2_i32 = np.array(input_array2, dtype="int64") +input_tensor2 = tf.constant(input_array2, dtype=tf.float64) +input_tensor2_f32 = tf.constant(input_array2, dtype=tf.float32) + +multi_output_model2 = MultiOutputModel2() +# feed some data for tracing +_ = multi_output_model2(np.array(input_array, dtype=np.float32)) +_ = multi_output_model2(input_array2) + +native_multi_output_model2 = FrameworkTestModel( + name="tf2", + model=multi_output_model2, + configurations=[ + Config( + test_inputs={ + "__call__": [ + Input( + input_args=[i], + expected=lambda out: np.isclose(out, i).all(), + ) + for i in [ + input_tensor, + input_tensor_f32, + input_array, + input_array_i32, + input_data, + ] + ] + + [ + Input( + input_args=[i], + expected=lambda out: np.isclose(out[0], input_array2 * 2).all(), + ) + for i in [ + input_tensor2, + input_tensor2_f32, + input_array2, + input_array2_i32, + ] + ], + }, + ), + ], +) + keras_models = [ FrameworkTestModel( name="tf2", @@ -188,7 +282,10 @@ def make_keras_functional_model() -> tf.keras.Model: make_keras_sequential_model(), ] ] + models: list[FrameworkTestModel] = keras_models + [ native_model, native_multi_input_model, + native_multi_output_model, + native_multi_output_model2, ]