Skip to content

Commit

Permalink
feat: tensorflow multi outputs support (#3115)
Browse files Browse the repository at this point in the history
  • Loading branch information
larme committed Nov 5, 2022
1 parent e1572cb commit 566220e
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 16 deletions.
11 changes: 11 additions & 0 deletions src/bentoml/_internal/external_typing/tensorflow.py
Expand Up @@ -20,6 +20,16 @@
from tensorflow.python.training.tracking.tracking import AutoTrackable
from tensorflow.python.saved_model.function_deserialization import RestoredFunction

# NOTE: FunctionSpec moved from eager.function to eager.function_spec
# and then to eager.polymorphic_function
try:
from tensorflow.python.eager.function import FunctionSpec
except ImportError:
try:
from tensorflow.python.eager.function_spec import FunctionSpec
except ImportError:
from tensorflow.python.eager.polymorphic_function import FunctionSpec

try:
from tensorflow.python.types.core import GenericFunction
from tensorflow.python.types.core import ConcreteFunction
Expand Down Expand Up @@ -84,6 +94,7 @@ def __repr__(self) -> str:
KerasModel = t.Union[Model, Sequential]

__all__ = [
"EagerTensor",
"CastableTensorType",
"TensorLike",
"InputSignature",
Expand Down
69 changes: 53 additions & 16 deletions src/bentoml/_internal/frameworks/tensorflow_v2.py
Expand Up @@ -25,6 +25,7 @@
from ..runner.container import DataContainerRegistry
from ..utils.tensorflow import get_tf_version
from ..utils.tensorflow import get_input_signatures_v2
from ..utils.tensorflow import get_output_signatures_v2
from ..utils.tensorflow import get_restorable_functions
from ..utils.tensorflow import cast_py_args_to_tf_function_args

Expand All @@ -34,6 +35,9 @@
from ..external_typing import tensorflow as tf_ext

TFArgType = t.Union[t.List[t.Union[int, float]], ext.NpNDArray, tf_ext.Tensor]
TFModelOutputType = tf_ext.EagerTensor | tuple[tf_ext.EagerTensor]
TFRunnableOutputType = ext.NpNDArray | tuple[ext.NpNDArray]


try:
import tensorflow as tf
Expand Down Expand Up @@ -67,7 +71,7 @@ def get(tag_like: str | Tag) -> bentoml.Model:
def load_model(
bento_model: str | Tag | bentoml.Model,
device_name: str = "/device:CPU:0",
) -> "tf_ext.AutoTrackable" | "tf_ext.Module":
) -> tf_ext.AutoTrackable | tf_ext.Module:
"""
Load a tensorflow model from BentoML local modelstore with given name.
Expand Down Expand Up @@ -275,11 +279,44 @@ def _gen_run_method(runnable_self: TensorflowRunnable, method_name: str):
raw_method = getattr(runnable_self.model, method_name)
method_partial_kwargs = partial_kwargs.get(method_name)

output_sigs = get_output_signatures_v2(raw_method)

if len(output_sigs) == 1:

# if there's only one output signatures, then we can
# define the _postprocess function without doing
# conditional casting each time

sig = output_sigs[0]
if isinstance(sig, tuple):

def _postprocess(
res: tuple[tf_ext.EagerTensor],
) -> TFRunnableOutputType:
return tuple(t.cast("ext.NpNDArray", r.numpy()) for r in res)

else:

def _postprocess(res: tf_ext.EagerTensor) -> TFRunnableOutputType:
return t.cast("ext.NpNDArray", res.numpy())

else:

# if there are no output signature or more than one output
# signatures, the post process function need to do casting
# depends on the real output value each time

def _postprocess(res: TFModelOutputType) -> TFRunnableOutputType:
if isinstance(res, tuple):
return tuple(t.cast("ext.NpNDArray", r.numpy()) for r in res)
else:
return t.cast("ext.NpNDArray", res.numpy())

def _run_method(
_runnable_self: TensorflowRunnable,
*args: "TFArgType",
**kwargs: "TFArgType",
) -> "ext.NpNDArray":
*args: TFArgType,
**kwargs: TFArgType,
) -> TFRunnableOutputType:
if method_partial_kwargs is not None:
kwargs = dict(method_partial_kwargs, **kwargs)

Expand All @@ -301,14 +338,14 @@ def _run_method(
raise

res = raw_method(*casted_args)
return t.cast("ext.NpNDArray", res.numpy())
return _postprocess(res)

return _run_method

def add_run_method(method_name: str, options: ModelSignature):
def run_method(
runnable_self: TensorflowRunnable, *args: "TFArgType", **kwargs: "TFArgType"
) -> "ext.NpNDArray":
runnable_self: TensorflowRunnable, *args: TFArgType, **kwargs: TFArgType
) -> TFRunnableOutputType:
_run_method = runnable_self.methods_cache.get(
method_name
) # is methods_cache nessesary?
Expand Down Expand Up @@ -338,9 +375,9 @@ class TensorflowTensorContainer(
):
@classmethod
def batches_to_batch(
cls, batches: t.Sequence["tf_ext.EagerTensor"], batch_dim: int = 0
) -> t.Tuple["tf_ext.EagerTensor", list[int]]:
batch: "tf_ext.EagerTensor" = tf.concat(batches, axis=batch_dim)
cls, batches: t.Sequence[tf_ext.EagerTensor], batch_dim: int = 0
) -> t.Tuple[tf_ext.EagerTensor, list[int]]:
batch: tf_ext.EagerTensor = tf.concat(batches, axis=batch_dim)
# TODO: fix typing mismatch @larme
indices: list[int] = list(
itertools.accumulate(subbatch.shape[batch_dim] for subbatch in batches)
Expand All @@ -350,15 +387,15 @@ def batches_to_batch(

@classmethod
def batch_to_batches(
cls, batch: "tf_ext.EagerTensor", indices: t.Sequence[int], batch_dim: int = 0
) -> t.List["tf_ext.EagerTensor"]:
cls, batch: tf_ext.EagerTensor, indices: t.Sequence[int], batch_dim: int = 0
) -> t.List[tf_ext.EagerTensor]:
size_splits = [indices[i + 1] - indices[i] for i in range(len(indices) - 1)]
return tf.split(batch, size_splits, axis=batch_dim) # type: ignore

@classmethod
def to_payload(
cls,
batch: "tf_ext.EagerTensor",
batch: tf_ext.EagerTensor,
batch_dim: int = 0,
) -> Payload:

Expand All @@ -371,14 +408,14 @@ def to_payload(
def from_payload(
cls,
payload: Payload,
) -> "tf_ext.EagerTensor":
) -> tf_ext.EagerTensor:

return pickle.loads(payload.data)

@classmethod
def batch_to_payloads(
cls,
batch: "tf_ext.EagerTensor",
batch: tf_ext.EagerTensor,
indices: t.Sequence[int],
batch_dim: int = 0,
) -> t.List[Payload]:
Expand All @@ -393,7 +430,7 @@ def from_batch_payloads(
cls,
payloads: t.Sequence[Payload],
batch_dim: int = 0,
) -> t.Tuple["tf_ext.EagerTensor", t.List[int]]:
) -> t.Tuple[tf_ext.EagerTensor, t.List[int]]:
batches = [cls.from_payload(payload) for payload in payloads]
return cls.batches_to_batch(batches, batch_dim)

Expand Down
17 changes: 17 additions & 0 deletions src/bentoml/_internal/utils/tensorflow.py
Expand Up @@ -213,6 +213,23 @@ def get_input_signatures_v2(
return []


def get_output_signatures_v2(
func: tf_ext.RestoredFunction,
) -> list[tuple[tf_ext.TensorSpec, ...] | tf_ext.TensorSpec]:
if hasattr(func, "concrete_functions") and func.concrete_functions:
return [
s
for conc in func.concrete_functions
for s in get_output_signatures_v2(conc)
]

if hasattr(func, "structured_outputs"):
# for concrete_functions
return [func.structured_outputs]

return []


def get_input_signatures(
func: tf_ext.DecoratedFunction,
) -> list[tuple[tf_ext.InputSignature, ...]]:
Expand Down
97 changes: 97 additions & 0 deletions tests/integration/frameworks/models/tensorflow.py
Expand Up @@ -56,6 +56,30 @@ def __call__(self, x1: tf.Tensor, x2: tf.Tensor, factor: tf.Tensor):
return self.dense(x1 + x2 * factor)


class MultiOutputModel(tf.Module):
def __init__(self):
super().__init__()
self.v = tf.Variable(2.0)

@tf.function(input_signature=[tf.TensorSpec([1, 5], tf.float32)])
def __call__(self, x: tf.Tensor):
return (x * self.v, x)


# This model could have 2 output signatures depends on the input
class MultiOutputModel2(tf.Module):
def __init__(self):
super().__init__()
self.v = tf.Variable(2.0)

@tf.function
def __call__(self, x):
if x.shape[0] > 2:
return (x * self.v, x)
else:
return x


def make_keras_sequential_model() -> tf.keras.models.Model:
net = keras.models.Sequential(
(
Expand Down Expand Up @@ -157,6 +181,76 @@ def make_keras_functional_model() -> tf.keras.Model:
],
)

native_multi_output_model = FrameworkTestModel(
name="tf2",
model=MultiOutputModel(),
configurations=[
Config(
test_inputs={
"__call__": [
Input(
input_args=[i],
expected=lambda out: np.isclose(out[0], input_array * 2).all(),
)
for i in [
input_tensor,
input_tensor_f32,
input_array,
input_array_i32,
input_data,
]
],
},
),
],
)

input_array2 = np.arange(15, dtype=np.float32).reshape((3, 5))
input_array2_i32 = np.array(input_array2, dtype="int64")
input_tensor2 = tf.constant(input_array2, dtype=tf.float64)
input_tensor2_f32 = tf.constant(input_array2, dtype=tf.float32)

multi_output_model2 = MultiOutputModel2()
# feed some data for tracing
_ = multi_output_model2(np.array(input_array, dtype=np.float32))
_ = multi_output_model2(input_array2)

native_multi_output_model2 = FrameworkTestModel(
name="tf2",
model=multi_output_model2,
configurations=[
Config(
test_inputs={
"__call__": [
Input(
input_args=[i],
expected=lambda out: np.isclose(out, i).all(),
)
for i in [
input_tensor,
input_tensor_f32,
input_array,
input_array_i32,
input_data,
]
]
+ [
Input(
input_args=[i],
expected=lambda out: np.isclose(out[0], input_array2 * 2).all(),
)
for i in [
input_tensor2,
input_tensor2_f32,
input_array2,
input_array2_i32,
]
],
},
),
],
)

keras_models = [
FrameworkTestModel(
name="tf2",
Expand Down Expand Up @@ -188,7 +282,10 @@ def make_keras_functional_model() -> tf.keras.Model:
make_keras_sequential_model(),
]
]

models: list[FrameworkTestModel] = keras_models + [
native_model,
native_multi_input_model,
native_multi_output_model,
native_multi_output_model2,
]

0 comments on commit 566220e

Please sign in to comment.