Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiate shape and type inference errors #5519

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions onnx/checker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,7 @@ void check_model(const std::string& model_path, bool full_check, bool skip_opset
check_model(model, ctx);

if (full_check) {
ShapeInferenceOptions options{true, 1, false};
ShapeInferenceOptions options{true, FailAnyInferenceError, false};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to scope this enum for clarity? e.g. errors::FailAnyInferenceError or shape_inference::FailAnyInferenceError etc.

@jcwchen any suggestions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is correct, this PR is far from complete. There are many changes on API surface so we want to make it right before merge. I am more concerned on the Python side. Naming is one thing that we need to be precise. In terms of scope, it is already under ShapeInference module but adding errors.
It is less a concern on the c++ side however.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why it's less of a concern on the c++ side?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should have said it was a concern for both. It is just that python api is more visible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby it is better to scope the enum. do you think InferenceErrorMode instead of error is better?

enum class InferenceErrorMode : uint32_t {
IgnoreInferenceError, // Ignore any inference errors
FailAnyInferenceError, // Fail on any inference error
FailShapeInferenceError, // Fail on any shape inference error, like merging existing shape with inferred etc.
FailTypeInferenceError // Fail on any type inference error
};

ONNX_NAMESPACE::shape_inference::InferShapes(model, ctx.get_schema_registry(), options);
}
}
Expand All @@ -1046,7 +1046,7 @@ void check_model(const ModelProto& model, bool full_check, bool skip_opset_compa
ctx.set_skip_opset_compatibility_check(skip_opset_compatibility_check);
check_model(model, ctx);
if (full_check) {
ShapeInferenceOptions options{true, 1, false};
ShapeInferenceOptions options{true, FailAnyInferenceError, false};
// Do not update the model in place by the check from shape inference
// because checker should not modify the original model
ModelProto copy = model;
Expand Down
22 changes: 15 additions & 7 deletions onnx/cpp2py_export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@
shape_inference::GraphInferenceContext graphInferenceContext(
valueTypes.second, opsetImports, nullptr, {}, OpSchemaRegistry::Instance(), nullptr, irVersion);
// Construct inference context and get results - may throw InferenceError
// TODO: if it is desirable for infer_node_outputs to provide check_type, strict_mode, data_prop,
// TODO: if it is desirable for infer_node_outputs to provide check_type, error_mode, data_prop,

Check warning on line 82 in onnx/cpp2py_export.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnx/cpp2py_export.cc:82: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// we can add them to the Python API. For now we just assume the default options.
ShapeInferenceOptions options{false, 0, false};
ShapeInferenceOptions options{false, IgnoreInferenceError, false};
shape_inference::InferenceContextImpl ctx(
node, valueTypes.second, inputData.second, inputSparseData.second, options, nullptr, &graphInferenceContext);
schema->GetTypeAndShapeInferenceFunction()(ctx);
Expand Down Expand Up @@ -577,31 +577,39 @@
auto shape_inference = onnx_cpp2py_export.def_submodule("shape_inference");
shape_inference.doc() = "Shape Inference submodule";
py::register_exception<InferenceError>(shape_inference, "InferenceError");
py::register_exception<TypeInferenceError>(shape_inference, "TypeInferenceError");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just making sure: Are we subclassing InferenceError? It would be nice if that is the case because users can then opt to catch InferenceError for all cases, or catch TypeInferenceError etc. for more specific errors

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes we are subclassing TypeInferenceError and ShapeInferenceError from InferenceError.

py::register_exception<ShapeInferenceError>(shape_inference, "ShapeInferenceError");

Check warning on line 582 in onnx/cpp2py_export.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnx/cpp2py_export.cc:582: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
Fixed Show fixed Hide fixed
py::enum_<InferenceErrorMode>(shape_inference, "InferenceErrorMode")
.value("IgnoreInferenceError", IgnoreInferenceError)
.value("FailAnyInferenceError", FailAnyInferenceError)
.value("FailShapeInferenceError", FailShapeInferenceError)
.value("FailTypeInferenceError", FailTypeInferenceError);

shape_inference.def(
"infer_shapes",
[](const py::bytes& bytes, bool check_type, bool strict_mode, bool data_prop) {
[](const py::bytes& bytes, bool check_type, InferenceErrorMode error_mode, bool data_prop) {
ModelProto proto{};
ParseProtoFromPyBytes(&proto, bytes);
ShapeInferenceOptions options{check_type, strict_mode == true ? 1 : 0, data_prop};
ShapeInferenceOptions options{check_type, error_mode, data_prop};
shape_inference::InferShapes(proto, OpSchemaRegistry::Instance(), options);
std::string out;
proto.SerializeToString(&out);
return py::bytes(out);
},
"bytes"_a,
"check_type"_a = false,
"strict_mode"_a = false,
"error_mode"_a = IgnoreInferenceError,
"data_prop"_a = false);

shape_inference.def(
"infer_shapes_path",
[](const std::string& model_path,
const std::string& output_path,
bool check_type,
bool strict_mode,
InferenceErrorMode error_mode,
bool data_prop) -> void {
ShapeInferenceOptions options{check_type, strict_mode == true ? 1 : 0, data_prop};
ShapeInferenceOptions options{check_type, error_mode, data_prop};
shape_inference::InferShapes(model_path, output_path, OpSchemaRegistry::Instance(), options);
});

Expand Down
37 changes: 26 additions & 11 deletions onnx/defs/shape_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,22 @@

using Dim = TensorShapeProto_Dimension;

enum InferenceErrorMode : uint8_t {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we could easily use uint32_t. Cost is not significant, and may help us in long run.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also wonder whether we need to expose anything like this externally? Who are the intended users and what will they use it for? May be it is better to focus on the intended use, which is internal ... specifically, we want to be able to propagate context information so that we know whether we are processing a top-level node, or a node contained inside a conditional (or loop? not sure about that). If it is a node inside a conditional, we want to tolerate some errors, since the node may never execute. So, I wonder if we really need this enumeration or expose it externally.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another concern here: the checker throws ValidationError. For the example discussed above, it may be necessary to catch/handle checker-errors as well, since a missing attribute might show up as a checker error.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be helpful to validate ONNX with this PR: #5488 (with converters etc.) to see what failures we run into, if any. And create some test-cases with if-then-else that we would like to successfully pass the checker, and approach this top-down, to make sure we are able to do what we want.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the enum is exposed via: infer_shapes, infer_shapes_path. It originally used a binary strict_mode, which can be replaced with IgnoreInferenceError or other enum values. In that sense, we can keep the 2 interfaces as it was so that the enum does not have to be visible from python. For checker, there is an option to run shape inference. In this case, the shape inference error mode is set to be FailAnyInferenceError. This means shape inference specific exception is thrown. As an example:

self.assertRaises(shape_inference.TypeError, checker.check_model, model, True)

IgnoreInferenceError, // Ignore any inference errors

Check warning on line 22 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnx/defs/shape_inference.h:22: At least two spaces is best between code and comments [whitespace/comments] [2]
FailAnyInferenceError, // Fail on any inference error
FailShapeInferenceError, // Fail on any shape inference error, like merging existing shape with inferred etc.
FailTypeInferenceError // Fail on any type inference error
};

struct ShapeInferenceOptions {
// Checks the type-equality for input and output
bool check_type;
// 1: Will throw any node level shape infer errors
// 0: Won't throw node-level shape infer errors, but other errors
// like merging existing shape with inferred etc are thrown
int error_mode;
InferenceErrorMode error_mode;
// Enables data propagation for limited operators
// to perform shape computation
bool enable_data_propagation;
ShapeInferenceOptions(bool check_type_val = false, int strict_mode_val = 0, bool data_prop_val = false)
: check_type(check_type_val), error_mode(strict_mode_val), enable_data_propagation(data_prop_val){};
ShapeInferenceOptions(bool check_type_val = false, InferenceErrorMode error_mode_val = IgnoreInferenceError, bool data_prop_val = false)

Check warning on line 35 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnx/defs/shape_inference.h:35: Lines should be <= 120 characters long [whitespace/line_length] [2]

Check warning on line 35 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Constructors callable with one argument should be marked explicit. [runtime/explicit] [5] Raw Output: onnx/defs/shape_inference.h:35: Constructors callable with one argument should be marked explicit. [runtime/explicit] [5]

Check warning on line 35 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: constructors that are callable with a single argument must be marked explicit to avoid unintentional implicit conversions [google-explicit-constructor] ```suggestion explicit ShapeInferenceOptions(bool check_type_val = false, InferenceErrorMode error_mode_val = IgnoreInferenceError, bool data_prop_val = false) ```
: check_type(check_type_val), error_mode(error_mode_val), enable_data_propagation(data_prop_val){};

Check warning on line 36 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnx/defs/shape_inference.h:36: You don't need a ; after a } [readability/braces] [4]
};

// Maintains a SymbolTable for symbolic shape inference
Expand All @@ -52,9 +56,8 @@
virtual ~GraphInferencer() = default;
};

// Exception class used for handling errors in type and shape inference

class InferenceError final : public std::runtime_error {
// Exception classes used for handling errors in type and shape inference
class InferenceError : public std::runtime_error {
public:
using std::runtime_error::runtime_error;

Expand All @@ -75,11 +78,23 @@
std::string expanded_message_;
};

class ShapeInferenceError final : public InferenceError {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re. the naming: I wonder if it is better to call these ShapeError and TypeError. (I realize that this naming comes from the original macro names.)

Also, I think these might not be enough: some errors (like missing attributes or wrong attribute-value) are not really "shape error". So, I wonder if we should either rename ShapeError to ValueError or at least add another exception called ValueError or something more generic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point @gramalingam
I like to keep ShapeError for its original use, adding AttributeError (more specific than ValueError) for errors raised from attribute checking. In addition to ShapeError, TypeError, and AttributeError, InferShape can also raise ValidationError via schema->CheckInputOutputType. It is currently swallowed as a shapeError which is in correct.

public:
ShapeInferenceError(const std::string& message)

Check warning on line 83 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: single-argument constructors must be marked explicit to avoid unintentional implicit conversions [google-explicit-constructor] ```suggestion explicit ShapeInferenceError(const std::string& message) ```
: InferenceError(ONNX_NAMESPACE::MakeString("[ShapeInferenceError] ", message)) {}
};

class TypeInferenceError final : public InferenceError {
public:
TypeInferenceError(const std::string& message)

Check warning on line 89 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: single-argument constructors must be marked explicit to avoid unintentional implicit conversions [google-explicit-constructor] ```suggestion explicit TypeInferenceError(const std::string& message) ```
: InferenceError(ONNX_NAMESPACE::MakeString("[TypeInferenceError] ", message)) {}
};

#define fail_type_inference(...) \
ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString("[TypeInferenceError] ", __VA_ARGS__)));
ONNX_THROW_EX(ONNX_NAMESPACE::TypeInferenceError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));

#define fail_shape_inference(...) \
ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString("[ShapeInferenceError] ", __VA_ARGS__)));
ONNX_THROW_EX(ONNX_NAMESPACE::ShapeInferenceError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));

struct InferenceContext {
virtual const AttributeProto* getAttribute(const std::string& name) const = 0;
Expand Down
18 changes: 12 additions & 6 deletions onnx/shape_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
def infer_shapes(
model: ModelProto | bytes,
check_type: bool = False,
strict_mode: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to keep the strict_mode option for backwards compatibility

error_mode: C.InferenceErrorMode = C.InferenceErrorMode.IgnoreInferenceError,

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "C.InferenceErrorMode" is not defined To disable, use # type: ignore[name-defined]

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

Module has no attribute "InferenceErrorMode"; maybe "InferenceError"? To disable, use # type: ignore[attr-defined]
data_prop: bool = False,
) -> ModelProto:
"""Apply shape inference to the provided ModelProto.
Expand All @@ -34,8 +34,11 @@
Arguments:
model (Union[ModelProto, bytes], bool, bool, bool) -> ModelProto
check_type (bool): Checks the type-equality for input and output
strict_mode (bool): Stricter shape inference, it will throw errors if any;
Otherwise, simply stop if any error
error_mode (InferenceErrorMode):
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
IgnoreInferenceError: ignore any inference error and continue
FailAnyInferenceError: stop on any inference error and throw the exception combining all inference errors
FailShapeInferenceError: stop on shape inference error and throw the exception
FailTypeInferenceError: stop on type inference error and throw the exception
data_prop (bool): Enables data propagation for limited operators to perform shape computation

Returns:
Expand All @@ -44,7 +47,7 @@
if isinstance(model, (ModelProto, bytes)):
model_str = model if isinstance(model, bytes) else model.SerializeToString()
inferred_model_str = C.infer_shapes(
model_str, check_type, strict_mode, data_prop
model_str, check_type, error_mode, data_prop
)
return onnx.load_from_string(inferred_model_str)
if isinstance(model, str):
Expand All @@ -62,7 +65,7 @@
model_path: str | os.PathLike,
output_path: str | os.PathLike = "",
check_type: bool = False,
strict_mode: bool = False,
error_mode: C.InferenceErrorMode = C.InferenceErrorMode.IgnoreInferenceError,

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "C.InferenceErrorMode" is not defined To disable, use # type: ignore[name-defined]

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

Module has no attribute "InferenceErrorMode"; maybe "InferenceError"? To disable, use # type: ignore[attr-defined]
data_prop: bool = False,
) -> None:
"""
Expand Down Expand Up @@ -91,7 +94,7 @@

if output_path == "":
output_path = model_path
C.infer_shapes_path(model_path, output_path, check_type, strict_mode, data_prop)
C.infer_shapes_path(model_path, output_path, check_type, error_mode, data_prop)


def infer_node_outputs(
Expand Down Expand Up @@ -168,3 +171,6 @@


InferenceError = C.InferenceError
TypeInferenceError = C.TypeInferenceError

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

Module has no attribute "TypeInferenceError"; maybe "InferenceError"? To disable, use # type: ignore[attr-defined]
ShapeInferenceError = C.ShapeInferenceError

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

Module has no attribute "ShapeInferenceError"; maybe "InferenceError"? To disable, use # type: ignore[attr-defined]
InferenceErrorMode = C.InferenceErrorMode

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

Module has no attribute "InferenceErrorMode"; maybe "InferenceError"? To disable, use # type: ignore[attr-defined]
2 changes: 1 addition & 1 deletion onnx/shape_inference/implementation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,7 @@ std::vector<TypeProto> InferFunctionOutputTypes(
const std::vector<AttributeProto>& attributes) {
// TODO: if it is desirable for infer_function_output_types to provide check_type, strict_mode, data_prop,
// we can add them to the Python API. For now we just assume the default options.
ShapeInferenceOptions options{true, 1, false};
ShapeInferenceOptions options{true, FailAnyInferenceError, false};
FunctionInferenceContext ctx(function_proto, input_types, attributes, options);
auto opset_imports = GetOpsetImportsFromProto(function_proto);
GraphProto g;
Expand Down
5 changes: 3 additions & 2 deletions onnx/test/automatic_upgrade_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import onnx
from onnx import TensorProto, ValueInfoProto, helper, shape_inference, version_converter
from onnx.shape_inference import InferenceError, InferenceErrorMode

Check warning

Code scanning / lintrunner

RUFF/F401 Warning

onnx.shape\_inference.InferenceError imported but unused.
See https://beta.ruff.rs/docs/rules/

Check warning

Code scanning / lintrunner

PYLINT/W0611 Warning

Unused InferenceError imported from onnx.shape_inference (unused-import)
See unused-import. To disable, use # pylint: disable=unused-import
Fixed Show fixed Hide fixed

#####################################################################################
# Every test creates a model containing a single operator from the lowest possible
Expand Down Expand Up @@ -106,11 +107,11 @@
opset_imports=[helper.make_opsetid("", from_opset)],
)
onnx.checker.check_model(original)
shape_inference.infer_shapes(original, strict_mode=True)
shape_inference.infer_shapes(original, error_mode=InferenceErrorMode.FailAnyInferenceError)

converted = version_converter.convert_version(original, LATEST_OPSET)
onnx.checker.check_model(converted)
shape_inference.infer_shapes(converted, strict_mode=True)
shape_inference.infer_shapes(converted, error_mode=InferenceErrorMode.FailAnyInferenceError)

def test_Abs(self) -> None:
self._test_op_upgrade("Abs", 1, attrs={"consumed_inputs": [0]})
Expand Down
4 changes: 2 additions & 2 deletions onnx/test/cpp/function_verify_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ foo (x) => (y) {
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
check_model(model);

ShapeInferenceOptions options{true, 1, true};
ShapeInferenceOptions options{true, FailAnyInferenceError, true};
ONNX_NAMESPACE::shape_inference::InferShapes(model, OpSchemaRegistry::Instance(), options);
}

Expand Down Expand Up @@ -550,7 +550,7 @@ foo (x) => (y) {

check_model(model);

ShapeInferenceOptions options{true, 1, true};
ShapeInferenceOptions options{true, FailAnyInferenceError, true};
ONNX_NAMESPACE::shape_inference::InferShapes(model, OpSchemaRegistry::Instance(), options);
}

Expand Down
4 changes: 2 additions & 2 deletions onnx/test/cpp/shape_inference_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ static void doInferencingTest(bool use_scan_opset8) {
std::vector<const TypeProto*> subgraphInputTypes = {&simple_tensor, &simple_tensor};

std::vector<const TensorProto*> subgraphInputData = {};
ShapeInferenceOptions options{false, 0, false};
ShapeInferenceOptions options{false, IgnoreInferenceError, false};
auto output = graphInferencer.doInferencing(subgraphInputTypes, subgraphInputData);

// check the subgraph outputs had their shape inferred when we called
Expand Down Expand Up @@ -514,7 +514,7 @@ void RunReshapeShapeInfTest(const char* modelStr, TensorShapeProto& expectedShap
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
EXPECT_TRUE(parser.EndOfInput()) << "Extra unparsed input unexpected.";

ShapeInferenceOptions options{true, 1, true};
ShapeInferenceOptions options{true, FailAnyInferenceError, true};
ONNX_NAMESPACE::shape_inference::InferShapes(model, ONNX_NAMESPACE::OpSchemaRegistry::Instance(), options);

const auto inferredShape = model.graph().output(0).type().tensor_type().shape();
Expand Down
2 changes: 1 addition & 1 deletion onnx/test/function_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def invoke_inference():
function, input_types, attributes
)

self.assertRaises(onnx.shape_inference.InferenceError, invoke_inference)
self.assertRaises(onnx.shape_inference.ShapeInferenceError, invoke_inference)

def test_fi_basic(self):
code = """
Expand Down
12 changes: 6 additions & 6 deletions onnx/test/inference_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
make_tensor_value_info,
)
from onnx.numpy_helper import from_array
from onnx.shape_inference import InferenceError, infer_node_outputs
from onnx.shape_inference import InferenceError, InferenceErrorMode, infer_node_outputs


ADD_SCHEMA = max(
(s for s in get_all_schemas_with_history() if s.name == "Add" and s.domain == ""),
Expand Down Expand Up @@ -236,9 +237,9 @@ def test_inference_with_conflow(self) -> None:
}
"""
model = onnx.parser.parse_model(model_script)
onnx.shape_inference.infer_shapes(model, strict_mode=False)
with self.assertRaises(onnx.shape_inference.InferenceError):
onnx.shape_inference.infer_shapes(model, strict_mode=True)
onnx.shape_inference.infer_shapes(model, error_mode=InferenceErrorMode.IgnoreInferenceError)
with self.assertRaises(onnx.shape_inference.ShapeInferenceError):
onnx.shape_inference.infer_shapes(model, error_mode=InferenceErrorMode.FailAnyInferenceError)

def test_inference_with_attribute(self) -> None:
model_script = """
Expand Down Expand Up @@ -273,8 +274,7 @@ def test_inference_with_attribute(self) -> None:
}
"""
model = onnx.parser.parse_model(model_script)
# onnx.shape_inference.infer_shapes(model, strict_mode=False)
onnx.shape_inference.infer_shapes(model, strict_mode=True)
onnx.shape_inference.infer_shapes(model, error_mode=InferenceErrorMode.FailAnyInferenceError)


if __name__ == "__main__":
Expand Down