diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 63b63b9473e403..f9b0ec070488ff 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -2879,6 +2879,10 @@ REGISTER_OP("QuantizeAndDequantizeV2") axis); } else if (axis != -1) { ShapeHandle input; + if (axis >= kint32max) { + return errors::InvalidArgument( + "Axis cannot be >= kint32max value, got ", axis); + } TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input)); DimensionHandle depth; TF_RETURN_IF_ERROR( @@ -2914,6 +2918,10 @@ REGISTER_OP("QuantizeAndDequantizeV4") axis); } else if (axis != -1) { ShapeHandle input; + if (axis >= kint32max) { + return errors::InvalidArgument( + "Axis cannot be >= kint32max value, got ", axis); + } TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input)); DimensionHandle depth; TF_RETURN_IF_ERROR( @@ -2945,6 +2953,10 @@ REGISTER_OP("QuantizeAndDequantizeV4Grad") axis); } else if (axis != -1) { ShapeHandle input; + if (axis >= kint32max) { + return errors::InvalidArgument( + "Axis cannot be >= kint32max value, got ", axis); + } TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input)); DimensionHandle depth; TF_RETURN_IF_ERROR( @@ -2981,6 +2993,10 @@ REGISTER_OP("QuantizeAndDequantizeV3") axis); } else if (axis != -1) { ShapeHandle input; + if (axis >= kint32max) { + return errors::InvalidArgument( + "Axis cannot be >= kint32max value, got ", axis); + } TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input)); DimensionHandle depth; TF_RETURN_IF_ERROR( diff --git a/tensorflow/python/kernel_tests/array_ops/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops/array_ops_test.py index 1a55c64b46934b..f54c92c5e21d61 100644 --- a/tensorflow/python/kernel_tests/array_ops/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops/array_ops_test.py @@ -1791,6 +1791,72 @@ def testOutOfBoundAxis(self): max_range=input_max, axis=2**31 - 1)) + @test_util.run_v2_only + def testInvalidAxis(self): + + @def_function.function + def test_quantize_and_dequantize_v2(): + gen_array_ops.quantize_and_dequantize_v2( + input=[2.5], + input_min=[1.0], + input_max=[10.0], + signed_input=True, + num_bits=1, + range_given=True, + round_mode="HALF_TO_EVEN", + narrow_range=True, + axis=0x7fffffff) + + @def_function.function + def test_quantize_and_dequantize_v3(): + gen_array_ops.quantize_and_dequantize_v3( + input=[2.5], + input_min=[1.0], + input_max=[10.0], + num_bits=1, + signed_input=True, + range_given=True, + narrow_range=True, + axis=0x7fffffff) + + @def_function.function + def test_quantize_and_dequantize_v4(): + gen_array_ops.quantize_and_dequantize_v4( + input=[2.5], + input_min=[1.0], + input_max=[10.0], + signed_input=True, + num_bits=1, + range_given=True, + round_mode="HALF_TO_EVEN", + narrow_range=True, + axis=0x7fffffff) + + @def_function.function + def test_quantize_and_dequantize_v4_grad(): + gen_array_ops.quantize_and_dequantize_v4_grad( + gradients=[2.5], + input=[2.5], + input_min=[1.0], + input_max=[10.0], + axis=0x7fffffff) + + with self.assertRaisesRegex( + ValueError, "Axis cannot be >= kint32max value, got 2147483647"): + test_quantize_and_dequantize_v2() + + with self.assertRaisesRegex( + ValueError, "Axis cannot be >= kint32max value, got 2147483647"): + test_quantize_and_dequantize_v3() + + with self.assertRaisesRegex( + ValueError, "Axis cannot be >= kint32max value, got 2147483647"): + test_quantize_and_dequantize_v4() + + with self.assertRaisesRegex( + ValueError, "Axis cannot be >= kint32max value, got 2147483647"): + test_quantize_and_dequantize_v4_grad() + @test_util.run_all_in_graph_and_eager_modes class SortedSearchTest(test_util.TensorFlowTestCase):