Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiate shape and type inference errors #5519

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions onnx/checker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,7 @@ void check_model(const std::string& model_path, bool full_check, bool skip_opset
check_model(model, ctx);

if (full_check) {
ShapeInferenceOptions options{true, 1, false};
ShapeInferenceOptions options{true, FailAnyInferenceError, false};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to scope this enum for clarity? e.g. errors::FailAnyInferenceError or shape_inference::FailAnyInferenceError etc.

@jcwchen any suggestions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is correct, this PR is far from complete. There are many changes on API surface so we want to make it right before merge. I am more concerned on the Python side. Naming is one thing that we need to be precise. In terms of scope, it is already under ShapeInference module but adding errors.
It is less a concern on the c++ side however.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why it's less of a concern on the c++ side?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should have said it was a concern for both. It is just that python api is more visible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby it is better to scope the enum. do you think InferenceErrorMode instead of error is better?

enum class InferenceErrorMode : uint32_t {
IgnoreInferenceError, // Ignore any inference errors
FailAnyInferenceError, // Fail on any inference error
FailShapeInferenceError, // Fail on any shape inference error, like merging existing shape with inferred etc.
FailTypeInferenceError // Fail on any type inference error
};

ONNX_NAMESPACE::shape_inference::InferShapes(model, ctx.get_schema_registry(), options);
}
}
Expand All @@ -1046,7 +1046,7 @@ void check_model(const ModelProto& model, bool full_check, bool skip_opset_compa
ctx.set_skip_opset_compatibility_check(skip_opset_compatibility_check);
check_model(model, ctx);
if (full_check) {
ShapeInferenceOptions options{true, 1, false};
ShapeInferenceOptions options{true, FailAnyInferenceError, false};
// Do not update the model in place by the check from shape inference
// because checker should not modify the original model
ModelProto copy = model;
Expand Down
22 changes: 15 additions & 7 deletions onnx/cpp2py_export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@
shape_inference::GraphInferenceContext graphInferenceContext(
valueTypes.second, opsetImports, nullptr, {}, OpSchemaRegistry::Instance(), nullptr, irVersion);
// Construct inference context and get results - may throw InferenceError
// TODO: if it is desirable for infer_node_outputs to provide check_type, strict_mode, data_prop,
// TODO: if it is desirable for infer_node_outputs to provide check_type, error_mode, data_prop,

Check warning on line 82 in onnx/cpp2py_export.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnx/cpp2py_export.cc:82: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// we can add them to the Python API. For now we just assume the default options.
ShapeInferenceOptions options{false, 0, false};
ShapeInferenceOptions options{false, IgnoreInferenceError, false};
shape_inference::InferenceContextImpl ctx(
node, valueTypes.second, inputData.second, inputSparseData.second, options, nullptr, &graphInferenceContext);
schema->GetTypeAndShapeInferenceFunction()(ctx);
Expand Down Expand Up @@ -592,31 +592,39 @@
auto shape_inference = onnx_cpp2py_export.def_submodule("shape_inference");
shape_inference.doc() = "Shape Inference submodule";
py::register_exception<InferenceError>(shape_inference, "InferenceError");
py::register_exception<TypeError>(shape_inference, "TypeError");
py::register_exception<ShapeError>(shape_inference, "ShapeError");

Check warning on line 597 in onnx/cpp2py_export.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnx/cpp2py_export.cc:597: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
Fixed Show fixed Hide fixed
py::enum_<InferenceErrorMode>(shape_inference, "InferenceErrorMode")
.value("IgnoreInferenceError", IgnoreInferenceError)
.value("FailAnyInferenceError", FailAnyInferenceError)
.value("FailShapeInferenceError", FailShapeInferenceError)
.value("FailTypeInferenceError", FailTypeInferenceError);

shape_inference.def(
"infer_shapes",
[](const py::bytes& bytes, bool check_type, bool strict_mode, bool data_prop) {
[](const py::bytes& bytes, bool check_type, InferenceErrorMode error_mode, bool data_prop) {
ModelProto proto{};
ParseProtoFromPyBytes(&proto, bytes);
ShapeInferenceOptions options{check_type, strict_mode == true ? 1 : 0, data_prop};
ShapeInferenceOptions options{check_type, error_mode, data_prop};
shape_inference::InferShapes(proto, OpSchemaRegistry::Instance(), options);
std::string out;
proto.SerializeToString(&out);
return py::bytes(out);
},
"bytes"_a,
"check_type"_a = false,
"strict_mode"_a = false,
"error_mode"_a = IgnoreInferenceError,
"data_prop"_a = false);

shape_inference.def(
"infer_shapes_path",
[](const std::string& model_path,
const std::string& output_path,
bool check_type,
bool strict_mode,
InferenceErrorMode error_mode,
bool data_prop) -> void {
ShapeInferenceOptions options{check_type, strict_mode == true ? 1 : 0, data_prop};
ShapeInferenceOptions options{check_type, error_mode, data_prop};
shape_inference::InferShapes(model_path, output_path, OpSchemaRegistry::Instance(), options);
});

Expand Down
39 changes: 28 additions & 11 deletions onnx/defs/shape_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,22 @@

using Dim = TensorShapeProto_Dimension;

enum InferenceErrorMode : uint32_t {
IgnoreInferenceError, // Ignore any inference errors

Check warning on line 22 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnx/defs/shape_inference.h:22: At least two spaces is best between code and comments [whitespace/comments] [2]
FailAnyInferenceError, // Fail on any inference error
FailShapeInferenceError, // Fail on any shape inference error, like merging existing shape with inferred etc.
FailTypeInferenceError // Fail on any type inference error
};

struct ShapeInferenceOptions {
// Checks the type-equality for input and output
bool check_type;
// 1: Will throw any node level shape infer errors
// 0: Won't throw node-level shape infer errors, but other errors
// like merging existing shape with inferred etc are thrown
int error_mode;
InferenceErrorMode error_mode;
// Enables data propagation for limited operators
// to perform shape computation
bool enable_data_propagation;
ShapeInferenceOptions(bool check_type_val = false, int strict_mode_val = 0, bool data_prop_val = false)
: check_type(check_type_val), error_mode(strict_mode_val), enable_data_propagation(data_prop_val){};
ShapeInferenceOptions(bool check_type_val = false, InferenceErrorMode error_mode_val = IgnoreInferenceError, bool data_prop_val = false)

Check warning on line 35 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnx/defs/shape_inference.h:35: Lines should be <= 120 characters long [whitespace/line_length] [2]

Check warning on line 35 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Constructors callable with one argument should be marked explicit. [runtime/explicit] [5] Raw Output: onnx/defs/shape_inference.h:35: Constructors callable with one argument should be marked explicit. [runtime/explicit] [5]

Check warning on line 35 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: constructors that are callable with a single argument must be marked explicit to avoid unintentional implicit conversions [google-explicit-constructor] ```suggestion explicit ShapeInferenceOptions(bool check_type_val = false, InferenceErrorMode error_mode_val = IgnoreInferenceError, bool data_prop_val = false) ```
: check_type(check_type_val), error_mode(error_mode_val), enable_data_propagation(data_prop_val){};

Check warning on line 36 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnx/defs/shape_inference.h:36: You don't need a ; after a } [readability/braces] [4]
};

// Maintains a SymbolTable for symbolic shape inference
Expand All @@ -52,9 +56,8 @@
virtual ~GraphInferencer() = default;
};

// Exception class used for handling errors in type and shape inference

class InferenceError final : public std::runtime_error {
// Exception classes used for handling errors in type and shape inference
class InferenceError : public std::runtime_error {
public:
using std::runtime_error::runtime_error;

Expand All @@ -75,12 +78,26 @@
std::string expanded_message_;
};

class ShapeError final : public InferenceError {
public:
ShapeError(const std::string& message)

Check warning on line 83 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: single-argument constructors must be marked explicit to avoid unintentional implicit conversions [google-explicit-constructor] ```suggestion explicit ShapeError(const std::string& message) ```
: InferenceError(ONNX_NAMESPACE::MakeString("[ShapeError] ", message)) {}
};

class TypeError final : public InferenceError {
public:
TypeError(const std::string& message)

Check warning on line 89 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: single-argument constructors must be marked explicit to avoid unintentional implicit conversions [google-explicit-constructor] ```suggestion explicit TypeError(const std::string& message) ```
: InferenceError(ONNX_NAMESPACE::MakeString("[TypeError] ", message)) {}
};

#define fail_type_inference(...) \
ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString("[TypeInferenceError] ", __VA_ARGS__)));
ONNX_THROW_EX(ONNX_NAMESPACE::TypeError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));

#define fail_shape_inference(...) \
ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString("[ShapeInferenceError] ", __VA_ARGS__)));
ONNX_THROW_EX(ONNX_NAMESPACE::ShapeError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));

#define fail_inference(...) \
ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));
struct InferenceContext {
virtual const AttributeProto* getAttribute(const std::string& name) const = 0;
virtual size_t getNumInputs() const = 0;
Expand Down
18 changes: 12 additions & 6 deletions onnx/shape_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
def infer_shapes(
model: ModelProto | bytes,
check_type: bool = False,
strict_mode: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to keep the strict_mode option for backwards compatibility

error_mode: C.InferenceErrorMode = C.InferenceErrorMode.IgnoreInferenceError,

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "C.InferenceErrorMode" is not defined To disable, use # type: ignore[name-defined]

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

Module has no attribute "InferenceErrorMode"; maybe "InferenceError"? To disable, use # type: ignore[attr-defined]
data_prop: bool = False,
) -> ModelProto:
"""Apply shape inference to the provided ModelProto.
Expand All @@ -34,8 +34,11 @@
Arguments:
model (Union[ModelProto, bytes], bool, bool, bool) -> ModelProto
check_type (bool): Checks the type-equality for input and output
strict_mode (bool): Stricter shape inference, it will throw errors if any;
Otherwise, simply stop if any error
error_mode (InferenceErrorMode):
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
IgnoreInferenceError: ignore any inference error and continue
FailAnyInferenceError: stop on any inference error and throw the exception combining all inference errors
FailShapeInferenceError: stop on shape inference error and throw the exception
FailTypeInferenceError: stop on type inference error and throw the exception
data_prop (bool): Enables data propagation for limited operators to perform shape computation

Returns:
Expand All @@ -44,7 +47,7 @@
if isinstance(model, (ModelProto, bytes)):
model_str = model if isinstance(model, bytes) else model.SerializeToString()
inferred_model_str = C.infer_shapes(
model_str, check_type, strict_mode, data_prop
model_str, check_type, error_mode, data_prop
)
return onnx.load_from_string(inferred_model_str)
if isinstance(model, str):
Expand All @@ -62,7 +65,7 @@
model_path: str | os.PathLike,
output_path: str | os.PathLike = "",
check_type: bool = False,
strict_mode: bool = False,
error_mode: C.InferenceErrorMode = C.InferenceErrorMode.IgnoreInferenceError,

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "C.InferenceErrorMode" is not defined To disable, use # type: ignore[name-defined]

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

Module has no attribute "InferenceErrorMode"; maybe "InferenceError"? To disable, use # type: ignore[attr-defined]
data_prop: bool = False,
) -> None:
"""
Expand Down Expand Up @@ -91,7 +94,7 @@

if output_path == "":
output_path = model_path
C.infer_shapes_path(model_path, output_path, check_type, strict_mode, data_prop)
C.infer_shapes_path(model_path, output_path, check_type, error_mode, data_prop)


def infer_node_outputs(
Expand Down Expand Up @@ -168,3 +171,6 @@


InferenceError = C.InferenceError
TypeError = C.TypeError

Check warning

Code scanning / lintrunner

PYLINT/W0622 Warning

Redefining built-in 'TypeError' (redefined-builtin)
See redefined-builtin. To disable, use # pylint: disable=redefined-builtin

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

Module has no attribute "TypeError" To disable, use # type: ignore[attr-defined]

Check warning

Code scanning / lintrunner

RUFF/A001 Warning

Variable TypeError is shadowing a Python builtin.
See https://docs.astral.sh/ruff/rules/builtin-variable-shadowing
ShapeError = C.ShapeError

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

Module has no attribute "ShapeError" To disable, use # type: ignore[attr-defined]
InferenceErrorMode = C.InferenceErrorMode

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

Module has no attribute "InferenceErrorMode"; maybe "InferenceError"? To disable, use # type: ignore[attr-defined]
62 changes: 49 additions & 13 deletions onnx/shape_inference/implementation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -458,12 +458,25 @@
return;
}
}
ONNX_CATCH(const ONNX_NAMESPACE::InferenceError& ex) {
ONNX_CATCH(const ONNX_NAMESPACE::TypeError& ex) {
ONNX_HANDLE_EXCEPTION([&]() {
// onnx does not support unsupported/experimental operators
// so it won't consider it as an error
if (!has_unsupported_op && !has_experimental_op) {
inference_errors.push_back(GetErrorWithNodeInfo(n, ex));
auto error_info = GetErrorWithNodeInfo(n, ex);
inference_errors.push_back(std::make_exception_ptr(ex));
}
});
// Continue with inference for remaining nodes
return;
}
ONNX_CATCH(const ONNX_NAMESPACE::ShapeError& ex) {
ONNX_HANDLE_EXCEPTION([&]() {
// onnx does not support unsupported/experimental operators
// so it won't consider it as an error
if (!has_unsupported_op && !has_experimental_op) {
auto error_info = GetErrorWithNodeInfo(n, ex);
inference_errors.push_back(std::make_exception_ptr(ex));
}
});
// Continue with inference for remaining nodes
Expand Down Expand Up @@ -673,23 +686,46 @@
}

void finalizeShapeInference() {
auto& errors = getErrors();
//auto& errors = getErrors();

Check warning on line 689 in onnx/shape_inference/implementation.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Should have a space between // and comment [whitespace/comments] [4] Raw Output: onnx/shape_inference/implementation.cc:689: Should have a space between // and comment [whitespace/comments] [4]
// Throw shape inference error if any. Error mode right now only supports 0 and 1.
// When set to 0, any node level shape inference errors are not thrown. This is to support backward compatiblity

Check warning on line 691 in onnx/shape_inference/implementation.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "compatiblity" is a misspelling of "compatibility" Raw Output: ./onnx/shape_inference/implementation.cc:691:104: "compatiblity" is a misspelling of "compatibility"
// with 1.7 and earlier releases. When set to 1 it will throw all exceptions.
// TODO: Add a more granular way for exception handling.
if (!errors.empty() && options.error_mode > 0) {
std::string full_errors = "Inference error(s): ";
for (const std::string& error : inference_errors) {
full_errors += error + "\n";
std::string shape_inference_errors, type_inference_errors, any_inference_errors, other_errors;

if (!inference_errors.empty() && options.error_mode != IgnoreInferenceError) {
for (const auto& exceptionPtr : inference_errors) {
try {
std::rethrow_exception(exceptionPtr);
} catch (const TypeError& type_error) {
type_inference_errors += type_error.what();
any_inference_errors += type_error.what();
} catch (const ShapeError& shape_error) {
shape_inference_errors += shape_error.what();
any_inference_errors += shape_error.what();
} catch (const InferenceError& inference_error) {
any_inference_errors += inference_error.what();
} catch (const std::exception& other_error) {
other_errors += other_error.what();
}
}
// depend on error_mode and collected errors, fail shape or type inference
// the order of error types being process implies the priority of error types
if (!type_inference_errors.empty()) {
fail_type_inference(type_inference_errors);
} else if (!shape_inference_errors.empty()) {
fail_shape_inference(shape_inference_errors);
} else if (!any_inference_errors.empty()) {
fail_inference(any_inference_errors);
} else if (!other_errors.empty()) {
throw std::runtime_error(other_errors);
}
fail_shape_inference(full_errors);
}
}

const std::vector<std::string>& getErrors() const {
return inference_errors;
}
//const std::vector<std::string>& getErrors() const {

Check warning on line 726 in onnx/shape_inference/implementation.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Should have a space between // and comment [whitespace/comments] [4] Raw Output: onnx/shape_inference/implementation.cc:726: Should have a space between // and comment [whitespace/comments] [4]
// return inference_errors;
//}
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed

private:
GraphProto& g;
Expand All @@ -712,7 +748,7 @@
bool has_experimental_op = false;
bool has_unsupported_op = false;

std::vector<std::string> inference_errors;
std::vector<std::exception_ptr> inference_errors;

std::list<TypeProto> initializer_type_list;

Expand Down Expand Up @@ -951,7 +987,7 @@
const std::vector<AttributeProto>& attributes) {
// TODO: if it is desirable for infer_function_output_types to provide check_type, strict_mode, data_prop,
// we can add them to the Python API. For now we just assume the default options.
ShapeInferenceOptions options{true, 1, false};
ShapeInferenceOptions options{true, FailAnyInferenceError, false};
FunctionInferenceContext ctx(function_proto, input_types, attributes, options);
auto opset_imports = GetOpsetImportsFromProto(function_proto);
GraphProto g;
Expand Down
5 changes: 3 additions & 2 deletions onnx/test/automatic_upgrade_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import onnx
from onnx import TensorProto, ValueInfoProto, helper, shape_inference, version_converter
from onnx.shape_inference import InferenceErrorMode

#####################################################################################
# Every test creates a model containing a single operator from the lowest possible
Expand Down Expand Up @@ -106,11 +107,11 @@ def _test_op_upgrade(
opset_imports=[helper.make_opsetid("", from_opset)],
)
onnx.checker.check_model(original)
shape_inference.infer_shapes(original, strict_mode=True)
shape_inference.infer_shapes(original, error_mode=InferenceErrorMode.FailAnyInferenceError)

converted = version_converter.convert_version(original, LATEST_OPSET)
onnx.checker.check_model(converted)
shape_inference.infer_shapes(converted, strict_mode=True)
shape_inference.infer_shapes(converted, error_mode=InferenceErrorMode.FailAnyInferenceError)

def test_Abs(self) -> None:
self._test_op_upgrade("Abs", 1, attrs={"consumed_inputs": [0]})
Expand Down
4 changes: 2 additions & 2 deletions onnx/test/cpp/function_verify_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ foo (x) => (y) {
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
check_model(model);

ShapeInferenceOptions options{true, 1, true};
ShapeInferenceOptions options{true, FailAnyInferenceError, true};
ONNX_NAMESPACE::shape_inference::InferShapes(model, OpSchemaRegistry::Instance(), options);
}

Expand Down Expand Up @@ -550,7 +550,7 @@ foo (x) => (y) {

check_model(model);

ShapeInferenceOptions options{true, 1, true};
ShapeInferenceOptions options{true, FailAnyInferenceError, true};
ONNX_NAMESPACE::shape_inference::InferShapes(model, OpSchemaRegistry::Instance(), options);
}

Expand Down
4 changes: 2 additions & 2 deletions onnx/test/cpp/shape_inference_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ static void doInferencingTest(bool use_scan_opset8) {
std::vector<const TypeProto*> subgraphInputTypes = {&simple_tensor, &simple_tensor};

std::vector<const TensorProto*> subgraphInputData = {};
ShapeInferenceOptions options{false, 0, false};
ShapeInferenceOptions options{false, IgnoreInferenceError, false};
auto output = graphInferencer.doInferencing(subgraphInputTypes, subgraphInputData);

// check the subgraph outputs had their shape inferred when we called
Expand Down Expand Up @@ -514,7 +514,7 @@ void RunReshapeShapeInfTest(const char* modelStr, TensorShapeProto& expectedShap
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
EXPECT_TRUE(parser.EndOfInput()) << "Extra unparsed input unexpected.";

ShapeInferenceOptions options{true, 1, true};
ShapeInferenceOptions options{true, FailAnyInferenceError, true};
ONNX_NAMESPACE::shape_inference::InferShapes(model, ONNX_NAMESPACE::OpSchemaRegistry::Instance(), options);

const auto inferredShape = model.graph().output(0).type().tensor_type().shape();
Expand Down
2 changes: 1 addition & 1 deletion onnx/test/function_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def invoke_inference():
function, input_types, attributes
)

self.assertRaises(onnx.shape_inference.InferenceError, invoke_inference)
self.assertRaises(onnx.shape_inference.ShapeError, invoke_inference)

def test_fi_basic(self):
code = """
Expand Down