Skip to content

Commit

Permalink
Add support for iterators in Interpreter.predict (#511)
Browse files Browse the repository at this point in the history
Co-authored-by: Adam Hillier <7688302+AdamHillier@users.noreply.github.com>
  • Loading branch information
lgeiger and AdamHillier committed Sep 21, 2020
1 parent 35e0670 commit ca5f8dd
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 22 deletions.
47 changes: 29 additions & 18 deletions larq_compute_engine/tflite/python/interpreter.py
@@ -1,4 +1,4 @@
from typing import List, Tuple, Union
from typing import Iterator, List, Tuple, Union

import numpy as np
from tqdm import tqdm
Expand All @@ -7,6 +7,28 @@

__all__ = ["Interpreter"]

Data = Union[np.ndarray, List[np.ndarray]]


def data_generator(x: Union[Data, Iterator[Data]]) -> Iterator[List[np.ndarray]]:
if isinstance(x, np.ndarray):
for inputs in x:
yield [np.expand_dims(inputs, axis=0)]
elif isinstance(x, list):
for inputs in zip(*x):
yield [np.expand_dims(inp, axis=0) for inp in inputs]
elif hasattr(x, "__next__") and hasattr(x, "__iter__"):
for inputs in x:
if isinstance(inputs, np.ndarray):
yield [np.expand_dims(inputs, axis=0)]
else:
yield [np.expand_dims(inp, axis=0) for inp in inputs]
else:
raise ValueError(
"Expected either a list of inputs or a Numpy array with implicit initial "
f"batch dimension or an iterator yielding one of the above. Received: {x}"
)


class Interpreter:
"""Interpreter interface for Larq Compute Engine Models.
Expand Down Expand Up @@ -51,9 +73,7 @@ def output_shapes(self) -> List[Tuple[int]]:
"""Returns a list of output shapes."""
return self.interpreter.output_shapes

def predict(
self, x: Union[np.ndarray, List[np.ndarray]], verbose: int = 0
) -> Union[np.ndarray, List[np.ndarray]]:
def predict(self, x: Union[Data, Iterator[Data]], verbose: int = 0) -> Data:
"""Generates output predictions for the input samples.
# Arguments
Expand All @@ -65,21 +85,12 @@ def predict(
Numpy array(s) of output predictions.
"""

if not isinstance(x, (list, np.ndarray)) or len(x) == 0:
raise ValueError(
"Expected either a non-empty list of inputs or a Numpy array with "
f"implicit initial batch dimension. Received: {x}"
)
data_iterator = data_generator(x)
if verbose >= 1:
data_iterator = tqdm(data_iterator)

if len(self.input_shapes) == 1:
x = [x]

batch_iter = tqdm(zip(*x)) if verbose >= 1 else zip(*x)
prediction_batch_iter = (
self.interpreter.predict([np.expand_dims(inp, axis=0) for inp in inputs])
for inputs in batch_iter
)
outputs = [np.concatenate(batches) for batches in zip(*prediction_batch_iter)]
prediction_iter = (self.interpreter.predict(inputs) for inputs in data_iterator)
outputs = [np.concatenate(batches) for batches in zip(*prediction_iter)]

if len(self.output_shapes) == 1:
return outputs[0]
Expand Down
20 changes: 16 additions & 4 deletions larq_compute_engine/tflite/tests/interpreter_test.py
Expand Up @@ -7,7 +7,8 @@
from larq_compute_engine.tflite.python.interpreter import Interpreter


def test_interpreter():
@pytest.mark.parametrize("use_iterator", [True, False])
def test_interpreter(use_iterator):
input_shape = (24, 24, 3)
x = tf.keras.Input(input_shape)
model = tf.keras.Model(x, tf.keras.layers.Flatten()(x))
Expand All @@ -22,11 +23,17 @@ def test_interpreter():
assert interpreter.input_shapes == [(1, *input_shape)]
assert interpreter.output_shapes == [(1, np.product(input_shape))]

outputs = interpreter.predict(inputs, 1)
def input_fn():
if use_iterator:
return (input for input in inputs)
return inputs

outputs = interpreter.predict(input_fn(), 1)
np.testing.assert_allclose(outputs, expected_outputs)


def test_interpreter_multi_input():
@pytest.mark.parametrize("use_iterator", [True, False])
def test_interpreter_multi_input(use_iterator):
x = tf.keras.Input((24, 24, 2))
y = tf.keras.Input((24, 24, 1))
model = tf.keras.Model(
Expand All @@ -45,7 +52,12 @@ def test_interpreter_multi_input():
assert interpreter.input_shapes == [(1, 24, 24, 2), (1, 24, 24, 1)]
assert interpreter.output_shapes == [(1, 24 * 24 * 2), (1, 24 * 24 * 1)]

output_x, output_y = interpreter.predict([x_np, y_np])
def input_fn():
if use_iterator:
return ([x, y] for x, y in zip(x_np, y_np))
return [x_np, y_np]

output_x, output_y = interpreter.predict(input_fn())
np.testing.assert_allclose(output_x, expected_output_x)
np.testing.assert_allclose(output_y, expected_output_y)

Expand Down

0 comments on commit ca5f8dd

Please sign in to comment.