From 6b2bf99cd9336026689579b683a709c5efcb4ae9 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Tue, 11 May 2021 18:32:03 -0700 Subject: [PATCH] Validate that a and b are proper sparse tensors PiperOrigin-RevId: 373274848 Change-Id: I3a665ac3a29dee9fb69bdf408a939330cb93ea75 --- .../kernels/sparse_sparse_binary_op_shared.cc | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/kernels/sparse_sparse_binary_op_shared.cc b/tensorflow/core/kernels/sparse_sparse_binary_op_shared.cc index 9fe42e05d879ee..eb993a5965043b 100644 --- a/tensorflow/core/kernels/sparse_sparse_binary_op_shared.cc +++ b/tensorflow/core/kernels/sparse_sparse_binary_op_shared.cc @@ -150,6 +150,7 @@ class SparseSparseBinaryOpShared : public OpKernel { const int64 a_nnz = a_indices_t->dim_size(0); const int64 b_nnz = b_indices_t->dim_size(0); + const auto a_values = a_values_t->vec(); const auto b_values = b_values_t->vec(); @@ -166,6 +167,14 @@ class SparseSparseBinaryOpShared : public OpKernel { "Input shapes should be a vector but received shapes ", a_shape_t->shape().DebugString(), " and ", b_shape_t->shape().DebugString())); + const int num_dims = a_indices_t->dim_size(1); + OP_REQUIRES( + ctx, a_shape_t->NumElements() == num_dims, + errors::InvalidArgument("Second dimension of a_indices and length of " + "a_shape must match, got ", + num_dims, " and ", a_shape_t->NumElements())); + OP_REQUIRES(ctx, num_dims > 0, + errors::InvalidArgument("Tensors must not be empty")); OP_REQUIRES(ctx, a_shape_t->IsSameSize(*b_shape_t), errors::InvalidArgument( "Operands do not have the same ranks; got shapes: ", @@ -180,12 +189,6 @@ class SparseSparseBinaryOpShared : public OpKernel { " for dimension ", i)); } - OP_REQUIRES( - ctx, a_indices_t->dim_size(1) == b_indices_t->dim_size(1), - errors::InvalidArgument( - "Indices' dimensions do not match: got ", a_indices_t->dim_size(1), - " and ", b_indices_t->dim_size(1), " for the second dimension.")); - const int num_dims = a_indices_t->dim_size(1); const auto a_indices_mat = a_indices_t->matrix(); const auto b_indices_mat = b_indices_t->matrix(); std::vector a_augmented_values, b_augmented_values;