diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc index 3dbd30e1dac7a3..db4674a2a57062 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc @@ -43,6 +43,18 @@ namespace tflite { namespace optimize { namespace { +// Unpacks the given flatbuffer model. +// +// This helper is useful as UnPackTo requires the input to not have any existing +// state so directly calling UnPackTo could lead to memory leaks if the model +// already had some state. Instead, the returned object from here can be used to +// overwrite existing model. +ModelT UnPackFlatBufferModel(const Model& flatbuffer_model) { + ModelT model; + flatbuffer_model.UnPackTo(&model); + return model; +} + TfLiteStatus QuantizeModel( flatbuffers::FlatBufferBuilder* builder, ModelT* model, const TensorType& input_type, const TensorType& output_type, @@ -67,7 +79,7 @@ TfLiteStatus QuantizeModel( auto flatbuffer_model = FlatBufferModel::BuildFromBuffer(buffer.c_str(), buffer.size()); - flatbuffer_model->GetModel()->UnPackTo(model); + *model = UnPackFlatBufferModel(*flatbuffer_model->GetModel()); return kTfLiteOk; } @@ -144,7 +156,7 @@ class QuantizeModelTest : public testing::Test { QuantizeModelTest() { input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights); readonly_model_ = input_model_->GetModel(); - readonly_model_->UnPackTo(&model_); + model_ = UnPackFlatBufferModel(*readonly_model_); } std::unique_ptr input_model_; @@ -230,7 +242,7 @@ class QuantizeConvModelTest : public QuantizeModelTest, tensor_type_ = GetParam(); input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights); readonly_model_ = input_model_->GetModel(); - readonly_model_->UnPackTo(&model_); + model_ = UnPackFlatBufferModel(*readonly_model_); // Flatbuffer is missing calibration data -- add dummy params. auto& subgraph = model_.subgraphs[0]; auto* input = subgraph->tensors[subgraph->inputs[0]].get(); @@ -304,7 +316,7 @@ class QuantizeConvNoBiasModelTest : public QuantizeModelTest { QuantizeConvNoBiasModelTest() { input_model_ = ReadModel(internal::kConvModelWithNoBias); readonly_model_ = input_model_->GetModel(); - readonly_model_->UnPackTo(&model_); + model_ = UnPackFlatBufferModel(*readonly_model_); } }; @@ -323,7 +335,7 @@ class QuantizeSplitModelTest : public QuantizeModelTest { QuantizeSplitModelTest() { input_model_ = ReadModel(internal::kModelSplit); readonly_model_ = input_model_->GetModel(); - readonly_model_->UnPackTo(&model_); + model_ = UnPackFlatBufferModel(*readonly_model_); } }; @@ -401,7 +413,7 @@ class QuantizeConvModel2Test : public QuantizeModelTest, tensor_type_ = GetParam(); input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights); readonly_model_ = input_model_->GetModel(); - readonly_model_->UnPackTo(&model_); + model_ = UnPackFlatBufferModel(*readonly_model_); auto& subgraph = model_.subgraphs[0]; auto* input = subgraph->tensors[subgraph->inputs[0]].get(); auto* output = subgraph->tensors[subgraph->outputs[0]].get(); @@ -636,7 +648,7 @@ class QuantizeSoftmaxTest : public QuantizeModelTest { QuantizeSoftmaxTest() { input_model_ = ReadModel(internal::kSingleSoftmaxModelMinMinus5MaxPlus5); readonly_model_ = input_model_->GetModel(); - readonly_model_->UnPackTo(&model_); + model_ = UnPackFlatBufferModel(*readonly_model_); } }; @@ -698,7 +710,7 @@ class QuantizeAvgPoolTest : public QuantizeModelTest { QuantizeAvgPoolTest() { input_model_ = ReadModel(internal::kSingleAvgPoolModelMinMinus5MaxPlus5); readonly_model_ = input_model_->GetModel(); - readonly_model_->UnPackTo(&model_); + model_ = UnPackFlatBufferModel(*readonly_model_); } }; @@ -757,7 +769,7 @@ class QuantizeMultiInputAddWithReshapeTest : public QuantizeModelTest { QuantizeMultiInputAddWithReshapeTest() { input_model_ = ReadModel(internal::kMultiInputAddWithReshape); readonly_model_ = input_model_->GetModel(); - readonly_model_->UnPackTo(&model_); + model_ = UnPackFlatBufferModel(*readonly_model_); } }; @@ -877,7 +889,7 @@ class QuantizeConstInputTest : public QuantizeModelTest, tensor_type_ = GetParam(); input_model_ = ReadModel(internal::kConstInputAddModel); readonly_model_ = input_model_->GetModel(); - readonly_model_->UnPackTo(&model_); + model_ = UnPackFlatBufferModel(*readonly_model_); } TensorType tensor_type_; @@ -926,7 +938,7 @@ class QuantizeArgMaxTest : public QuantizeModelTest { QuantizeArgMaxTest() { input_model_ = ReadModel(internal::kModelWithArgMaxOp); readonly_model_ = input_model_->GetModel(); - readonly_model_->UnPackTo(&model_); + model_ = UnPackFlatBufferModel(*readonly_model_); } }; @@ -971,7 +983,7 @@ class QuantizeLSTMTest : public QuantizeModelTest { QuantizeLSTMTest() { input_model_ = ReadModel(internal::kLstmCalibrated); readonly_model_ = input_model_->GetModel(); - readonly_model_->UnPackTo(&model_); + model_ = UnPackFlatBufferModel(*readonly_model_); } }; @@ -996,7 +1008,7 @@ class QuantizeLSTM2Test : public QuantizeModelTest { QuantizeLSTM2Test() { input_model_ = ReadModel(internal::kLstmCalibrated2); readonly_model_ = input_model_->GetModel(); - readonly_model_->UnPackTo(&model_); + model_ = UnPackFlatBufferModel(*readonly_model_); } }; @@ -1021,7 +1033,7 @@ class QuantizeUnidirectionalSequenceLSTMTest : public QuantizeModelTest { QuantizeUnidirectionalSequenceLSTMTest() { input_model_ = ReadModel(internal::kUnidirectionalSequenceLstmCalibrated); readonly_model_ = input_model_->GetModel(); - readonly_model_->UnPackTo(&model_); + model_ = UnPackFlatBufferModel(*readonly_model_); } }; @@ -1048,7 +1060,7 @@ class QuantizeSVDFTest : public QuantizeModelTest { QuantizeSVDFTest() { input_model_ = ReadModel(internal::kSvdfCalibrated); readonly_model_ = input_model_->GetModel(); - readonly_model_->UnPackTo(&model_); + model_ = UnPackFlatBufferModel(*readonly_model_); } }; @@ -1073,7 +1085,7 @@ class QuantizeFCTest : public QuantizeModelTest { QuantizeFCTest() { input_model_ = ReadModel(internal::kModelWithFCOp); readonly_model_ = input_model_->GetModel(); - readonly_model_->UnPackTo(&model_); + model_ = UnPackFlatBufferModel(*readonly_model_); } }; @@ -1125,7 +1137,7 @@ class QuantizeCustomOpTest QuantizeCustomOpTest() { input_model_ = ReadModel(internal::kModelMixed); readonly_model_ = input_model_->GetModel(); - readonly_model_->UnPackTo(&model_); + model_ = UnPackFlatBufferModel(*readonly_model_); } }; @@ -1164,7 +1176,7 @@ class QuantizePackTest : public QuantizeModelTest { QuantizePackTest() { input_model_ = ReadModel(internal::kModelPack); readonly_model_ = input_model_->GetModel(); - readonly_model_->UnPackTo(&model_); + model_ = UnPackFlatBufferModel(*readonly_model_); } }; @@ -1228,7 +1240,7 @@ class QuantizeMinimumMaximumTest QuantizeMinimumMaximumTest() { input_model_ = ReadModel(GetParam()); readonly_model_ = input_model_->GetModel(); - readonly_model_->UnPackTo(&model_); + model_ = UnPackFlatBufferModel(*readonly_model_); } }; @@ -1288,7 +1300,7 @@ class QuantizeUnpackTest : public QuantizeModelTest { QuantizeUnpackTest() { input_model_ = ReadModel(internal::kModelWithUnpack); readonly_model_ = input_model_->GetModel(); - readonly_model_->UnPackTo(&model_); + model_ = UnPackFlatBufferModel(*readonly_model_); } }; @@ -1340,7 +1352,7 @@ class QuantizeBroadcastToModelTest tensor_type_ = GetParam(); input_model_ = ReadModel(internal::kModelWithBroadcastToOp); readonly_model_ = input_model_->GetModel(); - readonly_model_->UnPackTo(&model_); + model_ = UnPackFlatBufferModel(*readonly_model_); } TensorType tensor_type_; }; @@ -1406,7 +1418,7 @@ class QuantizeGatherNDModelTest tensor_type_ = GetParam(); input_model_ = ReadModel(internal::kModelWithGatherNDOp); readonly_model_ = input_model_->GetModel(); - readonly_model_->UnPackTo(&model_); + model_ = UnPackFlatBufferModel(*readonly_model_); } TensorType tensor_type_; @@ -1467,7 +1479,7 @@ class QuantizeWhereModelTest : public QuantizeModelTest { QuantizeWhereModelTest() { input_model_ = ReadModel(internal::kModelWithWhereOp); readonly_model_ = input_model_->GetModel(); - readonly_model_->UnPackTo(&model_); + model_ = UnPackFlatBufferModel(*readonly_model_); } };