Skip to content

Commit

Permalink
Add Python wrapper for LCE interpreter (#507)
Browse files Browse the repository at this point in the history
Co-authored-by: Tom Bannink <tombannink@gmail.com>
Co-authored-by: Adam Hillier <7688302+AdamHillier@users.noreply.github.com>
  • Loading branch information
3 people committed Sep 18, 2020
1 parent 6bbc061 commit d72b679
Show file tree
Hide file tree
Showing 18 changed files with 475 additions and 130 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ jobs:
run: ./configure.sh <<< $'n\n'
shell: bash
- name: Install pip dependencies
run: pip install tensorflow-cpu~=2.2.0 larq~=0.10.0 larq_zoo~=1.0.0 pytest tensorflow_datasets~=3.2.0 flatbuffers --no-cache-dir
run: pip install tensorflow-cpu~=2.2.0 larq~=0.10.0 larq_zoo~=1.0.0 pytest tensorflow_datasets~=3.2.0 flatbuffers tqdm --no-cache-dir
- name: Run Interpreter test
run: bazelisk test larq_compute_engine/tflite/tests:interpreter_test --test_output=all --copt=-O2 --distinct_host_configuration=false $([ -z "$GOOGLE_APPLICATION_CREDENTIALS" ] || echo "--remote_http_cache=https://storage.googleapis.com/plumerai-bazel-cache/lce-ubuntu --google_default_credentials")
- name: Run FileCheck tests
run: bazelisk test larq_compute_engine/mlir/tests:all --test_output=all --copt=-O2 --distinct_host_configuration=false $([ -z "$GOOGLE_APPLICATION_CREDENTIALS" ] || echo "--remote_http_cache=https://storage.googleapis.com/plumerai-bazel-cache/lce-ubuntu --google_default_credentials")
- name: Run End2End tests
Expand Down
4 changes: 2 additions & 2 deletions build_pip_pkg.sh
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ function main() {

if ! is_windows; then
echo "=== Stripping symbols"
chmod +w ${TMPDIR}/larq_compute_engine/mlir/*.so
strip -x ${TMPDIR}/larq_compute_engine/mlir/*.so
chmod +w ${TMPDIR}/larq_compute_engine/mlir/*.so ${TMPDIR}/larq_compute_engine/tflite/python/*.so
strip -x ${TMPDIR}/larq_compute_engine/mlir/*.so ${TMPDIR}/larq_compute_engine/tflite/python/*.so
fi

echo $(date) : "=== Building wheel"
Expand Down
2 changes: 2 additions & 0 deletions larq_compute_engine/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ py_library(
srcs = ([
"__init__.py",
"//larq_compute_engine/mlir:__init__.py",
"//larq_compute_engine/tflite:__init__.py",
"//larq_compute_engine/mlir:python/__init__.py",
]),
deps = [
"//larq_compute_engine/mlir:converter",
"//larq_compute_engine/tflite/python:interpreter",
],
)
3 changes: 2 additions & 1 deletion larq_compute_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from larq_compute_engine.mlir.python.converter import convert_keras_model
from larq_compute_engine.tflite.python import interpreter as testing

__all__ = ["convert_keras_model"]
__all__ = ["convert_keras_model", "testing"]
3 changes: 3 additions & 0 deletions larq_compute_engine/mlir/python/converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from tensorflow.python.eager import context

sys.modules["larq_compute_engine.mlir._graphdef_tfl_flatbuffer"] = mock.MagicMock()
sys.modules[
"larq_compute_engine.tflite.python.interpreter_wrapper_lite"
] = mock.MagicMock()

from larq_compute_engine.mlir.python.converter import convert_keras_model
from larq_compute_engine.mlir._graphdef_tfl_flatbuffer import (
Expand Down
18 changes: 1 addition & 17 deletions larq_compute_engine/tests/BUILD
Original file line number Diff line number Diff line change
@@ -1,31 +1,15 @@
load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension")
load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_linkopts")

package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)

pybind_extension(
name = "_end2end_verify",
srcs = ["end2end_verify.cc"],
linkopts = tflite_linkopts(),
module_name = "end2end_verify",
deps = [
"//larq_compute_engine/tflite/kernels:lce_op_kernels",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"@pybind11",
],
)

py_test(
name = "end2end_test",
size = "large",
srcs = ["end2end_test.py"],
deps = [
":_end2end_verify",
"//larq_compute_engine/mlir:converter",
"//larq_compute_engine/tflite/python:interpreter",
],
)

Expand Down
20 changes: 8 additions & 12 deletions larq_compute_engine/tests/end2end_test.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import math
import sys
import pytest

import larq as lq
import larq_zoo as lqz
import numpy as np
import pytest
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import math

from larq_compute_engine.mlir.python.converter import convert_keras_model
from larq_compute_engine.tests._end2end_verify import run_model
from larq_compute_engine.tflite.python.interpreter import Interpreter


def toy_model(**kwargs):
Expand Down Expand Up @@ -154,14 +155,9 @@ def preprocess(data):


def assert_model_output(model_lce, inputs, outputs):
for input, output in zip(inputs, outputs):
actual_outputs = run_model(model_lce, list(input.flatten()))
assert len(actual_outputs) > 1
np.testing.assert_allclose(actual_outputs[0], actual_outputs[1], rtol=1e-5)
for actual_output in actual_outputs:
np.testing.assert_allclose(
actual_output, output.flatten(), rtol=0.001, atol=0.25
)
interpreter = Interpreter(model_lce)
actual_outputs = interpreter.predict(inputs)
np.testing.assert_allclose(actual_outputs, outputs, rtol=0.001, atol=0.25)


@pytest.mark.parametrize(
Expand Down
96 changes: 0 additions & 96 deletions larq_compute_engine/tests/end2end_verify.cc

This file was deleted.

4 changes: 4 additions & 0 deletions larq_compute_engine/tflite/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,7 @@ package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)

exports_files([
"__init__.py",
])
Empty file.
41 changes: 41 additions & 0 deletions larq_compute_engine/tflite/python/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension")
load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_linkopts")

package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)

cc_library(
name = "interpreter_wrapper_utils",
hdrs = ["interpreter_wrapper_utils.h"],
deps = [
"@org_tensorflow//tensorflow/lite/c:common",
"@pybind11",
],
)

pybind_extension(
name = "interpreter_wrapper_lite",
srcs = ["interpreter_wrapper_lite.cc"],
linkopts = tflite_linkopts(),
module_name = "interpreter_wrapper_lite",
deps = [
":interpreter_wrapper_utils",
"//larq_compute_engine/tflite/kernels:lce_op_kernels",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"@pybind11",
],
)

py_library(
name = "interpreter",
srcs = [
"__init__.py",
"interpreter.py",
],
deps = [
":interpreter_wrapper_lite",
],
)
Empty file.
86 changes: 86 additions & 0 deletions larq_compute_engine/tflite/python/interpreter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import List, Tuple, Union

import numpy as np
from tqdm import tqdm

from larq_compute_engine.tflite.python import interpreter_wrapper_lite

__all__ = ["Interpreter"]


class Interpreter:
"""Interpreter interface for Larq Compute Engine Models.
!!! example
```python
lce_model = convert_keras_model(model)
interpreter = Interpreter(lce_model)
interpreter.predict(input_data, verbose=1)
```
# Arguments
flatbuffer_model: A serialized Larq Compute Engine model in the flatbuffer format.
# Attributes
input_types: Returns a list of input types.
input_shapes: Returns a list of input shapes.
output_types: Returns a list of output types.
output_shapes: Returns a list of output shapes.
"""

def __init__(self, flatbuffer_model: bytes):
self.interpreter = interpreter_wrapper_lite.LiteInterpreter(flatbuffer_model)

@property
def input_types(self) -> list:
"""Returns a list of input types."""
return self.interpreter.input_types

@property
def input_shapes(self) -> List[Tuple[int]]:
"""Returns a list of input shapes."""
return self.interpreter.input_shapes

@property
def output_types(self) -> list:
"""Returns a list of output types."""
return self.interpreter.output_types

@property
def output_shapes(self) -> List[Tuple[int]]:
"""Returns a list of output shapes."""
return self.interpreter.output_shapes

def predict(
self, x: Union[np.ndarray, List[np.ndarray]], verbose: int = 0
) -> Union[np.ndarray, List[np.ndarray]]:
"""Generates output predictions for the input samples.
# Arguments
x: Input samples. A Numpy array, or a list of arrays in case the model has
multiple inputs.
verbose: Verbosity mode, 0 or 1.
# Returns
Numpy array(s) of output predictions.
"""

if not isinstance(x, (list, np.ndarray)) or len(x) == 0:
raise ValueError(
"Expected either a non-empty list of inputs or a Numpy array with "
f"implicit initial batch dimension. Received: {x}"
)

if len(self.input_shapes) == 1:
x = [x]

batch_iter = tqdm(zip(*x)) if verbose >= 1 else zip(*x)
prediction_batch_iter = (
self.interpreter.predict([np.expand_dims(inp, axis=0) for inp in inputs])
for inputs in batch_iter
)
outputs = [np.concatenate(batches) for batches in zip(*prediction_batch_iter)]

if len(self.output_shapes) == 1:
return outputs[0]
return outputs

0 comments on commit d72b679

Please sign in to comment.