-
Notifications
You must be signed in to change notification settings - Fork 33
/
interpreter_test.py
66 lines (50 loc) · 2.33 KB
/
interpreter_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import sys
import numpy as np
import pytest
import tensorflow as tf
from larq_compute_engine.tflite.python.interpreter import Interpreter
@pytest.mark.parametrize("use_iterator", [True, False])
def test_interpreter(use_iterator):
input_shape = (24, 24, 3)
x = tf.keras.Input(input_shape)
model = tf.keras.Model(x, tf.keras.layers.Flatten()(x))
converter = tf.lite.TFLiteConverter.from_keras_model(model)
inputs = np.random.uniform(-1, 1, size=(16, *input_shape)).astype(np.float32)
expected_outputs = inputs.reshape(16, -1)
interpreter = Interpreter(converter.convert())
assert interpreter.input_types == [np.float32]
assert interpreter.output_types == [np.float32]
assert interpreter.input_shapes == [(1, *input_shape)]
assert interpreter.output_shapes == [(1, np.product(input_shape))]
def input_fn():
if use_iterator:
return (input for input in inputs)
return inputs
outputs = interpreter.predict(input_fn(), 1)
np.testing.assert_allclose(outputs, expected_outputs)
@pytest.mark.parametrize("use_iterator", [True, False])
def test_interpreter_multi_input(use_iterator):
x = tf.keras.Input((24, 24, 2))
y = tf.keras.Input((24, 24, 1))
model = tf.keras.Model(
[x, y], [tf.keras.layers.Flatten()(x), tf.keras.layers.Flatten()(y)]
)
converter = tf.lite.TFLiteConverter.from_keras_model(model)
x_np = np.random.uniform(-1, 1, size=(16, 24, 24, 2)).astype(np.float32)
y_np = np.random.uniform(-1, 1, size=(16, 24, 24, 1)).astype(np.float32)
expected_output_x = x_np.reshape(16, -1)
expected_output_y = y_np.reshape(16, -1)
interpreter = Interpreter(converter.convert(), num_threads=2)
assert interpreter.input_types == [np.float32, np.float32]
assert interpreter.output_types == [np.float32, np.float32]
assert interpreter.input_shapes == [(1, 24, 24, 2), (1, 24, 24, 1)]
assert interpreter.output_shapes == [(1, 24 * 24 * 2), (1, 24 * 24 * 1)]
def input_fn():
if use_iterator:
return ([x, y] for x, y in zip(x_np, y_np))
return [x_np, y_np]
output_x, output_y = interpreter.predict(input_fn())
np.testing.assert_allclose(output_x, expected_output_x)
np.testing.assert_allclose(output_y, expected_output_y)
if __name__ == "__main__":
sys.exit(pytest.main([__file__, "-s"]))