Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiate shape and type inference errors #5519

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions onnx/checker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1031,7 +1031,7 @@ void check_model(const std::string& model_path, bool full_check, bool skip_opset
check_model(model, ctx);

if (full_check) {
ShapeInferenceOptions options{true, 1, false};
ShapeInferenceOptions options{true, FailAnyInferenceError, false};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to scope this enum for clarity? e.g. errors::FailAnyInferenceError or shape_inference::FailAnyInferenceError etc.

@jcwchen any suggestions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is correct, this PR is far from complete. There are many changes on API surface so we want to make it right before merge. I am more concerned on the Python side. Naming is one thing that we need to be precise. In terms of scope, it is already under ShapeInference module but adding errors.
It is less a concern on the c++ side however.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why it's less of a concern on the c++ side?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should have said it was a concern for both. It is just that python api is more visible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby it is better to scope the enum. do you think InferenceErrorMode instead of error is better?

enum class InferenceErrorMode : uint32_t {
IgnoreInferenceError, // Ignore any inference errors
FailAnyInferenceError, // Fail on any inference error
FailShapeInferenceError, // Fail on any shape inference error, like merging existing shape with inferred etc.
FailTypeInferenceError // Fail on any type inference error
};

ONNX_NAMESPACE::shape_inference::InferShapes(model, ctx.get_schema_registry(), options);
}
}
Expand All @@ -1041,7 +1041,7 @@ void check_model(const ModelProto& model, bool full_check, bool skip_opset_compa
ctx.set_skip_opset_compatibility_check(skip_opset_compatibility_check);
check_model(model, ctx);
if (full_check) {
ShapeInferenceOptions options{true, 1, false};
ShapeInferenceOptions options{true, FailAnyInferenceError, false};
// Do not update the model in place by the check from shape inference
// because checker should not modify the original model
ModelProto copy = model;
Expand Down
22 changes: 15 additions & 7 deletions onnx/cpp2py_export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@
shape_inference::GraphInferenceContext graphInferenceContext(
valueTypes.second, opsetImports, nullptr, {}, OpSchemaRegistry::Instance(), nullptr, irVersion);
// Construct inference context and get results - may throw InferenceError
// TODO: if it is desirable for infer_node_outputs to provide check_type, strict_mode, data_prop,
// TODO: if it is desirable for infer_node_outputs to provide check_type, error_mode, data_prop,

Check warning on line 82 in onnx/cpp2py_export.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnx/cpp2py_export.cc:82: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// we can add them to the Python API. For now we just assume the default options.
ShapeInferenceOptions options{false, 0, false};
ShapeInferenceOptions options{false, IgnoreInferenceError, false};
shape_inference::InferenceContextImpl ctx(
node, valueTypes.second, inputData.second, inputSparseData.second, options, nullptr, &graphInferenceContext);
schema->GetTypeAndShapeInferenceFunction()(ctx);
Expand Down Expand Up @@ -609,31 +609,39 @@
auto shape_inference = onnx_cpp2py_export.def_submodule("shape_inference");
shape_inference.doc() = "Shape Inference submodule";
py::register_exception<InferenceError>(shape_inference, "InferenceError");
py::register_exception<TypeError>(shape_inference, "TypeError");
py::register_exception<ShapeError>(shape_inference, "ShapeError");

py::enum_<InferenceErrorMode>(shape_inference, "InferenceErrorMode")
.value("IgnoreInferenceError", IgnoreInferenceError)
.value("FailAnyInferenceError", FailAnyInferenceError)
.value("FailShapeInferenceError", FailShapeInferenceError)
.value("FailTypeInferenceError", FailTypeInferenceError);

shape_inference.def(
"infer_shapes",
[](const py::bytes& bytes, bool check_type, bool strict_mode, bool data_prop) {
[](const py::bytes& bytes, bool check_type, InferenceErrorMode error_mode, bool data_prop) {
ModelProto proto{};
ParseProtoFromPyBytes(&proto, bytes);
ShapeInferenceOptions options{check_type, strict_mode == true ? 1 : 0, data_prop};
ShapeInferenceOptions options{check_type, error_mode, data_prop};
shape_inference::InferShapes(proto, OpSchemaRegistry::Instance(), options);
std::string out;
proto.SerializeToString(&out);
return py::bytes(out);
},
"bytes"_a,
"check_type"_a = false,
"strict_mode"_a = false,
"error_mode"_a = IgnoreInferenceError,
"data_prop"_a = false);

shape_inference.def(
"infer_shapes_path",
[](const std::string& model_path,
const std::string& output_path,
bool check_type,
bool strict_mode,
InferenceErrorMode error_mode,
bool data_prop) -> void {
ShapeInferenceOptions options{check_type, strict_mode == true ? 1 : 0, data_prop};
ShapeInferenceOptions options{check_type, error_mode, data_prop};
shape_inference::InferShapes(model_path, output_path, OpSchemaRegistry::Instance(), options);
});

Expand Down
41 changes: 28 additions & 13 deletions onnx/defs/shape_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,25 @@

using Dim = TensorShapeProto_Dimension;

enum InferenceErrorMode : uint32_t {
IgnoreInferenceError, // Ignore any inference errors

Check warning on line 22 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnx/defs/shape_inference.h:22: At least two spaces is best between code and comments [whitespace/comments] [2]
FailAnyInferenceError, // Fail on any inference error

Check warning on line 23 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnx/defs/shape_inference.h:23: At least two spaces is best between code and comments [whitespace/comments] [2]
FailShapeInferenceError, // Fail on any shape inference error, like merging existing shape with inferred etc.

Check warning on line 24 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnx/defs/shape_inference.h:24: At least two spaces is best between code and comments [whitespace/comments] [2]
FailTypeInferenceError // Fail on any type inference error

Check warning on line 25 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnx/defs/shape_inference.h:25: At least two spaces is best between code and comments [whitespace/comments] [2]
};

struct ShapeInferenceOptions {
// Checks the type-equality for input and output
bool check_type;
// 1: Will throw any node level shape infer errors
// 0: Won't throw node-level shape infer errors, but other errors
// like merging existing shape with inferred etc are thrown
int error_mode;
InferenceErrorMode error_mode;

Check warning on line 31 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: member variable 'error_mode' has public visibility [misc-non-private-member-variables-in-classes] ```cpp InferenceErrorMode error_mode; ^ ```
// Enables data propagation for limited operators
// to perform shape computation
bool enable_data_propagation;
ShapeInferenceOptions(bool check_type_val = false, int strict_mode_val = 0, bool data_prop_val = false)
: check_type(check_type_val), error_mode(strict_mode_val), enable_data_propagation(data_prop_val){};
ShapeInferenceOptions(
bool check_type_val = false,
InferenceErrorMode error_mode_val = IgnoreInferenceError,
bool data_prop_val = false)
: check_type(check_type_val), error_mode(error_mode_val), enable_data_propagation(data_prop_val){};

Check warning on line 39 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnx/defs/shape_inference.h:39: You don't need a ; after a } [readability/braces] [4]
};

// Maintains a SymbolTable for symbolic shape inference
Expand All @@ -55,9 +62,8 @@
virtual ~GraphInferencer() = default;
};

// Exception class used for handling errors in type and shape inference

class InferenceError final : public std::runtime_error {
// Exception classes used for handling errors in type and shape inference
class InferenceError : public std::runtime_error {
public:
using std::runtime_error::runtime_error;

Expand All @@ -78,12 +84,21 @@
std::string expanded_message_;
};

#define fail_type_inference(...) \
ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString("[TypeInferenceError] ", __VA_ARGS__)));
class ShapeError final : public InferenceError {
public:
ShapeError(const std::string& message) : InferenceError(ONNX_NAMESPACE::MakeString("[ShapeError] ", message)) {}
};

class TypeError final : public InferenceError {
public:
TypeError(const std::string& message) : InferenceError(ONNX_NAMESPACE::MakeString("[TypeError] ", message)) {}
};

#define fail_type_inference(...) ONNX_THROW_EX(ONNX_NAMESPACE::TypeError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));

Check warning on line 97 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: variadic macro 'fail_type_inference' used; consider using a 'constexpr' variadic template function [cppcoreguidelines-macro-usage] ```cpp #define fail_type_inference(...) ONNX_THROW_EX(ONNX_NAMESPACE::TypeError(ONNX_NAMESPACE::MakeString(__VA_ARGS__))); ^ ```

#define fail_shape_inference(...) \
ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString("[ShapeInferenceError] ", __VA_ARGS__)));
#define fail_shape_inference(...) ONNX_THROW_EX(ONNX_NAMESPACE::ShapeError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));

Check warning on line 99 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: variadic macro 'fail_shape_inference' used; consider using a 'constexpr' variadic template function [cppcoreguidelines-macro-usage] ```cpp #define fail_shape_inference(...) ONNX_THROW_EX(ONNX_NAMESPACE::ShapeError(ONNX_NAMESPACE::MakeString(__VA_ARGS__))); ^ ```

#define fail_inference(...) ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));

Check warning on line 101 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: variadic macro 'fail_inference' used; consider using a 'constexpr' variadic template function [cppcoreguidelines-macro-usage] ```cpp #define fail_inference(...) ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString(__VA_ARGS__))); ^ ```
struct InferenceContext {
virtual const AttributeProto* getAttribute(const std::string& name) const = 0;
virtual size_t getNumInputs() const = 0;
Expand Down
16 changes: 12 additions & 4 deletions onnx/shape_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
def infer_shapes(
model: ModelProto | bytes,
check_type: bool = False,
strict_mode: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to keep the strict_mode option for backwards compatibility

error_mode: C.InferenceErrorMode = C.InferenceErrorMode.IgnoreInferenceError,

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "C.InferenceErrorMode" is not defined To disable, use # type: ignore[name-defined]

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

Module has no attribute "InferenceErrorMode"; maybe "InferenceError"? To disable, use # type: ignore[attr-defined]
data_prop: bool = False,
) -> ModelProto:
"""Apply shape inference to the provided ModelProto.
Expand All @@ -36,6 +36,11 @@
check_type: Checks the type-equality for input and output.
strict_mode: Stricter shape inference, it will throw errors if any;
Otherwise, simply stop if any error.
error_mode (InferenceErrorMode):
IgnoreInferenceError: ignore any inference error and continue
FailAnyInferenceError: stop on any inference error and throw the exception combining all inference errors
FailShapeInferenceError: stop on shape inference error and throw the exception
FailTypeInferenceError: stop on type inference error and throw the exception
data_prop: Enables data propagation for limited operators to perform shape computation.

Returns:
Expand All @@ -44,7 +49,7 @@
if isinstance(model, (ModelProto, bytes)):
model_str = model if isinstance(model, bytes) else model.SerializeToString()
inferred_model_str = C.infer_shapes(
model_str, check_type, strict_mode, data_prop
model_str, check_type, error_mode, data_prop
)
return onnx.load_from_string(inferred_model_str)
if isinstance(model, str):
Expand All @@ -62,7 +67,7 @@
model_path: str | os.PathLike,
output_path: str | os.PathLike = "",
check_type: bool = False,
strict_mode: bool = False,
error_mode: C.InferenceErrorMode = C.InferenceErrorMode.IgnoreInferenceError,

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "C.InferenceErrorMode" is not defined To disable, use # type: ignore[name-defined]

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

Module has no attribute "InferenceErrorMode"; maybe "InferenceError"? To disable, use # type: ignore[attr-defined]
data_prop: bool = False,
) -> None:
"""Take model path for shape_inference.
Expand Down Expand Up @@ -93,7 +98,7 @@

if output_path == "":
output_path = model_path
C.infer_shapes_path(model_path, output_path, check_type, strict_mode, data_prop)
C.infer_shapes_path(model_path, output_path, check_type, error_mode, data_prop)

Check warning on line 101 in onnx/shape_inference.py

View check run for this annotation

Codecov / codecov/patch

onnx/shape_inference.py#L101

Added line #L101 was not covered by tests


def infer_node_outputs(
Expand Down Expand Up @@ -169,3 +174,6 @@


InferenceError = C.InferenceError
TypeError = C.TypeError

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

Module has no attribute "TypeError" To disable, use # type: ignore[attr-defined]

Check warning

Code scanning / lintrunner

RUFF/A001 Warning

Variable TypeError is shadowing a Python builtin.
See https://docs.astral.sh/ruff/rules/builtin-variable-shadowing
ShapeError = C.ShapeError

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

Module has no attribute "ShapeError" To disable, use # type: ignore[attr-defined]
InferenceErrorMode = C.InferenceErrorMode

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

Module has no attribute "InferenceErrorMode"; maybe "InferenceError"? To disable, use # type: ignore[attr-defined]
87 changes: 66 additions & 21 deletions onnx/shape_inference/implementation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -497,18 +497,33 @@
}
}
}
ONNX_CATCH(const ONNX_NAMESPACE::InferenceError& ex) {
// Note: The following special handling is to accommodate custom-ops. Ideally, custom-ops
// should be registered with a schema in the schema registry, allowing inference to handle
// them. As things stand, this special handling is somewhat fragile and is not fully
// general either. Eg., a custom-op suppresses error-messages for subsequent nodes, but
// this does not work across graphs. If special handling is required, a user-option may
// be a better way to do it. The fragility comes from the fact that the types of the
// returned-values of the custom-op are unknown, and subsequent node-level inference
// may fail because of this.
ONNX_CATCH(const ONNX_NAMESPACE::TypeError& ex) {
ONNX_HANDLE_EXCEPTION([&]() {
// Note: The following special handling is to accommodate custom-ops. Ideally, custom-ops
// should be registered with a schema in the schema registry, allowing inference to handle
// them. As things stand, this special handling is somewhat fragile and is not fully
// general either. Eg., a custom-op suppresses error-messages for subsequent nodes, but
// this does not work across graphs. If special handling is required, a user-option may
// be a better way to do it. The fragility comes from the fact that the types of the
// returned-values of the custom-op are unknown, and subsequent node-level inference
// may fail because of this.
// onnx does not support unsupported/experimental operators
// so it won't consider it as an error
if (!has_unsupported_op) {
inference_errors.push_back(GetErrorWithNodeInfo(n, ex));
auto error_info = GetErrorWithNodeInfo(n, ex);
inference_errors.push_back(std::make_exception_ptr(ex));
}
});
// Continue with inference for remaining nodes
return;
}
ONNX_CATCH(const ONNX_NAMESPACE::ShapeError& ex) {
ONNX_HANDLE_EXCEPTION([&]() {
// onnx does not support unsupported/experimental operators
// so it won't consider it as an error
if (!has_unsupported_op) {
auto error_info = GetErrorWithNodeInfo(n, ex);
inference_errors.push_back(std::make_exception_ptr(ex));
}
});
}
Expand Down Expand Up @@ -693,23 +708,53 @@
}

void FinalizeShapeInference() {
auto& errors = getErrors();
// Throw shape inference error if any. Error mode right now only supports 0 and 1.
// When set to 0, any node level shape inference errors are not thrown. This is to support backward compatiblity

Check warning on line 712 in onnx/shape_inference/implementation.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "compatiblity" is a misspelling of "compatibility" Raw Output: ./onnx/shape_inference/implementation.cc:712:104: "compatiblity" is a misspelling of "compatibility"
// with 1.7 and earlier releases. When set to 1 it will throw all exceptions.
// TODO: Add a more granular way for exception handling.
if (!errors.empty() && options.error_mode > 0) {
std::string full_errors = "Inference error(s): ";
for (const std::string& error : inference_errors) {
full_errors += error + "\n";
std::string shape_inference_errors, type_inference_errors, any_inference_errors, other_errors;

Check warning on line 715 in onnx/shape_inference/implementation.cc

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: variable 'shape_inference_errors' of type 'std::string' (aka 'basic_string<char>') can be declared 'const' [misc-const-correctness] ```cpp std::string shape_inference_errors, type_inference_errors, any_inference_errors, other_errors; ^ ```

Check warning on line 715 in onnx/shape_inference/implementation.cc

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: multiple declarations in a single statement reduces readability [readability-isolate-declaration] ```suggestion std::string shape_inference_errors; std::string type_inference_errors; std::string any_inference_errors; std::string other_errors; ```

if (!inference_errors.empty() && options.error_mode != IgnoreInferenceError) {
for (const auto& exceptionPtr : inference_errors) {
ONNX_TRY {
std::rethrow_exception(exceptionPtr);
}
ONNX_CATCH(const ONNX_NAMESPACE::TypeError& type_error) {
ONNX_HANDLE_EXCEPTION([&]() {
type_inference_errors += type_error.what();
any_inference_errors += type_error.what();
});
}
ONNX_CATCH(const ONNX_NAMESPACE::ShapeError& shape_error) {
ONNX_HANDLE_EXCEPTION([&]() {
shape_inference_errors += shape_error.what();
any_inference_errors += shape_error.what();
});
}
ONNX_CATCH(const ONNX_NAMESPACE::InferenceError& inference_error) {
ONNX_HANDLE_EXCEPTION([&]() { any_inference_errors += inference_error.what(); });
}
ONNX_CATCH(const std::exception& other_error) {
ONNX_HANDLE_EXCEPTION([&]() { other_errors += other_error.what(); });
}
}
// depend on error_mode and collected errors, fail shape or type inference
// the order of error types being process implies the priority of error types
if (!type_inference_errors.empty()) {
fail_type_inference(type_inference_errors);
} else if (!shape_inference_errors.empty()) {
fail_shape_inference(shape_inference_errors);
} else if (!any_inference_errors.empty()) {
fail_inference(any_inference_errors);
} else if (!other_errors.empty()) {
throw std::runtime_error(other_errors);
}

Check warning on line 751 in onnx/shape_inference/implementation.cc

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: do not use 'else' after 'throw' [readability-else-after-return] ```suggestion } if (!shape_inference_errors.empty()) { fail_shape_inference(shape_inference_errors); } else if (!any_inference_errors.empty()) { fail_inference(any_inference_errors); } else if (!other_errors.empty()) { throw std::runtime_error(other_errors); } ```
fail_shape_inference(full_errors);
}
}

const std::vector<std::string>& getErrors() const {
return inference_errors;
}
// const std::vector<std::string>& getErrors() const {
// return inference_errors;
// }
Comment on lines +755 to +757

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.

private:
InferredTypes inferred_types;
Expand All @@ -731,7 +776,7 @@

bool has_unsupported_op = false;

std::vector<std::string> inference_errors;
std::vector<std::exception_ptr> inference_errors;

std::list<TypeProto> initializer_type_list;

Expand Down Expand Up @@ -969,7 +1014,7 @@
const std::vector<AttributeProto>& attributes) {
// TODO: if it is desirable for infer_function_output_types to provide check_type, strict_mode, data_prop,
// we can add them to the Python API. For now we just assume the default options.
ShapeInferenceOptions options{true, 1, false};
ShapeInferenceOptions options{true, FailAnyInferenceError, false};
FunctionInferenceContext ctx(function_proto, input_types, attributes, options);
auto opset_imports = GetOpsetImportsFromProto(function_proto);
ShapeInferenceImplBase base(
Expand Down