Skip to content

Commit

Permalink
Merge pull request #58570 from tensorflow/r2.9-7b174a0f2e4
Browse files Browse the repository at this point in the history
r2.9 cherry-pick: 7b174a0 "Fix asan issue with QuantizeAndDequantizeV2/V3/V4/V4Grad shape inference functions."
  • Loading branch information
mihaimaruseac committed Nov 14, 2022
2 parents fb6df70 + eb92930 commit 5dbe90a
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 0 deletions.
16 changes: 16 additions & 0 deletions tensorflow/core/ops/array_ops.cc
Expand Up @@ -2879,6 +2879,10 @@ REGISTER_OP("QuantizeAndDequantizeV2")
axis);
} else if (axis != -1) {
ShapeHandle input;
if (axis >= kint32max) {
return errors::InvalidArgument(
"Axis cannot be >= kint32max value, got ", axis);
}
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
DimensionHandle depth;
TF_RETURN_IF_ERROR(
Expand Down Expand Up @@ -2914,6 +2918,10 @@ REGISTER_OP("QuantizeAndDequantizeV4")
axis);
} else if (axis != -1) {
ShapeHandle input;
if (axis >= kint32max) {
return errors::InvalidArgument(
"Axis cannot be >= kint32max value, got ", axis);
}
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
DimensionHandle depth;
TF_RETURN_IF_ERROR(
Expand Down Expand Up @@ -2945,6 +2953,10 @@ REGISTER_OP("QuantizeAndDequantizeV4Grad")
axis);
} else if (axis != -1) {
ShapeHandle input;
if (axis >= kint32max) {
return errors::InvalidArgument(
"Axis cannot be >= kint32max value, got ", axis);
}
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
DimensionHandle depth;
TF_RETURN_IF_ERROR(
Expand Down Expand Up @@ -2981,6 +2993,10 @@ REGISTER_OP("QuantizeAndDequantizeV3")
axis);
} else if (axis != -1) {
ShapeHandle input;
if (axis >= kint32max) {
return errors::InvalidArgument(
"Axis cannot be >= kint32max value, got ", axis);
}
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
DimensionHandle depth;
TF_RETURN_IF_ERROR(
Expand Down
66 changes: 66 additions & 0 deletions tensorflow/python/kernel_tests/array_ops/array_ops_test.py
Expand Up @@ -1791,6 +1791,72 @@ def testOutOfBoundAxis(self):
max_range=input_max,
axis=2**31 - 1))

@test_util.run_v2_only
def testInvalidAxis(self):

@def_function.function
def test_quantize_and_dequantize_v2():
gen_array_ops.quantize_and_dequantize_v2(
input=[2.5],
input_min=[1.0],
input_max=[10.0],
signed_input=True,
num_bits=1,
range_given=True,
round_mode="HALF_TO_EVEN",
narrow_range=True,
axis=0x7fffffff)

@def_function.function
def test_quantize_and_dequantize_v3():
gen_array_ops.quantize_and_dequantize_v3(
input=[2.5],
input_min=[1.0],
input_max=[10.0],
num_bits=1,
signed_input=True,
range_given=True,
narrow_range=True,
axis=0x7fffffff)

@def_function.function
def test_quantize_and_dequantize_v4():
gen_array_ops.quantize_and_dequantize_v4(
input=[2.5],
input_min=[1.0],
input_max=[10.0],
signed_input=True,
num_bits=1,
range_given=True,
round_mode="HALF_TO_EVEN",
narrow_range=True,
axis=0x7fffffff)

@def_function.function
def test_quantize_and_dequantize_v4_grad():
gen_array_ops.quantize_and_dequantize_v4_grad(
gradients=[2.5],
input=[2.5],
input_min=[1.0],
input_max=[10.0],
axis=0x7fffffff)

with self.assertRaisesRegex(
ValueError, "Axis cannot be >= kint32max value, got 2147483647"):
test_quantize_and_dequantize_v2()

with self.assertRaisesRegex(
ValueError, "Axis cannot be >= kint32max value, got 2147483647"):
test_quantize_and_dequantize_v3()

with self.assertRaisesRegex(
ValueError, "Axis cannot be >= kint32max value, got 2147483647"):
test_quantize_and_dequantize_v4()

with self.assertRaisesRegex(
ValueError, "Axis cannot be >= kint32max value, got 2147483647"):
test_quantize_and_dequantize_v4_grad()


@test_util.run_all_in_graph_and_eager_modes
class SortedSearchTest(test_util.TensorFlowTestCase):
Expand Down

0 comments on commit 5dbe90a

Please sign in to comment.