Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiate shape and type inference errors #5519

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions onnx/checker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,7 @@ void check_model(const std::string& model_path, bool full_check, bool skip_opset
check_model(model, ctx);

if (full_check) {
ShapeInferenceOptions options{true, 1, false};
ShapeInferenceOptions options{true, FailAnyInferenceError, false};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to scope this enum for clarity? e.g. errors::FailAnyInferenceError or shape_inference::FailAnyInferenceError etc.

@jcwchen any suggestions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is correct, this PR is far from complete. There are many changes on API surface so we want to make it right before merge. I am more concerned on the Python side. Naming is one thing that we need to be precise. In terms of scope, it is already under ShapeInference module but adding errors.
It is less a concern on the c++ side however.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why it's less of a concern on the c++ side?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should have said it was a concern for both. It is just that python api is more visible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby it is better to scope the enum. do you think InferenceErrorMode instead of error is better?

enum class InferenceErrorMode : uint32_t {
IgnoreInferenceError, // Ignore any inference errors
FailAnyInferenceError, // Fail on any inference error
FailShapeInferenceError, // Fail on any shape inference error, like merging existing shape with inferred etc.
FailTypeInferenceError // Fail on any type inference error
};

ONNX_NAMESPACE::shape_inference::InferShapes(model, ctx.get_schema_registry(), options);
}
}
Expand All @@ -1046,7 +1046,7 @@ void check_model(const ModelProto& model, bool full_check, bool skip_opset_compa
ctx.set_skip_opset_compatibility_check(skip_opset_compatibility_check);
check_model(model, ctx);
if (full_check) {
ShapeInferenceOptions options{true, 1, false};
ShapeInferenceOptions options{true, FailAnyInferenceError, false};
// Do not update the model in place by the check from shape inference
// because checker should not modify the original model
ModelProto copy = model;
Expand Down
22 changes: 15 additions & 7 deletions onnx/cpp2py_export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@
shape_inference::GraphInferenceContext graphInferenceContext(
valueTypes.second, opsetImports, nullptr, {}, OpSchemaRegistry::Instance(), nullptr, irVersion);
// Construct inference context and get results - may throw InferenceError
// TODO: if it is desirable for infer_node_outputs to provide check_type, strict_mode, data_prop,
// TODO: if it is desirable for infer_node_outputs to provide check_type, error_mode, data_prop,

Check warning on line 82 in onnx/cpp2py_export.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnx/cpp2py_export.cc:82: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// we can add them to the Python API. For now we just assume the default options.
ShapeInferenceOptions options{false, 0, false};
ShapeInferenceOptions options{false, IgnoreInferenceError, false};
shape_inference::InferenceContextImpl ctx(
node, valueTypes.second, inputData.second, inputSparseData.second, options, nullptr, &graphInferenceContext);
schema->GetTypeAndShapeInferenceFunction()(ctx);
Expand Down Expand Up @@ -609,31 +609,39 @@
auto shape_inference = onnx_cpp2py_export.def_submodule("shape_inference");
shape_inference.doc() = "Shape Inference submodule";
py::register_exception<InferenceError>(shape_inference, "InferenceError");
py::register_exception<TypeError>(shape_inference, "TypeError");
py::register_exception<ShapeError>(shape_inference, "ShapeError");

Check warning on line 614 in onnx/cpp2py_export.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnx/cpp2py_export.cc:614: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
Fixed Show fixed Hide fixed
py::enum_<InferenceErrorMode>(shape_inference, "InferenceErrorMode")
.value("IgnoreInferenceError", IgnoreInferenceError)
.value("FailAnyInferenceError", FailAnyInferenceError)
.value("FailShapeInferenceError", FailShapeInferenceError)
.value("FailTypeInferenceError", FailTypeInferenceError);

shape_inference.def(
"infer_shapes",
[](const py::bytes& bytes, bool check_type, bool strict_mode, bool data_prop) {
[](const py::bytes& bytes, bool check_type, InferenceErrorMode error_mode, bool data_prop) {
ModelProto proto{};
ParseProtoFromPyBytes(&proto, bytes);
ShapeInferenceOptions options{check_type, strict_mode == true ? 1 : 0, data_prop};
ShapeInferenceOptions options{check_type, error_mode, data_prop};
shape_inference::InferShapes(proto, OpSchemaRegistry::Instance(), options);
std::string out;
proto.SerializeToString(&out);
return py::bytes(out);
},
"bytes"_a,
"check_type"_a = false,
"strict_mode"_a = false,
"error_mode"_a = IgnoreInferenceError,
"data_prop"_a = false);

shape_inference.def(
"infer_shapes_path",
[](const std::string& model_path,
const std::string& output_path,
bool check_type,
bool strict_mode,
InferenceErrorMode error_mode,
bool data_prop) -> void {
ShapeInferenceOptions options{check_type, strict_mode == true ? 1 : 0, data_prop};
ShapeInferenceOptions options{check_type, error_mode, data_prop};
shape_inference::InferShapes(model_path, output_path, OpSchemaRegistry::Instance(), options);
});

Expand Down
39 changes: 28 additions & 11 deletions onnx/defs/shape_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,22 @@

using Dim = TensorShapeProto_Dimension;

enum InferenceErrorMode : uint32_t {
IgnoreInferenceError, // Ignore any inference errors

Check warning on line 22 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnx/defs/shape_inference.h:22: At least two spaces is best between code and comments [whitespace/comments] [2]
FailAnyInferenceError, // Fail on any inference error
FailShapeInferenceError, // Fail on any shape inference error, like merging existing shape with inferred etc.
FailTypeInferenceError // Fail on any type inference error
};

struct ShapeInferenceOptions {
// Checks the type-equality for input and output
bool check_type;
// 1: Will throw any node level shape infer errors
// 0: Won't throw node-level shape infer errors, but other errors
// like merging existing shape with inferred etc are thrown
int error_mode;
InferenceErrorMode error_mode;

Check warning on line 31 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: member variable 'error_mode' has public visibility [misc-non-private-member-variables-in-classes] ```cpp InferenceErrorMode error_mode; ^ ```
// Enables data propagation for limited operators
// to perform shape computation
bool enable_data_propagation;
ShapeInferenceOptions(bool check_type_val = false, int strict_mode_val = 0, bool data_prop_val = false)
: check_type(check_type_val), error_mode(strict_mode_val), enable_data_propagation(data_prop_val){};
ShapeInferenceOptions(bool check_type_val = false, InferenceErrorMode error_mode_val = IgnoreInferenceError, bool data_prop_val = false)

Check warning on line 35 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnx/defs/shape_inference.h:35: Lines should be <= 120 characters long [whitespace/line_length] [2]

Check warning on line 35 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Constructors callable with one argument should be marked explicit. [runtime/explicit] [5] Raw Output: onnx/defs/shape_inference.h:35: Constructors callable with one argument should be marked explicit. [runtime/explicit] [5]
: check_type(check_type_val), error_mode(error_mode_val), enable_data_propagation(data_prop_val){};

Check warning on line 36 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnx/defs/shape_inference.h:36: You don't need a ; after a } [readability/braces] [4]
};

// Maintains a SymbolTable for symbolic shape inference
Expand All @@ -55,9 +59,8 @@
virtual ~GraphInferencer() = default;
};

// Exception class used for handling errors in type and shape inference

class InferenceError final : public std::runtime_error {
// Exception classes used for handling errors in type and shape inference
class InferenceError : public std::runtime_error {
public:
using std::runtime_error::runtime_error;

Expand All @@ -78,12 +81,26 @@
std::string expanded_message_;
};

class ShapeError final : public InferenceError {
public:
ShapeError(const std::string& message)
: InferenceError(ONNX_NAMESPACE::MakeString("[ShapeError] ", message)) {}
};

class TypeError final : public InferenceError {
public:
TypeError(const std::string& message)
: InferenceError(ONNX_NAMESPACE::MakeString("[TypeError] ", message)) {}
};

#define fail_type_inference(...) \
ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString("[TypeInferenceError] ", __VA_ARGS__)));
ONNX_THROW_EX(ONNX_NAMESPACE::TypeError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));

#define fail_shape_inference(...) \
ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString("[ShapeInferenceError] ", __VA_ARGS__)));
ONNX_THROW_EX(ONNX_NAMESPACE::ShapeError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));

#define fail_inference(...) \

Check warning on line 102 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: variadic macro 'fail_inference' used; consider using a 'constexpr' variadic template function [cppcoreguidelines-macro-usage] ```cpp #define fail_inference(...) \ ^ ```
ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));
struct InferenceContext {
virtual const AttributeProto* getAttribute(const std::string& name) const = 0;
virtual size_t getNumInputs() const = 0;
Expand Down
16 changes: 12 additions & 4 deletions onnx/shape_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
def infer_shapes(
model: ModelProto | bytes,
check_type: bool = False,
strict_mode: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to keep the strict_mode option for backwards compatibility

error_mode: C.InferenceErrorMode = C.InferenceErrorMode.IgnoreInferenceError,

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "C.InferenceErrorMode" is not defined To disable, use # type: ignore[name-defined]

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

Module has no attribute "InferenceErrorMode"; maybe "InferenceError"? To disable, use # type: ignore[attr-defined]
data_prop: bool = False,
) -> ModelProto:
"""Apply shape inference to the provided ModelProto.
Expand All @@ -36,6 +36,11 @@
check_type: Checks the type-equality for input and output.
strict_mode: Stricter shape inference, it will throw errors if any;
Otherwise, simply stop if any error.
error_mode (InferenceErrorMode):
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
IgnoreInferenceError: ignore any inference error and continue
FailAnyInferenceError: stop on any inference error and throw the exception combining all inference errors
FailShapeInferenceError: stop on shape inference error and throw the exception
FailTypeInferenceError: stop on type inference error and throw the exception
data_prop: Enables data propagation for limited operators to perform shape computation.

Returns:
Expand All @@ -44,7 +49,7 @@
if isinstance(model, (ModelProto, bytes)):
model_str = model if isinstance(model, bytes) else model.SerializeToString()
inferred_model_str = C.infer_shapes(
model_str, check_type, strict_mode, data_prop
model_str, check_type, error_mode, data_prop
)
return onnx.load_from_string(inferred_model_str)
if isinstance(model, str):
Expand All @@ -62,7 +67,7 @@
model_path: str | os.PathLike,
output_path: str | os.PathLike = "",
check_type: bool = False,
strict_mode: bool = False,
error_mode: C.InferenceErrorMode = C.InferenceErrorMode.IgnoreInferenceError,

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "C.InferenceErrorMode" is not defined To disable, use # type: ignore[name-defined]

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

Module has no attribute "InferenceErrorMode"; maybe "InferenceError"? To disable, use # type: ignore[attr-defined]
data_prop: bool = False,
) -> None:
"""Take model path for shape_inference.
Expand Down Expand Up @@ -93,7 +98,7 @@

if output_path == "":
output_path = model_path
C.infer_shapes_path(model_path, output_path, check_type, strict_mode, data_prop)
C.infer_shapes_path(model_path, output_path, check_type, error_mode, data_prop)


def infer_node_outputs(
Expand Down Expand Up @@ -169,3 +174,6 @@


InferenceError = C.InferenceError
TypeError = C.TypeError

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

Module has no attribute "TypeError" To disable, use # type: ignore[attr-defined]

Check warning

Code scanning / lintrunner

RUFF/A001 Warning

Variable TypeError is shadowing a Python builtin.
See https://docs.astral.sh/ruff/rules/builtin-variable-shadowing
ShapeError = C.ShapeError

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

Module has no attribute "ShapeError" To disable, use # type: ignore[attr-defined]
InferenceErrorMode = C.InferenceErrorMode

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

Module has no attribute "InferenceErrorMode"; maybe "InferenceError"? To disable, use # type: ignore[attr-defined]
79 changes: 58 additions & 21 deletions onnx/shape_inference/implementation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -497,18 +497,33 @@
}
}
}
ONNX_CATCH(const ONNX_NAMESPACE::InferenceError& ex) {
// Note: The following special handling is to accommodate custom-ops. Ideally, custom-ops
// should be registered with a schema in the schema registry, allowing inference to handle
// them. As things stand, this special handling is somewhat fragile and is not fully
// general either. Eg., a custom-op suppresses error-messages for subsequent nodes, but
// this does not work across graphs. If special handling is required, a user-option may
// be a better way to do it. The fragility comes from the fact that the types of the
// returned-values of the custom-op are unknown, and subsequent node-level inference
// may fail because of this.
ONNX_CATCH(const ONNX_NAMESPACE::TypeError& ex) {
ONNX_HANDLE_EXCEPTION([&]() {
// Note: The following special handling is to accommodate custom-ops. Ideally, custom-ops
// should be registered with a schema in the schema registry, allowing inference to handle
// them. As things stand, this special handling is somewhat fragile and is not fully
// general either. Eg., a custom-op suppresses error-messages for subsequent nodes, but
// this does not work across graphs. If special handling is required, a user-option may
// be a better way to do it. The fragility comes from the fact that the types of the
// returned-values of the custom-op are unknown, and subsequent node-level inference
// may fail because of this.
// onnx does not support unsupported/experimental operators
// so it won't consider it as an error
if (!has_unsupported_op) {
inference_errors.push_back(GetErrorWithNodeInfo(n, ex));
auto error_info = GetErrorWithNodeInfo(n, ex);
inference_errors.push_back(std::make_exception_ptr(ex));
}
});
// Continue with inference for remaining nodes
return;
}
ONNX_CATCH(const ONNX_NAMESPACE::ShapeError& ex) {
ONNX_HANDLE_EXCEPTION([&]() {
// onnx does not support unsupported/experimental operators
// so it won't consider it as an error
if (!has_unsupported_op) {
auto error_info = GetErrorWithNodeInfo(n, ex);
inference_errors.push_back(std::make_exception_ptr(ex));
}
});
}
Expand Down Expand Up @@ -693,23 +708,45 @@
}

void FinalizeShapeInference() {
auto& errors = getErrors();
// Throw shape inference error if any. Error mode right now only supports 0 and 1.
// When set to 0, any node level shape inference errors are not thrown. This is to support backward compatiblity

Check warning on line 712 in onnx/shape_inference/implementation.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "compatiblity" is a misspelling of "compatibility" Raw Output: ./onnx/shape_inference/implementation.cc:712:104: "compatiblity" is a misspelling of "compatibility"
// with 1.7 and earlier releases. When set to 1 it will throw all exceptions.
// TODO: Add a more granular way for exception handling.
if (!errors.empty() && options.error_mode > 0) {
std::string full_errors = "Inference error(s): ";
for (const std::string& error : inference_errors) {
full_errors += error + "\n";
std::string shape_inference_errors, type_inference_errors, any_inference_errors, other_errors;

Check warning on line 715 in onnx/shape_inference/implementation.cc

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: variable 'shape_inference_errors' of type 'std::string' (aka 'basic_string<char>') can be declared 'const' [misc-const-correctness] ```cpp std::string shape_inference_errors, type_inference_errors, any_inference_errors, other_errors; ^ ```

Check warning on line 715 in onnx/shape_inference/implementation.cc

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: multiple declarations in a single statement reduces readability [readability-isolate-declaration] ```suggestion std::string shape_inference_errors; std::string type_inference_errors; std::string any_inference_errors; std::string other_errors; ```

if (!inference_errors.empty() && options.error_mode != IgnoreInferenceError) {
for (const auto& exceptionPtr : inference_errors) {
try {
std::rethrow_exception(exceptionPtr);
} catch (const TypeError& type_error) {
type_inference_errors += type_error.what();
any_inference_errors += type_error.what();
} catch (const ShapeError& shape_error) {
shape_inference_errors += shape_error.what();
any_inference_errors += shape_error.what();
} catch (const InferenceError& inference_error) {
any_inference_errors += inference_error.what();
} catch (const std::exception& other_error) {
other_errors += other_error.what();
}
}
// depend on error_mode and collected errors, fail shape or type inference
// the order of error types being process implies the priority of error types
if (!type_inference_errors.empty()) {
fail_type_inference(type_inference_errors);
} else if (!shape_inference_errors.empty()) {
fail_shape_inference(shape_inference_errors);
} else if (!any_inference_errors.empty()) {
fail_inference(any_inference_errors);
} else if (!other_errors.empty()) {
throw std::runtime_error(other_errors);
}

Check warning on line 743 in onnx/shape_inference/implementation.cc

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: do not use 'else' after 'throw' [readability-else-after-return] ```suggestion } if (!shape_inference_errors.empty()) { fail_shape_inference(shape_inference_errors); } else if (!any_inference_errors.empty()) { fail_inference(any_inference_errors); } else if (!other_errors.empty()) { throw std::runtime_error(other_errors); } ```
fail_shape_inference(full_errors);
}
}

const std::vector<std::string>& getErrors() const {
return inference_errors;
}
//const std::vector<std::string>& getErrors() const {

Check warning on line 747 in onnx/shape_inference/implementation.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Should have a space between // and comment [whitespace/comments] [4] Raw Output: onnx/shape_inference/implementation.cc:747: Should have a space between // and comment [whitespace/comments] [4]
// return inference_errors;
//}
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed

private:
InferredTypes inferred_types;
Expand All @@ -731,7 +768,7 @@

bool has_unsupported_op = false;

std::vector<std::string> inference_errors;
std::vector<std::exception_ptr> inference_errors;

std::list<TypeProto> initializer_type_list;

Expand Down Expand Up @@ -969,7 +1006,7 @@
const std::vector<AttributeProto>& attributes) {
// TODO: if it is desirable for infer_function_output_types to provide check_type, strict_mode, data_prop,
// we can add them to the Python API. For now we just assume the default options.
ShapeInferenceOptions options{true, 1, false};
ShapeInferenceOptions options{true, FailAnyInferenceError, false};
FunctionInferenceContext ctx(function_proto, input_types, attributes, options);
auto opset_imports = GetOpsetImportsFromProto(function_proto);
ShapeInferenceImplBase base(
Expand Down
12 changes: 6 additions & 6 deletions onnx/test/checker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_check_graph_types(self) -> None:
model = helper.make_model(graph, producer_name="test")

self.assertRaises(
shape_inference.InferenceError, checker.check_model, model, True
shape_inference.ShapeError, checker.check_model, model, True
)

checker.check_graph(graph)
Expand Down Expand Up @@ -479,7 +479,7 @@ def test_check_model_unsupported_input_type(self) -> None:
graph = helper.make_graph([node], "test_add_input", [X, Y], [Z])
model = helper.make_model(graph, producer_name="test", opset_imports=[onnx_id])
self.assertRaises(
shape_inference.InferenceError, checker.check_model, model, True
shape_inference.ShapeError, checker.check_model, model, True
)

def test_check_model_inconsistent_type(self) -> None:
Expand All @@ -492,7 +492,7 @@ def test_check_model_inconsistent_type(self) -> None:
graph = helper.make_graph([node], "test_add_input", [X, Y], [Z])
model = helper.make_model(graph, producer_name="test", opset_imports=[onnx_id])
self.assertRaises(
shape_inference.InferenceError, checker.check_model, model, True
shape_inference.ShapeError, checker.check_model, model, True
)

def test_check_model_unsupported_output_type(self) -> None:
Expand All @@ -505,7 +505,7 @@ def test_check_model_unsupported_output_type(self) -> None:
graph = helper.make_graph([node], "test_add_input", [X, Y], [Z])
model = helper.make_model(graph, producer_name="test", opset_imports=[onnx_id])
self.assertRaises(
shape_inference.InferenceError, checker.check_model, model, True
shape_inference.TypeError, checker.check_model, model, True
)

def test_loop_with_same_initializer_input_below_ir4(self) -> None:
Expand Down Expand Up @@ -847,7 +847,7 @@ def test_loop_with_different_initializer_input_below_ir4(self) -> None:
),
)
self.assertRaises(
shape_inference.InferenceError, checker.check_model, model, True
shape_inference.TypeError, checker.check_model, model, True
)

def test_loop_with_same_initializer_input_above_ir4(self) -> None:
Expand Down Expand Up @@ -1027,7 +1027,7 @@ def test_loop_with_same_initializer_input_above_ir4(self) -> None:
),
)
self.assertRaises(
shape_inference.InferenceError, checker.check_model, model, True
shape_inference.TypeError, checker.check_model, model, True
)

def test_empty_list_attribute(self):
Expand Down
4 changes: 2 additions & 2 deletions onnx/test/cpp/function_verify_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ foo (x) => (y) {
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
check_model(model);

ShapeInferenceOptions options{true, 1, true};
ShapeInferenceOptions options{true, FailAnyInferenceError, true};
ONNX_NAMESPACE::shape_inference::InferShapes(model, OpSchemaRegistry::Instance(), options);
}

Expand Down Expand Up @@ -550,7 +550,7 @@ foo (x) => (y) {

check_model(model);

ShapeInferenceOptions options{true, 1, true};
ShapeInferenceOptions options{true, FailAnyInferenceError, true};
ONNX_NAMESPACE::shape_inference::InferShapes(model, OpSchemaRegistry::Instance(), options);
}

Expand Down