Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: tensorflow multi outputs support #3115

Merged
merged 4 commits into from Nov 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/bentoml/_internal/external_typing/tensorflow.py
Expand Up @@ -20,6 +20,16 @@
from tensorflow.python.training.tracking.tracking import AutoTrackable
from tensorflow.python.saved_model.function_deserialization import RestoredFunction

# NOTE: FunctionSpec moved from eager.function to eager.function_spec
# and then to eager.polymorphic_function
try:
from tensorflow.python.eager.function import FunctionSpec
except ImportError:
try:
from tensorflow.python.eager.function_spec import FunctionSpec
except ImportError:
from tensorflow.python.eager.polymorphic_function import FunctionSpec

try:
from tensorflow.python.types.core import GenericFunction
from tensorflow.python.types.core import ConcreteFunction
Expand Down Expand Up @@ -84,6 +94,7 @@ def __repr__(self) -> str:
KerasModel = t.Union[Model, Sequential]

__all__ = [
"EagerTensor",
"CastableTensorType",
"TensorLike",
"InputSignature",
Expand Down
69 changes: 53 additions & 16 deletions src/bentoml/_internal/frameworks/tensorflow_v2.py
Expand Up @@ -25,6 +25,7 @@
from ..runner.container import DataContainerRegistry
from ..utils.tensorflow import get_tf_version
from ..utils.tensorflow import get_input_signatures_v2
from ..utils.tensorflow import get_output_signatures_v2
from ..utils.tensorflow import get_restorable_functions
from ..utils.tensorflow import cast_py_args_to_tf_function_args

Expand All @@ -34,6 +35,9 @@
from ..external_typing import tensorflow as tf_ext

TFArgType = t.Union[t.List[t.Union[int, float]], ext.NpNDArray, tf_ext.Tensor]
TFModelOutputType = tf_ext.EagerTensor | tuple[tf_ext.EagerTensor]
TFRunnableOutputType = ext.NpNDArray | tuple[ext.NpNDArray]


try:
import tensorflow as tf
Expand Down Expand Up @@ -67,7 +71,7 @@ def get(tag_like: str | Tag) -> bentoml.Model:
def load_model(
bento_model: str | Tag | bentoml.Model,
device_name: str = "/device:CPU:0",
) -> "tf_ext.AutoTrackable" | "tf_ext.Module":
) -> tf_ext.AutoTrackable | tf_ext.Module:
"""
Load a tensorflow model from BentoML local modelstore with given name.

Expand Down Expand Up @@ -275,11 +279,44 @@ def _gen_run_method(runnable_self: TensorflowRunnable, method_name: str):
raw_method = getattr(runnable_self.model, method_name)
method_partial_kwargs = partial_kwargs.get(method_name)

output_sigs = get_output_signatures_v2(raw_method)

if len(output_sigs) == 1:

# if there's only one output signatures, then we can
# define the _postprocess function without doing
# conditional casting each time

sig = output_sigs[0]
if isinstance(sig, tuple):

def _postprocess(
res: tuple[tf_ext.EagerTensor],
) -> TFRunnableOutputType:
return tuple(t.cast("ext.NpNDArray", r.numpy()) for r in res)
aarnphm marked this conversation as resolved.
Show resolved Hide resolved

else:

def _postprocess(res: tf_ext.EagerTensor) -> TFRunnableOutputType:
return t.cast("ext.NpNDArray", res.numpy())

else:

# if there are no output signature or more than one output
# signatures, the post process function need to do casting
# depends on the real output value each time

def _postprocess(res: TFModelOutputType) -> TFRunnableOutputType:
if isinstance(res, tuple):
return tuple(t.cast("ext.NpNDArray", r.numpy()) for r in res)
else:
return t.cast("ext.NpNDArray", res.numpy())

def _run_method(
_runnable_self: TensorflowRunnable,
*args: "TFArgType",
**kwargs: "TFArgType",
) -> "ext.NpNDArray":
*args: TFArgType,
**kwargs: TFArgType,
) -> TFRunnableOutputType:
if method_partial_kwargs is not None:
kwargs = dict(method_partial_kwargs, **kwargs)

Expand All @@ -301,14 +338,14 @@ def _run_method(
raise

res = raw_method(*casted_args)
return t.cast("ext.NpNDArray", res.numpy())
return _postprocess(res)

return _run_method

def add_run_method(method_name: str, options: ModelSignature):
def run_method(
runnable_self: TensorflowRunnable, *args: "TFArgType", **kwargs: "TFArgType"
) -> "ext.NpNDArray":
runnable_self: TensorflowRunnable, *args: TFArgType, **kwargs: TFArgType
) -> TFRunnableOutputType:
_run_method = runnable_self.methods_cache.get(
method_name
) # is methods_cache nessesary?
Expand Down Expand Up @@ -338,9 +375,9 @@ class TensorflowTensorContainer(
):
@classmethod
def batches_to_batch(
cls, batches: t.Sequence["tf_ext.EagerTensor"], batch_dim: int = 0
) -> t.Tuple["tf_ext.EagerTensor", list[int]]:
batch: "tf_ext.EagerTensor" = tf.concat(batches, axis=batch_dim)
cls, batches: t.Sequence[tf_ext.EagerTensor], batch_dim: int = 0
) -> t.Tuple[tf_ext.EagerTensor, list[int]]:
batch: tf_ext.EagerTensor = tf.concat(batches, axis=batch_dim)
# TODO: fix typing mismatch @larme
indices: list[int] = list(
itertools.accumulate(subbatch.shape[batch_dim] for subbatch in batches)
Expand All @@ -350,15 +387,15 @@ def batches_to_batch(

@classmethod
def batch_to_batches(
cls, batch: "tf_ext.EagerTensor", indices: t.Sequence[int], batch_dim: int = 0
) -> t.List["tf_ext.EagerTensor"]:
cls, batch: tf_ext.EagerTensor, indices: t.Sequence[int], batch_dim: int = 0
) -> t.List[tf_ext.EagerTensor]:
size_splits = [indices[i + 1] - indices[i] for i in range(len(indices) - 1)]
return tf.split(batch, size_splits, axis=batch_dim) # type: ignore

@classmethod
def to_payload(
cls,
batch: "tf_ext.EagerTensor",
batch: tf_ext.EagerTensor,
batch_dim: int = 0,
) -> Payload:

Expand All @@ -371,14 +408,14 @@ def to_payload(
def from_payload(
cls,
payload: Payload,
) -> "tf_ext.EagerTensor":
) -> tf_ext.EagerTensor:

return pickle.loads(payload.data)

@classmethod
def batch_to_payloads(
cls,
batch: "tf_ext.EagerTensor",
batch: tf_ext.EagerTensor,
indices: t.Sequence[int],
batch_dim: int = 0,
) -> t.List[Payload]:
Expand All @@ -393,7 +430,7 @@ def from_batch_payloads(
cls,
payloads: t.Sequence[Payload],
batch_dim: int = 0,
) -> t.Tuple["tf_ext.EagerTensor", t.List[int]]:
) -> t.Tuple[tf_ext.EagerTensor, t.List[int]]:
batches = [cls.from_payload(payload) for payload in payloads]
return cls.batches_to_batch(batches, batch_dim)

Expand Down
17 changes: 17 additions & 0 deletions src/bentoml/_internal/utils/tensorflow.py
Expand Up @@ -213,6 +213,23 @@ def get_input_signatures_v2(
return []


def get_output_signatures_v2(
func: tf_ext.RestoredFunction,
) -> list[tuple[tf_ext.TensorSpec, ...] | tf_ext.TensorSpec]:
if hasattr(func, "concrete_functions") and func.concrete_functions:
return [
s
for conc in func.concrete_functions
for s in get_output_signatures_v2(conc)
]

if hasattr(func, "structured_outputs"):
# for concrete_functions
return [func.structured_outputs]

return []


def get_input_signatures(
func: tf_ext.DecoratedFunction,
) -> list[tuple[tf_ext.InputSignature, ...]]:
Expand Down
97 changes: 97 additions & 0 deletions tests/integration/frameworks/models/tensorflow.py
Expand Up @@ -56,6 +56,30 @@ def __call__(self, x1: tf.Tensor, x2: tf.Tensor, factor: tf.Tensor):
return self.dense(x1 + x2 * factor)


class MultiOutputModel(tf.Module):
def __init__(self):
super().__init__()
self.v = tf.Variable(2.0)

@tf.function(input_signature=[tf.TensorSpec([1, 5], tf.float32)])
def __call__(self, x: tf.Tensor):
return (x * self.v, x)


# This model could have 2 output signatures depends on the input
class MultiOutputModel2(tf.Module):
def __init__(self):
super().__init__()
self.v = tf.Variable(2.0)

@tf.function
def __call__(self, x):
if x.shape[0] > 2:
return (x * self.v, x)
else:
return x


def make_keras_sequential_model() -> tf.keras.models.Model:
net = keras.models.Sequential(
(
Expand Down Expand Up @@ -157,6 +181,76 @@ def make_keras_functional_model() -> tf.keras.Model:
],
)

native_multi_output_model = FrameworkTestModel(
name="tf2",
model=MultiOutputModel(),
configurations=[
Config(
test_inputs={
"__call__": [
Input(
input_args=[i],
expected=lambda out: np.isclose(out[0], input_array * 2).all(),
)
for i in [
input_tensor,
input_tensor_f32,
input_array,
input_array_i32,
input_data,
]
],
},
),
],
)

input_array2 = np.arange(15, dtype=np.float32).reshape((3, 5))
input_array2_i32 = np.array(input_array2, dtype="int64")
input_tensor2 = tf.constant(input_array2, dtype=tf.float64)
input_tensor2_f32 = tf.constant(input_array2, dtype=tf.float32)

multi_output_model2 = MultiOutputModel2()
# feed some data for tracing
_ = multi_output_model2(np.array(input_array, dtype=np.float32))
_ = multi_output_model2(input_array2)

native_multi_output_model2 = FrameworkTestModel(
name="tf2",
model=multi_output_model2,
configurations=[
Config(
test_inputs={
"__call__": [
Input(
input_args=[i],
expected=lambda out: np.isclose(out, i).all(),
)
for i in [
input_tensor,
input_tensor_f32,
input_array,
input_array_i32,
input_data,
]
]
+ [
Input(
input_args=[i],
expected=lambda out: np.isclose(out[0], input_array2 * 2).all(),
)
for i in [
input_tensor2,
input_tensor2_f32,
input_array2,
input_array2_i32,
]
],
},
),
],
)

keras_models = [
FrameworkTestModel(
name="tf2",
Expand Down Expand Up @@ -188,7 +282,10 @@ def make_keras_functional_model() -> tf.keras.Model:
make_keras_sequential_model(),
]
]

models: list[FrameworkTestModel] = keras_models + [
native_model,
native_multi_input_model,
native_multi_output_model,
native_multi_output_model2,
]