From a8929929f7a761312228c923ba5a9028b239e99f Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Tue, 11 May 2021 15:22:49 -0700 Subject: [PATCH] Fix heap OOB / undefined behavior in `RaggedTensorToTensor` PiperOrigin-RevId: 373244623 Change-Id: I2d6cbbc8c67b238a8815bf58097f7586d87c54f2 --- .../kernels/ragged_tensor_to_tensor_op.cc | 55 ++++++++++++------- 1 file changed, 35 insertions(+), 20 deletions(-) diff --git a/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc b/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc index 376d55945d2ce8..b79a07e67ba913 100644 --- a/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc +++ b/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc @@ -207,8 +207,8 @@ class RaggedTensorToTensorBaseOp : public OpKernel { DCHECK_EQ(result->size(), first_dimension); } - void CalculateOutputIndexRowSplit( - OpKernelContext* context, const RowPartitionTensor& row_split, + Status CalculateOutputIndexRowSplit( + const RowPartitionTensor& row_split, const vector& parent_output_index, INDEX_TYPE output_index_multiplier, INDEX_TYPE output_size, vector* result) { @@ -232,10 +232,11 @@ class RaggedTensorToTensorBaseOp : public OpKernel { result->push_back(-1); } } - if (row_split_size > 0) { - OP_REQUIRES(context, result->size() == row_split(row_split_size - 1), - errors::InvalidArgument("Invalid row split size.")); + if (row_split_size > 0 && result->size() != row_split(row_split_size - 1)) { + return errors::InvalidArgument("Invalid row split size."); } + + return Status::OK(); } // Calculate the output index of the first element of a list. @@ -259,20 +260,26 @@ class RaggedTensorToTensorBaseOp : public OpKernel { // result[6] = -1 because parent_output_index[value_rowids[6]] == -1 // result[7] = -1 because parent_output_index[value_rowids[6]] == -1 // result[8] = parent_output_index[value_rowids[7]] - void CalculateOutputIndexValueRowID( - OpKernelContext* context, const RowPartitionTensor& value_rowids, + Status CalculateOutputIndexValueRowID( + const RowPartitionTensor& value_rowids, const vector& parent_output_index, INDEX_TYPE output_index_multiplier, INDEX_TYPE output_size, vector* result) { const INDEX_TYPE index_size = value_rowids.size(); result->reserve(index_size); if (index_size == 0) { - return; + return Status::OK(); } INDEX_TYPE current_output_column = 0; INDEX_TYPE current_value_rowid = value_rowids(0); - DCHECK_LT(current_value_rowid, parent_output_index.size()); + + if (current_value_rowid >= parent_output_index.size()) { + return errors::InvalidArgument( + "Got current_value_rowid=", current_value_rowid, + " which is not less than ", parent_output_index.size()); + } + INDEX_TYPE current_output_index = parent_output_index[current_value_rowid]; result->push_back(current_output_index); for (INDEX_TYPE i = 1; i < index_size; ++i) { @@ -289,13 +296,23 @@ class RaggedTensorToTensorBaseOp : public OpKernel { } else { current_output_column = 0; current_value_rowid = next_value_rowid; - DCHECK_LT(next_value_rowid, parent_output_index.size()); + + if (next_value_rowid >= parent_output_index.size()) { + return errors::InvalidArgument( + "Got next_value_rowid=", next_value_rowid, + " which is not less than ", parent_output_index.size()); + } + current_output_index = parent_output_index[next_value_rowid]; } result->push_back(current_output_index); } - OP_REQUIRES(context, result->size() == value_rowids.size(), - errors::InvalidArgument("Invalid row ids.")); + + if (result->size() != value_rowids.size()) { + return errors::InvalidArgument("Invalid row ids."); + } + + return Status::OK(); } Status CalculateOutputIndex(OpKernelContext* context, int dimension, @@ -308,10 +325,9 @@ class RaggedTensorToTensorBaseOp : public OpKernel { auto partition_type = GetRowPartitionTypeByDimension(dimension); switch (partition_type) { case RowPartitionType::VALUE_ROWIDS: - CalculateOutputIndexValueRowID( - context, row_partition_tensor, parent_output_index, - output_index_multiplier, output_size, result); - return tensorflow::Status::OK(); + return CalculateOutputIndexValueRowID( + row_partition_tensor, parent_output_index, output_index_multiplier, + output_size, result); case RowPartitionType::ROW_SPLITS: if (row_partition_tensor.size() - 1 > parent_output_index.size()) { return errors::InvalidArgument( @@ -319,10 +335,9 @@ class RaggedTensorToTensorBaseOp : public OpKernel { row_partition_tensor.size() - 1, " > ", parent_output_index.size()); } - CalculateOutputIndexRowSplit( - context, row_partition_tensor, parent_output_index, - output_index_multiplier, output_size, result); - return tensorflow::Status::OK(); + return CalculateOutputIndexRowSplit( + row_partition_tensor, parent_output_index, output_index_multiplier, + output_size, result); default: return errors::InvalidArgument( "Unsupported partition type:",