Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate that a and b are proper sparse tensors #49120

Merged
merged 1 commit into from May 12, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 12 additions & 5 deletions tensorflow/core/kernels/sparse_add_op.cc
Expand Up @@ -44,6 +44,11 @@ class SparseAddOp : public OpKernel {
b_indices->shape().DebugString()));
const int64 a_nnz = a_indices->dim_size(0);
const int64 b_nnz = b_indices->dim_size(0);
const int num_dims = a_indices->dim_size(1);
OP_REQUIRES(ctx, b_indices->dim_size(1) == num_dims,
errors::InvalidArgument(
"Input indices must have the same dimension, got ",
num_dims, " and ", b_indices->dim_size(1)));

OP_REQUIRES_OK(ctx, ctx->input("a_values", &a_values_t));
OP_REQUIRES_OK(ctx, ctx->input("b_values", &b_values_t));
Expand Down Expand Up @@ -72,6 +77,13 @@ class SparseAddOp : public OpKernel {
"Input shapes should be a vector but received shapes ",
a_shape->shape().DebugString(), " and ",
b_shape->shape().DebugString()));
OP_REQUIRES(
ctx, a_shape->NumElements() == num_dims,
errors::InvalidArgument("Second dimension of a_indices and length of "
"a_shape must match, got ",
num_dims, " and ", a_shape->NumElements()));
OP_REQUIRES(ctx, num_dims > 0,
errors::InvalidArgument("Tesors must not be empty"));
OP_REQUIRES(
ctx, a_shape->IsSameSize(*b_shape),
errors::InvalidArgument(
Expand Down Expand Up @@ -100,11 +112,6 @@ class SparseAddOp : public OpKernel {
std::vector<std::pair<bool, int64>> entries_to_copy; // from_a?, idx
entries_to_copy.reserve(a_nnz + b_nnz);
std::vector<T> out_values;
const int num_dims = a_shape->dim_size(0);

OP_REQUIRES(ctx, num_dims > 0,
errors::InvalidArgument("Invalid input_a shape. Received: ",
a_shape->DebugString()));

// The input and output sparse tensors are assumed to be ordered along
// increasing dimension number.
Expand Down