-
Notifications
You must be signed in to change notification settings - Fork 33
/
interpreter.py
100 lines (78 loc) · 3.24 KB
/
interpreter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from typing import Iterator, List, Tuple, Union
import numpy as np
from tqdm import tqdm
from larq_compute_engine.tflite.python import interpreter_wrapper_lite
__all__ = ["Interpreter"]
Data = Union[np.ndarray, List[np.ndarray]]
def data_generator(x: Union[Data, Iterator[Data]]) -> Iterator[List[np.ndarray]]:
if isinstance(x, np.ndarray):
for inputs in x:
yield [np.expand_dims(inputs, axis=0)]
elif isinstance(x, list):
for inputs in zip(*x):
yield [np.expand_dims(inp, axis=0) for inp in inputs]
elif hasattr(x, "__next__") and hasattr(x, "__iter__"):
for inputs in x:
if isinstance(inputs, np.ndarray):
yield [np.expand_dims(inputs, axis=0)]
else:
yield [np.expand_dims(inp, axis=0) for inp in inputs]
else:
raise ValueError(
"Expected either a list of inputs or a Numpy array with implicit initial "
f"batch dimension or an iterator yielding one of the above. Received: {x}"
)
class Interpreter:
"""Interpreter interface for Larq Compute Engine Models.
!!! example
```python
lce_model = convert_keras_model(model)
interpreter = Interpreter(lce_model)
interpreter.predict(input_data, verbose=1)
```
# Arguments
flatbuffer_model: A serialized Larq Compute Engine model in the flatbuffer format.
num_threads: The number of threads used by the interpreter.
# Attributes
input_types: Returns a list of input types.
input_shapes: Returns a list of input shapes.
output_types: Returns a list of output types.
output_shapes: Returns a list of output shapes.
"""
def __init__(self, flatbuffer_model: bytes, num_threads: int = 1):
self.interpreter = interpreter_wrapper_lite.LiteInterpreter(
flatbuffer_model, num_threads
)
@property
def input_types(self) -> list:
"""Returns a list of input types."""
return self.interpreter.input_types
@property
def input_shapes(self) -> List[Tuple[int]]:
"""Returns a list of input shapes."""
return self.interpreter.input_shapes
@property
def output_types(self) -> list:
"""Returns a list of output types."""
return self.interpreter.output_types
@property
def output_shapes(self) -> List[Tuple[int]]:
"""Returns a list of output shapes."""
return self.interpreter.output_shapes
def predict(self, x: Union[Data, Iterator[Data]], verbose: int = 0) -> Data:
"""Generates output predictions for the input samples.
# Arguments
x: Input samples. A Numpy array, or a list of arrays in case the model has
multiple inputs.
verbose: Verbosity mode, 0 or 1.
# Returns
Numpy array(s) of output predictions.
"""
data_iterator = data_generator(x)
if verbose >= 1:
data_iterator = tqdm(data_iterator)
prediction_iter = (self.interpreter.predict(inputs) for inputs in data_iterator)
outputs = [np.concatenate(batches) for batches in zip(*prediction_iter)]
if len(self.output_shapes) == 1:
return outputs[0]
return outputs