Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate that a and b are proper sparse tensors #49128

Merged
merged 1 commit into from May 12, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 9 additions & 6 deletions tensorflow/core/kernels/sparse_sparse_binary_op_shared.cc
Expand Up @@ -150,6 +150,7 @@ class SparseSparseBinaryOpShared : public OpKernel {

const int64 a_nnz = a_indices_t->dim_size(0);
const int64 b_nnz = b_indices_t->dim_size(0);

const auto a_values = a_values_t->vec<T>();
const auto b_values = b_values_t->vec<T>();

Expand All @@ -166,6 +167,14 @@ class SparseSparseBinaryOpShared : public OpKernel {
"Input shapes should be a vector but received shapes ",
a_shape_t->shape().DebugString(), " and ",
b_shape_t->shape().DebugString()));
const int num_dims = a_indices_t->dim_size(1);
OP_REQUIRES(
ctx, a_shape_t->NumElements() == num_dims,
errors::InvalidArgument("Second dimension of a_indices and length of "
"a_shape must match, got ",
num_dims, " and ", a_shape_t->NumElements()));
OP_REQUIRES(ctx, num_dims > 0,
errors::InvalidArgument("Tensors must not be empty"));
OP_REQUIRES(ctx, a_shape_t->IsSameSize(*b_shape_t),
errors::InvalidArgument(
"Operands do not have the same ranks; got shapes: ",
Expand All @@ -180,12 +189,6 @@ class SparseSparseBinaryOpShared : public OpKernel {
" for dimension ", i));
}

OP_REQUIRES(
ctx, a_indices_t->dim_size(1) == b_indices_t->dim_size(1),
errors::InvalidArgument(
"Indices' dimensions do not match: got ", a_indices_t->dim_size(1),
" and ", b_indices_t->dim_size(1), " for the second dimension."));
const int num_dims = a_indices_t->dim_size(1);
const auto a_indices_mat = a_indices_t->matrix<int64>();
const auto b_indices_mat = b_indices_t->matrix<int64>();
std::vector<T> a_augmented_values, b_augmented_values;
Expand Down