Skip to content

Commit

Permalink
Add support for multi threaded interpreter (#512)
Browse files Browse the repository at this point in the history
* Add support for multi threaded interpreter

* Run end2end tests using multithreaded interpreter
  • Loading branch information
lgeiger committed Sep 21, 2020
1 parent ca5f8dd commit d07a1ee
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 8 deletions.
3 changes: 2 additions & 1 deletion larq_compute_engine/tests/end2end_test.py
@@ -1,4 +1,5 @@
import math
import os
import sys

import larq as lq
Expand Down Expand Up @@ -155,7 +156,7 @@ def preprocess(data):


def assert_model_output(model_lce, inputs, outputs):
interpreter = Interpreter(model_lce)
interpreter = Interpreter(model_lce, num_threads=min(os.cpu_count(), 4))
actual_outputs = interpreter.predict(inputs)
np.testing.assert_allclose(actual_outputs, outputs, rtol=0.001, atol=0.25)

Expand Down
7 changes: 5 additions & 2 deletions larq_compute_engine/tflite/python/interpreter.py
Expand Up @@ -42,6 +42,7 @@ class Interpreter:
# Arguments
flatbuffer_model: A serialized Larq Compute Engine model in the flatbuffer format.
num_threads: The number of threads used by the interpreter.
# Attributes
input_types: Returns a list of input types.
Expand All @@ -50,8 +51,10 @@ class Interpreter:
output_shapes: Returns a list of output shapes.
"""

def __init__(self, flatbuffer_model: bytes):
self.interpreter = interpreter_wrapper_lite.LiteInterpreter(flatbuffer_model)
def __init__(self, flatbuffer_model: bytes, num_threads: int = 1):
self.interpreter = interpreter_wrapper_lite.LiteInterpreter(
flatbuffer_model, num_threads
)

@property
def input_types(self) -> list:
Expand Down
9 changes: 5 additions & 4 deletions larq_compute_engine/tflite/python/interpreter_wrapper_lite.cc
Expand Up @@ -8,7 +8,8 @@
class LiteInterpreterWrapper
: public InterpreterWrapperBase<tflite::Interpreter> {
public:
LiteInterpreterWrapper(const pybind11::bytes& flatbuffer);
LiteInterpreterWrapper(const pybind11::bytes& flatbuffer,
const int num_threads);
~LiteInterpreterWrapper(){};

private:
Expand All @@ -20,7 +21,7 @@ class LiteInterpreterWrapper
};

LiteInterpreterWrapper::LiteInterpreterWrapper(
const pybind11::bytes& flatbuffer) {
const pybind11::bytes& flatbuffer, const int num_threads = 1) {
// Make a copy of the flatbuffer because it can get deallocated after the
// constructor is done
flatbuffer_ = static_cast<std::string>(flatbuffer);
Expand All @@ -36,7 +37,7 @@ LiteInterpreterWrapper::LiteInterpreterWrapper(
compute_engine::tflite::RegisterLCECustomOps(resolver_.get());

tflite::InterpreterBuilder builder(*model_, *resolver_);
builder(&interpreter_);
builder(&interpreter_, num_threads);
MINIMAL_CHECK(interpreter_ != nullptr);

// Allocate tensor buffers.
Expand All @@ -45,7 +46,7 @@ LiteInterpreterWrapper::LiteInterpreterWrapper(

PYBIND11_MODULE(interpreter_wrapper_lite, m) {
pybind11::class_<LiteInterpreterWrapper>(m, "LiteInterpreter")
.def(pybind11::init<const pybind11::bytes&>())
.def(pybind11::init<const pybind11::bytes&, const int>())
.def_property("input_types", &LiteInterpreterWrapper::get_input_types,
nullptr)
.def_property("output_types", &LiteInterpreterWrapper::get_output_types,
Expand Down
2 changes: 1 addition & 1 deletion larq_compute_engine/tflite/tests/interpreter_test.py
Expand Up @@ -46,7 +46,7 @@ def test_interpreter_multi_input(use_iterator):
expected_output_x = x_np.reshape(16, -1)
expected_output_y = y_np.reshape(16, -1)

interpreter = Interpreter(converter.convert())
interpreter = Interpreter(converter.convert(), num_threads=2)
assert interpreter.input_types == [np.float32, np.float32]
assert interpreter.output_types == [np.float32, np.float32]
assert interpreter.input_shapes == [(1, 24, 24, 2), (1, 24, 24, 1)]
Expand Down

0 comments on commit d07a1ee

Please sign in to comment.