Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix heap OOB / undefined behavior in RaggedTensorToTensor #49118

Merged
merged 1 commit into from May 12, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
55 changes: 35 additions & 20 deletions tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc
Expand Up @@ -207,8 +207,8 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
DCHECK_EQ(result->size(), first_dimension);
}

void CalculateOutputIndexRowSplit(
OpKernelContext* context, const RowPartitionTensor& row_split,
Status CalculateOutputIndexRowSplit(
const RowPartitionTensor& row_split,
const vector<INDEX_TYPE>& parent_output_index,
INDEX_TYPE output_index_multiplier, INDEX_TYPE output_size,
vector<INDEX_TYPE>* result) {
Expand All @@ -232,10 +232,11 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
result->push_back(-1);
}
}
if (row_split_size > 0) {
OP_REQUIRES(context, result->size() == row_split(row_split_size - 1),
errors::InvalidArgument("Invalid row split size."));
if (row_split_size > 0 && result->size() != row_split(row_split_size - 1)) {
return errors::InvalidArgument("Invalid row split size.");
}

return Status::OK();
}

// Calculate the output index of the first element of a list.
Expand All @@ -259,20 +260,26 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
// result[6] = -1 because parent_output_index[value_rowids[6]] == -1
// result[7] = -1 because parent_output_index[value_rowids[6]] == -1
// result[8] = parent_output_index[value_rowids[7]]
void CalculateOutputIndexValueRowID(
OpKernelContext* context, const RowPartitionTensor& value_rowids,
Status CalculateOutputIndexValueRowID(
const RowPartitionTensor& value_rowids,
const vector<INDEX_TYPE>& parent_output_index,
INDEX_TYPE output_index_multiplier, INDEX_TYPE output_size,
vector<INDEX_TYPE>* result) {
const INDEX_TYPE index_size = value_rowids.size();
result->reserve(index_size);
if (index_size == 0) {
return;
return Status::OK();
}

INDEX_TYPE current_output_column = 0;
INDEX_TYPE current_value_rowid = value_rowids(0);
DCHECK_LT(current_value_rowid, parent_output_index.size());

if (current_value_rowid >= parent_output_index.size()) {
return errors::InvalidArgument(
"Got current_value_rowid=", current_value_rowid,
" which is not less than ", parent_output_index.size());
}

INDEX_TYPE current_output_index = parent_output_index[current_value_rowid];
result->push_back(current_output_index);
for (INDEX_TYPE i = 1; i < index_size; ++i) {
Expand All @@ -289,13 +296,23 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
} else {
current_output_column = 0;
current_value_rowid = next_value_rowid;
DCHECK_LT(next_value_rowid, parent_output_index.size());

if (next_value_rowid >= parent_output_index.size()) {
return errors::InvalidArgument(
"Got next_value_rowid=", next_value_rowid,
" which is not less than ", parent_output_index.size());
}

current_output_index = parent_output_index[next_value_rowid];
}
result->push_back(current_output_index);
}
OP_REQUIRES(context, result->size() == value_rowids.size(),
errors::InvalidArgument("Invalid row ids."));

if (result->size() != value_rowids.size()) {
return errors::InvalidArgument("Invalid row ids.");
}

return Status::OK();
}

Status CalculateOutputIndex(OpKernelContext* context, int dimension,
Expand All @@ -308,21 +325,19 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
auto partition_type = GetRowPartitionTypeByDimension(dimension);
switch (partition_type) {
case RowPartitionType::VALUE_ROWIDS:
CalculateOutputIndexValueRowID(
context, row_partition_tensor, parent_output_index,
output_index_multiplier, output_size, result);
return tensorflow::Status::OK();
return CalculateOutputIndexValueRowID(
row_partition_tensor, parent_output_index, output_index_multiplier,
output_size, result);
case RowPartitionType::ROW_SPLITS:
if (row_partition_tensor.size() - 1 > parent_output_index.size()) {
return errors::InvalidArgument(
"Row partition size is greater than output size: ",
row_partition_tensor.size() - 1, " > ",
parent_output_index.size());
}
CalculateOutputIndexRowSplit(
context, row_partition_tensor, parent_output_index,
output_index_multiplier, output_size, result);
return tensorflow::Status::OK();
return CalculateOutputIndexRowSplit(
row_partition_tensor, parent_output_index, output_index_multiplier,
output_size, result);
default:
return errors::InvalidArgument(
"Unsupported partition type:",
Expand Down