New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor shape inference test PR2: extend to Op versions <= 5 #5354
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Yuhong Guo <yuhong.gyh@antgroup.com>
Signed-off-by: Yuhong Guo <yuhong.gyh@antgroup.com>
Signed-off-by: Yuhong Guo <yuhong.gyh@antgroup.com>
a876cd6
to
15b0e5a
Compare
Signed-off-by: Yuhong Guo <yuhong.gyh@antgroup.com>
15b0e5a
to
6ceda4d
Compare
@justinchuby @jcwchen @gramalingam Please take a look at this code change. |
@@ -32349,7 +32349,7 @@ expect(node, inputs=[x, axes], outputs=[y], name="test_unsqueeze_unsorted_axes") | |||
|
|||
This version of the operator has been deprecated since version 10 of the default ONNX operator set. | |||
|
|||
Other versions of this operator: <a href="Changelog.md#Upsample-7">7</a>, <a href="Changelog.md#Upsample-9">9</a> | |||
Other versions of this operator: <a href="Changelog.md#Upsample-1">1</a>, <a href="Changelog.md#Upsample-7">7</a>, <a href="Changelog.md#Upsample-9">9</a>, <a href="Changelog.md#Upsample-10">10</a> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious why this was updated?
class TestShapeInferenceHelper(unittest.TestCase): | ||
def _make_graph( | ||
self, | ||
seed_values: Sequence[str | tuple[str, TensorProto.DataType, Any]], | ||
nodes: list[NodeProto], | ||
value_info: list[ValueInfoProto], | ||
# TODO: remove the default value when all _make_graph calls have reshape_version. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you open a tracking issue for this todo and link it here?
inferred_model = self._inferred(graph_or_model, **kwargs) | ||
inferred_vis = list(inferred_model.graph.value_info) | ||
vis = sorted(vis, key=lambda x: x.name) # type: ignore[no-any-return] | ||
inferred_vis = sorted(inferred_vis, key=lambda x: x.name) # type: ignore | ||
assert len(vis) == len(inferred_vis) | ||
if has_inference_func: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add a comment to explain the logic here. It is not immediately clear what the +1 is doing or what invariants it is keeping?
) | ||
self.assertRaises(onnx.shape_inference.InferenceError, self._inferred, graph) | ||
if version == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hope we have less logic inside a test (avoid if), but at the same time I don’t have a good alternative here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be define a utility function like
def InferenceErrorFor(version):
return RuntimeError if (version == 1) else InferenceError
and use it?
) | ||
self.assertRaises(onnx.shape_inference.InferenceError, self._inferred, graph) | ||
if version == 1: | ||
self.assertRaises(RuntimeError, self._inferred, graph) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would comment on why the logic is needed to distinguish version 1
) | ||
self.assertRaises(onnx.shape_inference.InferenceError, self._inferred, graph) | ||
if version == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potentially capture this logic in a function so we don’t duplicate it in many places?
@@ -380,15 +432,22 @@ def test_matmul_allow_unknown(self, _, version) -> None: | |||
|
|||
@parameterized.expand(all_versions_for("Cast")) | |||
def test_cast(self, _, version) -> None: | |||
if version == 1: | |||
target_type = "UINT8" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just skip version 1 for this? @jcwchen
@@ -619,70 +702,116 @@ def test_reshape_dynamic_shape_known_rank(self, _, version) -> None: | |||
|
|||
@parameterized.expand(all_versions_for("Reshape")) | |||
def test_reshape_dynamic_shape_symbolic(self, _, version) -> None: | |||
seed_values: list[tuple[str, TensorProto.DataType, Any]] = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A seed tends to make readers expect random generation? Would « base » be more accurate?
@@ -705,6 +835,10 @@ def test_reshape_static_shape_allowzero(self, _, version) -> None: | |||
|
|||
@parameterized.expand(all_versions_for("Reshape")) | |||
def test_reshape_static_shape_constant(self, _, version) -> None: | |||
if version == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this just version 1? Will we be able to ignore it to reduce much of the special case handling here?
) | ||
|
||
@parameterized.expand(all_versions_for("Upsample")) | ||
def test_upsample(self, _, version) -> None: | ||
if version == 7: | ||
if version <= 7: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
@@ -262,10 +288,12 @@ def test_transpose_preexisting(self, _, version) -> None: | |||
|
|||
@parameterized.expand(all_versions_for("Transpose")) | |||
def test_transpose_scalar(self, _, version) -> None: | |||
self.skipIf(version == 1, "Transpose a scale is not supported in Version 1") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Why not just modify the input to @parameterized.expand([x for x in all_versions_for("Transpose") if x > 1])
? Better still, I think it would be simpler to extend the function all_versions_for
to take an optional lower_bound as a parameter, so we can write @parameterized.expand(all_versions_for("Transpose", from=2))
reshape_inputs = ["SEED_" + seed_name, "UNKNOWN_SHAPE_" + seed_name] | ||
else: | ||
# For Reshape of Version 1, it has only one input: data. | ||
reshape_inputs = ["SEED_" + seed_name] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem right. As the comment above (new linenumber 100) imply, the goal is to create a tensor of unknown-shape (to test the shape-inference logic). And a Reshape with shape as attribute (i.e., before version 5) cannot do that. We may need an alternative way to create a tensor of unknown-shape in version < 5.
Description
Previous refactor PR #5263 only extend the test for Op Version > 5, because there are attribute change for Op "Reshape" which is used in
TestShapeInferenceHelper._make_graph
function.This PR does the following change:
reshape_version
to functionTestShapeInferenceHelper._make_graph
. Then, this function can use it to call different version ofmake_node
forReshape
Op.TestShapeInferenceHelper._assert_inferred
, for some Ops' Version 1 implementation, it's SchemaOp'shas_type_and_shape_inference_function
is False.After this PR, the tests with annotation
@parameterized.expand(all_versions_for("${OpName}"))
covers all Op Versions.Motivation and Context
Fix: #5289
Related Issue: #4160