From 8482d7449805d7647ed5b3b47513677a22ae2993 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Tue, 11 May 2021 15:41:51 -0700 Subject: [PATCH] Validate that a and b are proper sparse tensors PiperOrigin-RevId: 373248068 Change-Id: I0a2041a0747901b3f00387a6a3bce9bca6b0b3b1 --- tensorflow/core/kernels/sparse_add_op.cc | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/kernels/sparse_add_op.cc b/tensorflow/core/kernels/sparse_add_op.cc index 346206365af8d5..2bd05fa41adc26 100644 --- a/tensorflow/core/kernels/sparse_add_op.cc +++ b/tensorflow/core/kernels/sparse_add_op.cc @@ -44,6 +44,11 @@ class SparseAddOp : public OpKernel { b_indices->shape().DebugString())); const int64 a_nnz = a_indices->dim_size(0); const int64 b_nnz = b_indices->dim_size(0); + const int num_dims = a_indices->dim_size(1); + OP_REQUIRES(ctx, b_indices->dim_size(1) == num_dims, + errors::InvalidArgument( + "Input indices must have the same dimension, got ", + num_dims, " and ", b_indices->dim_size(1))); OP_REQUIRES_OK(ctx, ctx->input("a_values", &a_values_t)); OP_REQUIRES_OK(ctx, ctx->input("b_values", &b_values_t)); @@ -72,6 +77,13 @@ class SparseAddOp : public OpKernel { "Input shapes should be a vector but received shapes ", a_shape->shape().DebugString(), " and ", b_shape->shape().DebugString())); + OP_REQUIRES( + ctx, a_shape->NumElements() == num_dims, + errors::InvalidArgument("Second dimension of a_indices and length of " + "a_shape must match, got ", + num_dims, " and ", a_shape->NumElements())); + OP_REQUIRES(ctx, num_dims > 0, + errors::InvalidArgument("Tesors must not be empty")); OP_REQUIRES( ctx, a_shape->IsSameSize(*b_shape), errors::InvalidArgument( @@ -100,11 +112,6 @@ class SparseAddOp : public OpKernel { std::vector> entries_to_copy; // from_a?, idx entries_to_copy.reserve(a_nnz + b_nnz); std::vector out_values; - const int num_dims = a_shape->dim_size(0); - - OP_REQUIRES(ctx, num_dims > 0, - errors::InvalidArgument("Invalid input_a shape. Received: ", - a_shape->DebugString())); // The input and output sparse tensors are assumed to be ordered along // increasing dimension number.