Skip to content

Commit

Permalink
Merge pull request #49118 from geetachavan1/cherrypicks_BIDTR
Browse files Browse the repository at this point in the history
Fix heap OOB / undefined behavior in `RaggedTensorToTensor`
  • Loading branch information
mihaimaruseac committed May 12, 2021
2 parents 7dee4eb + a892992 commit b67f5b8
Showing 1 changed file with 35 additions and 20 deletions.
55 changes: 35 additions & 20 deletions tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc
Expand Up @@ -207,8 +207,8 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
DCHECK_EQ(result->size(), first_dimension);
}

void CalculateOutputIndexRowSplit(
OpKernelContext* context, const RowPartitionTensor& row_split,
Status CalculateOutputIndexRowSplit(
const RowPartitionTensor& row_split,
const vector<INDEX_TYPE>& parent_output_index,
INDEX_TYPE output_index_multiplier, INDEX_TYPE output_size,
vector<INDEX_TYPE>* result) {
Expand All @@ -232,10 +232,11 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
result->push_back(-1);
}
}
if (row_split_size > 0) {
OP_REQUIRES(context, result->size() == row_split(row_split_size - 1),
errors::InvalidArgument("Invalid row split size."));
if (row_split_size > 0 && result->size() != row_split(row_split_size - 1)) {
return errors::InvalidArgument("Invalid row split size.");
}

return Status::OK();
}

// Calculate the output index of the first element of a list.
Expand All @@ -259,20 +260,26 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
// result[6] = -1 because parent_output_index[value_rowids[6]] == -1
// result[7] = -1 because parent_output_index[value_rowids[6]] == -1
// result[8] = parent_output_index[value_rowids[7]]
void CalculateOutputIndexValueRowID(
OpKernelContext* context, const RowPartitionTensor& value_rowids,
Status CalculateOutputIndexValueRowID(
const RowPartitionTensor& value_rowids,
const vector<INDEX_TYPE>& parent_output_index,
INDEX_TYPE output_index_multiplier, INDEX_TYPE output_size,
vector<INDEX_TYPE>* result) {
const INDEX_TYPE index_size = value_rowids.size();
result->reserve(index_size);
if (index_size == 0) {
return;
return Status::OK();
}

INDEX_TYPE current_output_column = 0;
INDEX_TYPE current_value_rowid = value_rowids(0);
DCHECK_LT(current_value_rowid, parent_output_index.size());

if (current_value_rowid >= parent_output_index.size()) {
return errors::InvalidArgument(
"Got current_value_rowid=", current_value_rowid,
" which is not less than ", parent_output_index.size());
}

INDEX_TYPE current_output_index = parent_output_index[current_value_rowid];
result->push_back(current_output_index);
for (INDEX_TYPE i = 1; i < index_size; ++i) {
Expand All @@ -289,13 +296,23 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
} else {
current_output_column = 0;
current_value_rowid = next_value_rowid;
DCHECK_LT(next_value_rowid, parent_output_index.size());

if (next_value_rowid >= parent_output_index.size()) {
return errors::InvalidArgument(
"Got next_value_rowid=", next_value_rowid,
" which is not less than ", parent_output_index.size());
}

current_output_index = parent_output_index[next_value_rowid];
}
result->push_back(current_output_index);
}
OP_REQUIRES(context, result->size() == value_rowids.size(),
errors::InvalidArgument("Invalid row ids."));

if (result->size() != value_rowids.size()) {
return errors::InvalidArgument("Invalid row ids.");
}

return Status::OK();
}

Status CalculateOutputIndex(OpKernelContext* context, int dimension,
Expand All @@ -308,21 +325,19 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
auto partition_type = GetRowPartitionTypeByDimension(dimension);
switch (partition_type) {
case RowPartitionType::VALUE_ROWIDS:
CalculateOutputIndexValueRowID(
context, row_partition_tensor, parent_output_index,
output_index_multiplier, output_size, result);
return tensorflow::Status::OK();
return CalculateOutputIndexValueRowID(
row_partition_tensor, parent_output_index, output_index_multiplier,
output_size, result);
case RowPartitionType::ROW_SPLITS:
if (row_partition_tensor.size() - 1 > parent_output_index.size()) {
return errors::InvalidArgument(
"Row partition size is greater than output size: ",
row_partition_tensor.size() - 1, " > ",
parent_output_index.size());
}
CalculateOutputIndexRowSplit(
context, row_partition_tensor, parent_output_index,
output_index_multiplier, output_size, result);
return tensorflow::Status::OK();
return CalculateOutputIndexRowSplit(
row_partition_tensor, parent_output_index, output_index_multiplier,
output_size, result);
default:
return errors::InvalidArgument(
"Unsupported partition type:",
Expand Down

0 comments on commit b67f5b8

Please sign in to comment.