Skip to content

Commit

Permalink
Reset model object before using it to unpack a flatbuffer to avoid th…
Browse files Browse the repository at this point in the history
…e memory leak

The existing code relies on UnPackTo to merge the existing state with the new incoming state. However, this behavior was changed in google/flatbuffers#7527 and now leads to memory leak. Ignore the commit description there as it is outdated. During the review, they decided to not have the old behavior as an option.

PiperOrigin-RevId: 484798857
  • Loading branch information
smit-hinsu authored and tensorflower-gardener committed Oct 30, 2022
1 parent 269e683 commit c077b75
Showing 1 changed file with 35 additions and 23 deletions.
Expand Up @@ -43,6 +43,18 @@ namespace tflite {
namespace optimize {
namespace {

// Unpacks the given flatbuffer model.
//
// This helper is useful as UnPackTo requires the input to not have any existing
// state so directly calling UnPackTo could lead to memory leaks if the model
// already had some state. Instead, the returned object from here can be used to
// overwrite existing model.
ModelT UnPackFlatBufferModel(const Model& flatbuffer_model) {
ModelT model;
flatbuffer_model.UnPackTo(&model);
return model;
}

TfLiteStatus QuantizeModel(
flatbuffers::FlatBufferBuilder* builder, ModelT* model,
const TensorType& input_type, const TensorType& output_type,
Expand All @@ -67,7 +79,7 @@ TfLiteStatus QuantizeModel(

auto flatbuffer_model =
FlatBufferModel::BuildFromBuffer(buffer.c_str(), buffer.size());
flatbuffer_model->GetModel()->UnPackTo(model);
*model = UnPackFlatBufferModel(*flatbuffer_model->GetModel());
return kTfLiteOk;
}

Expand Down Expand Up @@ -144,7 +156,7 @@ class QuantizeModelTest : public testing::Test {
QuantizeModelTest() {
input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
model_ = UnPackFlatBufferModel(*readonly_model_);
}

std::unique_ptr<FlatBufferModel> input_model_;
Expand Down Expand Up @@ -230,7 +242,7 @@ class QuantizeConvModelTest : public QuantizeModelTest,
tensor_type_ = GetParam();
input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
model_ = UnPackFlatBufferModel(*readonly_model_);
// Flatbuffer is missing calibration data -- add dummy params.
auto& subgraph = model_.subgraphs[0];
auto* input = subgraph->tensors[subgraph->inputs[0]].get();
Expand Down Expand Up @@ -304,7 +316,7 @@ class QuantizeConvNoBiasModelTest : public QuantizeModelTest {
QuantizeConvNoBiasModelTest() {
input_model_ = ReadModel(internal::kConvModelWithNoBias);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
model_ = UnPackFlatBufferModel(*readonly_model_);
}
};

Expand All @@ -323,7 +335,7 @@ class QuantizeSplitModelTest : public QuantizeModelTest {
QuantizeSplitModelTest() {
input_model_ = ReadModel(internal::kModelSplit);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
model_ = UnPackFlatBufferModel(*readonly_model_);
}
};

Expand Down Expand Up @@ -401,7 +413,7 @@ class QuantizeConvModel2Test : public QuantizeModelTest,
tensor_type_ = GetParam();
input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
model_ = UnPackFlatBufferModel(*readonly_model_);
auto& subgraph = model_.subgraphs[0];
auto* input = subgraph->tensors[subgraph->inputs[0]].get();
auto* output = subgraph->tensors[subgraph->outputs[0]].get();
Expand Down Expand Up @@ -636,7 +648,7 @@ class QuantizeSoftmaxTest : public QuantizeModelTest {
QuantizeSoftmaxTest() {
input_model_ = ReadModel(internal::kSingleSoftmaxModelMinMinus5MaxPlus5);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
model_ = UnPackFlatBufferModel(*readonly_model_);
}
};

Expand Down Expand Up @@ -698,7 +710,7 @@ class QuantizeAvgPoolTest : public QuantizeModelTest {
QuantizeAvgPoolTest() {
input_model_ = ReadModel(internal::kSingleAvgPoolModelMinMinus5MaxPlus5);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
model_ = UnPackFlatBufferModel(*readonly_model_);
}
};

Expand Down Expand Up @@ -757,7 +769,7 @@ class QuantizeMultiInputAddWithReshapeTest : public QuantizeModelTest {
QuantizeMultiInputAddWithReshapeTest() {
input_model_ = ReadModel(internal::kMultiInputAddWithReshape);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
model_ = UnPackFlatBufferModel(*readonly_model_);
}
};

Expand Down Expand Up @@ -877,7 +889,7 @@ class QuantizeConstInputTest : public QuantizeModelTest,
tensor_type_ = GetParam();
input_model_ = ReadModel(internal::kConstInputAddModel);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
model_ = UnPackFlatBufferModel(*readonly_model_);
}

TensorType tensor_type_;
Expand Down Expand Up @@ -926,7 +938,7 @@ class QuantizeArgMaxTest : public QuantizeModelTest {
QuantizeArgMaxTest() {
input_model_ = ReadModel(internal::kModelWithArgMaxOp);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
model_ = UnPackFlatBufferModel(*readonly_model_);
}
};

Expand Down Expand Up @@ -971,7 +983,7 @@ class QuantizeLSTMTest : public QuantizeModelTest {
QuantizeLSTMTest() {
input_model_ = ReadModel(internal::kLstmCalibrated);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
model_ = UnPackFlatBufferModel(*readonly_model_);
}
};

Expand All @@ -996,7 +1008,7 @@ class QuantizeLSTM2Test : public QuantizeModelTest {
QuantizeLSTM2Test() {
input_model_ = ReadModel(internal::kLstmCalibrated2);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
model_ = UnPackFlatBufferModel(*readonly_model_);
}
};

Expand All @@ -1021,7 +1033,7 @@ class QuantizeUnidirectionalSequenceLSTMTest : public QuantizeModelTest {
QuantizeUnidirectionalSequenceLSTMTest() {
input_model_ = ReadModel(internal::kUnidirectionalSequenceLstmCalibrated);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
model_ = UnPackFlatBufferModel(*readonly_model_);
}
};

Expand All @@ -1048,7 +1060,7 @@ class QuantizeSVDFTest : public QuantizeModelTest {
QuantizeSVDFTest() {
input_model_ = ReadModel(internal::kSvdfCalibrated);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
model_ = UnPackFlatBufferModel(*readonly_model_);
}
};

Expand All @@ -1073,7 +1085,7 @@ class QuantizeFCTest : public QuantizeModelTest {
QuantizeFCTest() {
input_model_ = ReadModel(internal::kModelWithFCOp);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
model_ = UnPackFlatBufferModel(*readonly_model_);
}
};

Expand Down Expand Up @@ -1125,7 +1137,7 @@ class QuantizeCustomOpTest
QuantizeCustomOpTest() {
input_model_ = ReadModel(internal::kModelMixed);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
model_ = UnPackFlatBufferModel(*readonly_model_);
}
};

Expand Down Expand Up @@ -1164,7 +1176,7 @@ class QuantizePackTest : public QuantizeModelTest {
QuantizePackTest() {
input_model_ = ReadModel(internal::kModelPack);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
model_ = UnPackFlatBufferModel(*readonly_model_);
}
};

Expand Down Expand Up @@ -1228,7 +1240,7 @@ class QuantizeMinimumMaximumTest
QuantizeMinimumMaximumTest() {
input_model_ = ReadModel(GetParam());
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
model_ = UnPackFlatBufferModel(*readonly_model_);
}
};

Expand Down Expand Up @@ -1288,7 +1300,7 @@ class QuantizeUnpackTest : public QuantizeModelTest {
QuantizeUnpackTest() {
input_model_ = ReadModel(internal::kModelWithUnpack);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
model_ = UnPackFlatBufferModel(*readonly_model_);
}
};

Expand Down Expand Up @@ -1340,7 +1352,7 @@ class QuantizeBroadcastToModelTest
tensor_type_ = GetParam();
input_model_ = ReadModel(internal::kModelWithBroadcastToOp);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
model_ = UnPackFlatBufferModel(*readonly_model_);
}
TensorType tensor_type_;
};
Expand Down Expand Up @@ -1406,7 +1418,7 @@ class QuantizeGatherNDModelTest
tensor_type_ = GetParam();
input_model_ = ReadModel(internal::kModelWithGatherNDOp);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
model_ = UnPackFlatBufferModel(*readonly_model_);
}

TensorType tensor_type_;
Expand Down Expand Up @@ -1467,7 +1479,7 @@ class QuantizeWhereModelTest : public QuantizeModelTest {
QuantizeWhereModelTest() {
input_model_ = ReadModel(internal::kModelWithWhereOp);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
model_ = UnPackFlatBufferModel(*readonly_model_);
}
};

Expand Down

0 comments on commit c077b75

Please sign in to comment.