Skip to content

Commit

Permalink
[Cherry-pick]Fix SegFault bug in shape inference (#5995)
Browse files Browse the repository at this point in the history
### Description
Cherry-Pick #5990 into `rel-1.16.0` branch

The bug is SegFault during shape inference if the schema not be set
inference function.

The schema has `CheckInputOutputType` for shape inference and `Verify`
for raw node proto. And the behavior is not fully aligned.

### Motivation and Context

fixes #5989

Signed-off-by: opluss <opluss@qq.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Co-authored-by: oPluss <opluss@qq.com>
Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
  • Loading branch information
3 people committed Mar 5, 2024
1 parent e6ff7b9 commit c867683
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
24 changes: 24 additions & 0 deletions onnx/defs/schema.cc
Expand Up @@ -107,6 +107,30 @@ OpSchemaRegistry* OpSchemaRegistry::Instance() {

void OpSchema::CheckInputOutputType(struct InferenceContext& ctx) const {
std::unordered_map<std::string, std::string> type_constraints;
if (inputs_.empty() && ctx.getNumInputs() > 0) {
fail_check(
"Node (",
domain(),
"::",
Name(),
":",
since_version(),
") takes zero inputs, but got ",
ctx.getNumInputs(),
" in graph");
}
if (outputs_.empty() && ctx.getNumOutputs() > 0) {
fail_check(
"Node (",
domain(),
"::",
Name(),
":",
since_version(),
") yields zero outputs, but got ",
ctx.getNumOutputs(),
" in graph");
}
// check all input types
for (size_t in_idx = 0; in_idx < ctx.getNumInputs(); ++in_idx) {
// If the last input is Variadic by definition, checker still needs to check the rest of actual input's type
Expand Down
27 changes: 27 additions & 0 deletions onnx/test/shape_inference_test.py
Expand Up @@ -9875,6 +9875,33 @@ def test_compress_with_axis(self) -> None:
)
self._assert_inferred(graph, [make_tensor_value_info("output", TensorProto.INT64, (2, "N", 3, None))]) # type: ignore

def test_check_type_when_schema_has_empty_io(self):
input = """
<
ir_version: 7,
opset_import: ["" : 1]
>
agraph (X, Y) => (Z)
{
Z = CustomOp(X, Y)
}
"""
model = onnx.parser.parse_model(input)

op_schema = defs.OpSchema(
"CustomOp",
"",
1,
inputs=[],
outputs=[],
)
onnx.defs.register_schema(op_schema)
with self.assertRaises(onnx.shape_inference.InferenceError):
onnx.shape_inference.infer_shapes(model, True)
onnx.defs.deregister_schema(
op_schema.name, op_schema.since_version, op_schema.domain
)


if __name__ == "__main__":
unittest.main()

0 comments on commit c867683

Please sign in to comment.