From d07a1eed8e8c1a9ea3960e9a753e9d3fa5262b89 Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Mon, 21 Sep 2020 14:13:10 +0200 Subject: [PATCH] Add support for multi threaded interpreter (#512) * Add support for multi threaded interpreter * Run end2end tests using multithreaded interpreter --- larq_compute_engine/tests/end2end_test.py | 3 ++- larq_compute_engine/tflite/python/interpreter.py | 7 +++++-- .../tflite/python/interpreter_wrapper_lite.cc | 9 +++++---- larq_compute_engine/tflite/tests/interpreter_test.py | 2 +- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/larq_compute_engine/tests/end2end_test.py b/larq_compute_engine/tests/end2end_test.py index 17a7a5966..b5817dc2d 100644 --- a/larq_compute_engine/tests/end2end_test.py +++ b/larq_compute_engine/tests/end2end_test.py @@ -1,4 +1,5 @@ import math +import os import sys import larq as lq @@ -155,7 +156,7 @@ def preprocess(data): def assert_model_output(model_lce, inputs, outputs): - interpreter = Interpreter(model_lce) + interpreter = Interpreter(model_lce, num_threads=min(os.cpu_count(), 4)) actual_outputs = interpreter.predict(inputs) np.testing.assert_allclose(actual_outputs, outputs, rtol=0.001, atol=0.25) diff --git a/larq_compute_engine/tflite/python/interpreter.py b/larq_compute_engine/tflite/python/interpreter.py index 02754188b..338ed2226 100644 --- a/larq_compute_engine/tflite/python/interpreter.py +++ b/larq_compute_engine/tflite/python/interpreter.py @@ -42,6 +42,7 @@ class Interpreter: # Arguments flatbuffer_model: A serialized Larq Compute Engine model in the flatbuffer format. + num_threads: The number of threads used by the interpreter. # Attributes input_types: Returns a list of input types. @@ -50,8 +51,10 @@ class Interpreter: output_shapes: Returns a list of output shapes. """ - def __init__(self, flatbuffer_model: bytes): - self.interpreter = interpreter_wrapper_lite.LiteInterpreter(flatbuffer_model) + def __init__(self, flatbuffer_model: bytes, num_threads: int = 1): + self.interpreter = interpreter_wrapper_lite.LiteInterpreter( + flatbuffer_model, num_threads + ) @property def input_types(self) -> list: diff --git a/larq_compute_engine/tflite/python/interpreter_wrapper_lite.cc b/larq_compute_engine/tflite/python/interpreter_wrapper_lite.cc index 3f367a4de..7a382273a 100644 --- a/larq_compute_engine/tflite/python/interpreter_wrapper_lite.cc +++ b/larq_compute_engine/tflite/python/interpreter_wrapper_lite.cc @@ -8,7 +8,8 @@ class LiteInterpreterWrapper : public InterpreterWrapperBase { public: - LiteInterpreterWrapper(const pybind11::bytes& flatbuffer); + LiteInterpreterWrapper(const pybind11::bytes& flatbuffer, + const int num_threads); ~LiteInterpreterWrapper(){}; private: @@ -20,7 +21,7 @@ class LiteInterpreterWrapper }; LiteInterpreterWrapper::LiteInterpreterWrapper( - const pybind11::bytes& flatbuffer) { + const pybind11::bytes& flatbuffer, const int num_threads = 1) { // Make a copy of the flatbuffer because it can get deallocated after the // constructor is done flatbuffer_ = static_cast(flatbuffer); @@ -36,7 +37,7 @@ LiteInterpreterWrapper::LiteInterpreterWrapper( compute_engine::tflite::RegisterLCECustomOps(resolver_.get()); tflite::InterpreterBuilder builder(*model_, *resolver_); - builder(&interpreter_); + builder(&interpreter_, num_threads); MINIMAL_CHECK(interpreter_ != nullptr); // Allocate tensor buffers. @@ -45,7 +46,7 @@ LiteInterpreterWrapper::LiteInterpreterWrapper( PYBIND11_MODULE(interpreter_wrapper_lite, m) { pybind11::class_(m, "LiteInterpreter") - .def(pybind11::init()) + .def(pybind11::init()) .def_property("input_types", &LiteInterpreterWrapper::get_input_types, nullptr) .def_property("output_types", &LiteInterpreterWrapper::get_output_types, diff --git a/larq_compute_engine/tflite/tests/interpreter_test.py b/larq_compute_engine/tflite/tests/interpreter_test.py index 0196e8d28..5c6b4e354 100644 --- a/larq_compute_engine/tflite/tests/interpreter_test.py +++ b/larq_compute_engine/tflite/tests/interpreter_test.py @@ -46,7 +46,7 @@ def test_interpreter_multi_input(use_iterator): expected_output_x = x_np.reshape(16, -1) expected_output_y = y_np.reshape(16, -1) - interpreter = Interpreter(converter.convert()) + interpreter = Interpreter(converter.convert(), num_threads=2) assert interpreter.input_types == [np.float32, np.float32] assert interpreter.output_types == [np.float32, np.float32] assert interpreter.input_shapes == [(1, 24, 24, 2), (1, 24, 24, 1)]