From a838d9e739b4a4b948ae867b28aa17a5dcb45c38 Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Fri, 30 Sep 2022 14:49:42 -0700 Subject: [PATCH] Fix empty input crash for SparseFillEmptyRowsGrad. PiperOrigin-RevId: 478085721 --- .../sparse_fill_empty_rows_op_gpu.cu.cc | 43 +++++++++++-------- .../sparse_ops/sparse_ops_test.py | 7 +++ 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/tensorflow/core/kernels/sparse_fill_empty_rows_op_gpu.cu.cc b/tensorflow/core/kernels/sparse_fill_empty_rows_op_gpu.cu.cc index 8ef4ce6172f367..2efa88106ab523 100644 --- a/tensorflow/core/kernels/sparse_fill_empty_rows_op_gpu.cu.cc +++ b/tensorflow/core/kernels/sparse_fill_empty_rows_op_gpu.cu.cc @@ -297,9 +297,12 @@ struct SparseFillEmptyRows { empty_row_indicator = empty_row_indicator_t.vec().data(); } - TF_RETURN_IF_ERROR(wrap_kernel_call(ComputeEmptyRowIndicatorKernel, - /*device=*/device, /*size=*/dense_rows, - elements_per_row, empty_row_indicator)); + if (dense_rows > 0) { + TF_RETURN_IF_ERROR( + wrap_kernel_call(ComputeEmptyRowIndicatorKernel, + /*device=*/device, /*size=*/dense_rows, + elements_per_row, empty_row_indicator)); + } // For each row, the number of empty rows up to and including that row. Tensor num_empty_rows_through_t; @@ -405,14 +408,16 @@ struct SparseFillEmptyRows { done); } - OP_REQUIRES_OK_ASYNC( - context, - wrap_kernel_call(ScatterNewElementsKernel, - /*device=*/device, /*size=*/dense_rows, rank, - default_value, num_empty_rows_through, - input_row_ends, empty_row_indicator, output_indices, - output_values), - done); + if (dense_rows > 0) { + OP_REQUIRES_OK_ASYNC( + context, + wrap_kernel_call(ScatterNewElementsKernel, + /*device=*/device, /*size=*/dense_rows, rank, + default_value, num_empty_rows_through, + input_row_ends, empty_row_indicator, + output_indices, output_values), + done); + } done(); }; @@ -461,9 +466,11 @@ struct SparseFillEmptyRows { TF_RETURN_IF_ERROR( context->allocate_temp(index_type, TensorShape({N}), &row_indices_t)); auto row_indices = row_indices_t.flat(); - TF_RETURN_IF_ERROR(wrap_kernel_call(CopyRowIndicesKernel, - /*device=*/device, /*size=*/N, rank, - indices, row_indices)); + if (N > 0) { + TF_RETURN_IF_ERROR(wrap_kernel_call(CopyRowIndicesKernel, + /*device=*/device, /*size=*/N, rank, + indices, row_indices)); + } // Allocate input_index_map. TF_RETURN_IF_ERROR(context->allocate_temp(index_type, TensorShape({N}), input_index_map_t)); @@ -528,9 +535,11 @@ struct SparseFillEmptyRowsGrad { auto visited = visited_t.vec(); visited.device(device) = visited.constant(false); - TF_RETURN_IF_ERROR(wrap_kernel_call( - GatherOriginalGradValuesKernel, /*device=*/device, - /*size=*/N, reverse_index_map, grad_values, d_values, visited)); + if (N > 0) { + TF_RETURN_IF_ERROR(wrap_kernel_call( + GatherOriginalGradValuesKernel, /*device=*/device, + /*size=*/N, reverse_index_map, grad_values, d_values, visited)); + } // Now we mask out the visited values and sum the remaining ones (which // correspond to the empty rows in the forward input) to compute diff --git a/tensorflow/python/kernel_tests/sparse_ops/sparse_ops_test.py b/tensorflow/python/kernel_tests/sparse_ops/sparse_ops_test.py index 684d1f98432b53..c1ceac68040318 100644 --- a/tensorflow/python/kernel_tests/sparse_ops/sparse_ops_test.py +++ b/tensorflow/python/kernel_tests/sparse_ops/sparse_ops_test.py @@ -514,6 +514,13 @@ def testFillNumber(self): self.assertAllEqual(empty_row_indicator_out, np.array([0, 0, 1, 0, 1]).astype(np.bool_)) + def testSparseFillEmptyRowsGradEmpty(self): + with test_util.use_gpu(): + grad, _ = self.evaluate( + sparse_ops.sparse_fill_empty_rows_grad( + reverse_index_map=[], grad_values=[])) + self.assertAllEqual(grad, []) + @test_util.run_deprecated_v1 def testFillFloat(self): with self.session():