Skip to content

Commit

Permalink
When no weight cache is provided to XNNPack, create one to share pack…
Browse files Browse the repository at this point in the history
…ed weights between operations.

PiperOrigin-RevId: 623781016
  • Loading branch information
qukhan authored and tensorflower-gardener committed May 10, 2024
1 parent 132c4ee commit f05f44e
Show file tree
Hide file tree
Showing 14 changed files with 2,108 additions and 15 deletions.
2 changes: 2 additions & 0 deletions tensorflow/lite/core/c/common.h
Expand Up @@ -472,6 +472,8 @@ typedef enum TfLiteCustomAllocationFlags {
kTfLiteCustomAllocationFlagsSkipAlignCheck = 1,
} TfLiteCustomAllocationFlags;

enum { kTfLiteNoBufferIdentifier = SIZE_MAX };

/// A tensor in the interpreter system which is a wrapper around a buffer of
/// data including a dimensionality (or NULL if not currently defined).
#ifndef TF_LITE_STATIC_MEMORY
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/lite/core/interpreter_builder.cc
Expand Up @@ -691,7 +691,8 @@ TfLiteStatus InterpreterBuilder::ParseTensors(

if (subgraph->SetTensorParametersReadOnly(
i, type, get_name(tensor), dims, quantization, buffer_ptr,
buffer_size, allocation_, sparsity) != kTfLiteOk) {
buffer_size, allocation_, sparsity,
/*buffer_identifier=*/tensor->buffer()) != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter_,
"Tensor %d is invalidly specified in schema.\n",
i);
Expand Down
6 changes: 5 additions & 1 deletion tensorflow/lite/core/subgraph.cc
Expand Up @@ -1856,7 +1856,8 @@ TfLiteStatus Subgraph::GetNodeAndRegistration(
TfLiteStatus Subgraph::SetTensorParametersReadOnly(
int tensor_index, TfLiteType type, const char* name, const size_t ndims,
const int* dims, TfLiteQuantization quantization, const char* buffer,
size_t bytes, const Allocation* allocation, TfLiteSparsity* sparsity) {
size_t bytes, const Allocation* allocation, TfLiteSparsity* sparsity,
const size_t buffer_identifier) {
// Ensure quantization cleanup on failure.
ScopedTfLiteQuantization scoped_quantization(&quantization);
ScopedTfLiteSparsity scoped_sparsity(sparsity);
Expand Down Expand Up @@ -1904,6 +1905,9 @@ TfLiteStatus Subgraph::SetTensorParametersReadOnly(
tensor.quantization = *scoped_quantization.release();
tensor.sparsity = scoped_sparsity.release();
}
if (buffer_identifier != kTfLiteNoBufferIdentifier) {
tensor_buffer_identifiers_[tensor_index] = buffer_identifier;
}
return kTfLiteOk;
}

Expand Down
17 changes: 14 additions & 3 deletions tensorflow/lite/core/subgraph.h
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -132,16 +133,18 @@ class Subgraph {
int tensor_index, TfLiteType type, const char* name,
const std::vector<int>& dims, TfLiteQuantization quantization,
const char* buffer, size_t bytes, const Allocation* allocation = nullptr,
TfLiteSparsity* sparsity = nullptr) {
TfLiteSparsity* sparsity = nullptr,
size_t buffer_identifier = kTfLiteNoBufferIdentifier) {
return SetTensorParametersReadOnly(tensor_index, type, name, dims.size(),
dims.data(), quantization, buffer, bytes,
allocation, sparsity);
allocation, sparsity, buffer_identifier);
}
TfLiteStatus SetTensorParametersReadOnly(
int tensor_index, TfLiteType type, const char* name, const size_t ndims,
const int* dims, TfLiteQuantization quantization, const char* buffer,
size_t bytes, const Allocation* allocation = nullptr,
TfLiteSparsity* sparsity = nullptr);
TfLiteSparsity* sparsity = nullptr,
size_t buffer_identifier = kTfLiteNoBufferIdentifier);

// Set description of inputs/outputs/data/fptrs for node `node_index`.
// This variant assumes an external buffer has been allocated of size
Expand Down Expand Up @@ -589,6 +592,10 @@ class Subgraph {
// Returns true if the subgraph has been fully delegated.
bool IsFullyDelegated() const;

const std::unordered_map<size_t, size_t>& GetTensorBufferIdentifiers() {
return tensor_buffer_identifiers_;
}

private:
#ifndef DOXYGEN_SKIP
friend class tflite::impl::InterpreterBuilder;
Expand Down Expand Up @@ -1153,6 +1160,10 @@ class Subgraph {
/// The allocator used for holding memory of the model. Note that this will
/// be null if the client provides a tflite::Model directly.
const Allocation* allocation_ = nullptr;

// Maps tensor constant buffers used in the subgraph to a model-wide
// identifiers.
std::unordered_map<size_t, size_t> tensor_buffer_identifiers_;
};

} // namespace tflite
Expand Down
48 changes: 41 additions & 7 deletions tensorflow/lite/delegates/xnnpack/BUILD
@@ -1,3 +1,4 @@
load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable")
load("//tensorflow/lite:build_def.bzl", "tflite_copts")
load("//tensorflow/lite:special_rules.bzl", "internal_visibility_allowlist", "tflite_portable_test_suite_combined")
Expand Down Expand Up @@ -246,11 +247,7 @@ cc_library(
linkstatic = True,
deps = [
":quantization_util",
":tflite_with_xnnpack_dynamic_fully_connected",
":tflite_with_xnnpack_logging",
":tflite_with_xnnpack_qs8",
":tflite_with_xnnpack_qu8",
":tflite_with_xnnpack_transient_indirection_buffer",
":weight_cache",
"//tensorflow/lite:kernel_api",
"//tensorflow/lite:minimal_logging",
"//tensorflow/lite/c:c_api_types",
Expand All @@ -267,7 +264,6 @@ cc_library(
"//tensorflow/lite/tools/optimize:reduced_precision_support",
"@XNNPACK",
"@XNNPACK//:experiments_config",
"@XNNPACK//:logging",
],
)

Expand All @@ -289,6 +285,7 @@ cc_library(
linkstatic = True,
deps = [
":quantization_util",
":weight_cache",
"//tensorflow/lite:kernel_api",
"//tensorflow/lite:minimal_logging",
"//tensorflow/lite/c:c_api_types",
Expand All @@ -305,7 +302,6 @@ cc_library(
"//tensorflow/lite/tools/optimize:reduced_precision_support",
"@XNNPACK//:XNNPACK_test_mode",
"@XNNPACK//:experiments_config",
"@XNNPACK//:logging",
],
)

Expand All @@ -323,6 +319,30 @@ cc_library(
],
)

flatbuffer_cc_library(
name = "weight_cache_schema",
srcs = ["weight_cache_schema.fbs"],
compatible_with = get_compatible_with_portable(),
flatc_args = [
"--gen-mutable",
"--gen-object-api",
],
)

cc_library(
name = "weight_cache",
srcs = ["weight_cache.cc"],
hdrs = ["weight_cache.h"],
compatible_with = get_compatible_with_portable(),
deps = [
":weight_cache_schema",
"//tensorflow/lite:minimal_logging",
"//tensorflow/lite/c:common",
"@XNNPACK",
"@flatbuffers//:runtime_cc",
],
)

################################ Tester classes ################################

cc_library(
Expand Down Expand Up @@ -2828,4 +2848,18 @@ cc_test(
],
)

cc_test(
name = "weight_cache_test",
srcs = ["weight_cache_test.cc"],
deps = [
":test_main",
":weight_cache",
":weight_cache_schema",
"//tensorflow/lite/c:common",
"@XNNPACK",
"@com_google_googletest//:gtest",
"@flatbuffers//:runtime_cc",
],
)

tflite_portable_test_suite_combined(combine_conditions = {"deps": [":test_main"]})

0 comments on commit f05f44e

Please sign in to comment.