Skip to content

Commit

Permalink
add test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
larme committed Oct 19, 2022
1 parent d7b9a59 commit e33775d
Showing 1 changed file with 97 additions and 0 deletions.
97 changes: 97 additions & 0 deletions tests/integration/frameworks/models/tensorflow_v2.py
Expand Up @@ -54,6 +54,30 @@ def __call__(self, x1: tf.Tensor, x2: tf.Tensor, factor: tf.Tensor):
return self.dense(x1 + x2 * factor)


class MultiOutputModel(tf.Module):
def __init__(self):
super().__init__()
self.v = tf.Variable(2.0)

@tf.function(input_signature=[tf.TensorSpec([1, 5], tf.float32)])
def __call__(self, x: tf.Tensor):
return (x * self.v, x)


# This model could have 2 output signatures depends on the input
class MultiOutputModel2(tf.Module):
def __init__(self):
super().__init__()
self.v = tf.Variable(2.0)

@tf.function
def __call__(self, x):
if x.shape[0] > 2:
return (x * self.v, x)
else:
return x


def make_keras_sequential_model() -> tf.keras.models.Model:
net = keras.models.Sequential(
(
Expand Down Expand Up @@ -155,6 +179,76 @@ def make_keras_functional_model() -> tf.keras.Model:
],
)

native_multi_output_model = FrameworkTestModel(
name="tf2",
model=MultiOutputModel(),
configurations=[
Config(
test_inputs={
"__call__": [
Input(
input_args=[i],
expected=lambda out: np.isclose(out[0], input_array * 2).all(),
)
for i in [
input_tensor,
input_tensor_f32,
input_array,
input_array_i32,
input_data,
]
],
},
),
],
)

input_array2 = np.arange(15, dtype=np.float32).reshape((3, 5))
input_array2_i32 = np.array(input_array2, dtype="int64")
input_tensor2 = tf.constant(input_array2, dtype=tf.float64)
input_tensor2_f32 = tf.constant(input_array2, dtype=tf.float32)

multi_output_model2 = MultiOutputModel2()
# feed some data for tracing
_ = multi_output_model2(np.array(input_array, dtype=np.float32))
_ = multi_output_model2(input_array2)

native_multi_output_model2 = FrameworkTestModel(
name="tf2",
model=multi_output_model2,
configurations=[
Config(
test_inputs={
"__call__": [
Input(
input_args=[i],
expected=lambda out: np.isclose(out, i).all(),
)
for i in [
input_tensor,
input_tensor_f32,
input_array,
input_array_i32,
input_data,
]
]
+ [
Input(
input_args=[i],
expected=lambda out: np.isclose(out[0], input_array2 * 2).all(),
)
for i in [
input_tensor2,
input_tensor2_f32,
input_array2,
input_array2_i32,
]
],
},
),
],
)

keras_models = [
FrameworkTestModel(
name="tf2",
Expand Down Expand Up @@ -186,7 +280,10 @@ def make_keras_functional_model() -> tf.keras.Model:
make_keras_sequential_model(),
]
]

models: list[FrameworkTestModel] = keras_models + [
native_model,
native_multi_input_model,
native_multi_output_model,
native_multi_output_model2,
]

0 comments on commit e33775d

Please sign in to comment.