diff --git a/larq_compute_engine/tflite/python/interpreter.py b/larq_compute_engine/tflite/python/interpreter.py index e3c072114..02754188b 100644 --- a/larq_compute_engine/tflite/python/interpreter.py +++ b/larq_compute_engine/tflite/python/interpreter.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Union +from typing import Iterator, List, Tuple, Union import numpy as np from tqdm import tqdm @@ -7,6 +7,28 @@ __all__ = ["Interpreter"] +Data = Union[np.ndarray, List[np.ndarray]] + + +def data_generator(x: Union[Data, Iterator[Data]]) -> Iterator[List[np.ndarray]]: + if isinstance(x, np.ndarray): + for inputs in x: + yield [np.expand_dims(inputs, axis=0)] + elif isinstance(x, list): + for inputs in zip(*x): + yield [np.expand_dims(inp, axis=0) for inp in inputs] + elif hasattr(x, "__next__") and hasattr(x, "__iter__"): + for inputs in x: + if isinstance(inputs, np.ndarray): + yield [np.expand_dims(inputs, axis=0)] + else: + yield [np.expand_dims(inp, axis=0) for inp in inputs] + else: + raise ValueError( + "Expected either a list of inputs or a Numpy array with implicit initial " + f"batch dimension or an iterator yielding one of the above. Received: {x}" + ) + class Interpreter: """Interpreter interface for Larq Compute Engine Models. @@ -51,9 +73,7 @@ def output_shapes(self) -> List[Tuple[int]]: """Returns a list of output shapes.""" return self.interpreter.output_shapes - def predict( - self, x: Union[np.ndarray, List[np.ndarray]], verbose: int = 0 - ) -> Union[np.ndarray, List[np.ndarray]]: + def predict(self, x: Union[Data, Iterator[Data]], verbose: int = 0) -> Data: """Generates output predictions for the input samples. # Arguments @@ -65,21 +85,12 @@ def predict( Numpy array(s) of output predictions. """ - if not isinstance(x, (list, np.ndarray)) or len(x) == 0: - raise ValueError( - "Expected either a non-empty list of inputs or a Numpy array with " - f"implicit initial batch dimension. Received: {x}" - ) + data_iterator = data_generator(x) + if verbose >= 1: + data_iterator = tqdm(data_iterator) - if len(self.input_shapes) == 1: - x = [x] - - batch_iter = tqdm(zip(*x)) if verbose >= 1 else zip(*x) - prediction_batch_iter = ( - self.interpreter.predict([np.expand_dims(inp, axis=0) for inp in inputs]) - for inputs in batch_iter - ) - outputs = [np.concatenate(batches) for batches in zip(*prediction_batch_iter)] + prediction_iter = (self.interpreter.predict(inputs) for inputs in data_iterator) + outputs = [np.concatenate(batches) for batches in zip(*prediction_iter)] if len(self.output_shapes) == 1: return outputs[0] diff --git a/larq_compute_engine/tflite/tests/interpreter_test.py b/larq_compute_engine/tflite/tests/interpreter_test.py index 71e50a0d9..0196e8d28 100644 --- a/larq_compute_engine/tflite/tests/interpreter_test.py +++ b/larq_compute_engine/tflite/tests/interpreter_test.py @@ -7,7 +7,8 @@ from larq_compute_engine.tflite.python.interpreter import Interpreter -def test_interpreter(): +@pytest.mark.parametrize("use_iterator", [True, False]) +def test_interpreter(use_iterator): input_shape = (24, 24, 3) x = tf.keras.Input(input_shape) model = tf.keras.Model(x, tf.keras.layers.Flatten()(x)) @@ -22,11 +23,17 @@ def test_interpreter(): assert interpreter.input_shapes == [(1, *input_shape)] assert interpreter.output_shapes == [(1, np.product(input_shape))] - outputs = interpreter.predict(inputs, 1) + def input_fn(): + if use_iterator: + return (input for input in inputs) + return inputs + + outputs = interpreter.predict(input_fn(), 1) np.testing.assert_allclose(outputs, expected_outputs) -def test_interpreter_multi_input(): +@pytest.mark.parametrize("use_iterator", [True, False]) +def test_interpreter_multi_input(use_iterator): x = tf.keras.Input((24, 24, 2)) y = tf.keras.Input((24, 24, 1)) model = tf.keras.Model( @@ -45,7 +52,12 @@ def test_interpreter_multi_input(): assert interpreter.input_shapes == [(1, 24, 24, 2), (1, 24, 24, 1)] assert interpreter.output_shapes == [(1, 24 * 24 * 2), (1, 24 * 24 * 1)] - output_x, output_y = interpreter.predict([x_np, y_np]) + def input_fn(): + if use_iterator: + return ([x, y] for x, y in zip(x_np, y_np)) + return [x_np, y_np] + + output_x, output_y = interpreter.predict(input_fn()) np.testing.assert_allclose(output_x, expected_output_x) np.testing.assert_allclose(output_y, expected_output_y)