-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Differentiate shape and type inference errors #5519
base: main
Are you sure you want to change the base?
Changes from 2 commits
1846156
00cffa9
7a25ab4
4dbf23c
a34720d
73eeb91
099bdcf
8d0644a
1c574ec
5ee4e02
cfa8267
abf0fa7
d6e8efb
5bcf659
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -79,9 +79,9 @@ | |
shape_inference::GraphInferenceContext graphInferenceContext( | ||
valueTypes.second, opsetImports, nullptr, {}, OpSchemaRegistry::Instance(), nullptr, irVersion); | ||
// Construct inference context and get results - may throw InferenceError | ||
// TODO: if it is desirable for infer_node_outputs to provide check_type, strict_mode, data_prop, | ||
// TODO: if it is desirable for infer_node_outputs to provide check_type, error_mode, data_prop, | ||
Check warning on line 82 in onnx/cpp2py_export.cc GitHub Actions / Optional Lint
|
||
// we can add them to the Python API. For now we just assume the default options. | ||
ShapeInferenceOptions options{false, 0, false}; | ||
ShapeInferenceOptions options{false, IgnoreInferenceError, false}; | ||
shape_inference::InferenceContextImpl ctx( | ||
node, valueTypes.second, inputData.second, inputSparseData.second, options, nullptr, &graphInferenceContext); | ||
schema->GetTypeAndShapeInferenceFunction()(ctx); | ||
|
@@ -577,31 +577,39 @@ | |
auto shape_inference = onnx_cpp2py_export.def_submodule("shape_inference"); | ||
shape_inference.doc() = "Shape Inference submodule"; | ||
py::register_exception<InferenceError>(shape_inference, "InferenceError"); | ||
py::register_exception<TypeInferenceError>(shape_inference, "TypeInferenceError"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just making sure: Are we subclassing InferenceError? It would be nice if that is the case because users can then opt to catch There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes we are subclassing TypeInferenceError and ShapeInferenceError from InferenceError. |
||
py::register_exception<ShapeInferenceError>(shape_inference, "ShapeInferenceError"); | ||
|
||
Check warning on line 582 in onnx/cpp2py_export.cc GitHub Actions / Optional Lint
|
||
|
||
py::enum_<InferenceErrorMode>(shape_inference, "InferenceErrorMode") | ||
.value("IgnoreInferenceError", IgnoreInferenceError) | ||
.value("FailAnyInferenceError", FailAnyInferenceError) | ||
.value("FailShapeInferenceError", FailShapeInferenceError) | ||
.value("FailTypeInferenceError", FailTypeInferenceError); | ||
|
||
shape_inference.def( | ||
"infer_shapes", | ||
[](const py::bytes& bytes, bool check_type, bool strict_mode, bool data_prop) { | ||
[](const py::bytes& bytes, bool check_type, InferenceErrorMode error_mode, bool data_prop) { | ||
ModelProto proto{}; | ||
ParseProtoFromPyBytes(&proto, bytes); | ||
ShapeInferenceOptions options{check_type, strict_mode == true ? 1 : 0, data_prop}; | ||
ShapeInferenceOptions options{check_type, error_mode, data_prop}; | ||
shape_inference::InferShapes(proto, OpSchemaRegistry::Instance(), options); | ||
std::string out; | ||
proto.SerializeToString(&out); | ||
return py::bytes(out); | ||
}, | ||
"bytes"_a, | ||
"check_type"_a = false, | ||
"strict_mode"_a = false, | ||
"error_mode"_a = IgnoreInferenceError, | ||
"data_prop"_a = false); | ||
|
||
shape_inference.def( | ||
"infer_shapes_path", | ||
[](const std::string& model_path, | ||
const std::string& output_path, | ||
bool check_type, | ||
bool strict_mode, | ||
InferenceErrorMode error_mode, | ||
bool data_prop) -> void { | ||
ShapeInferenceOptions options{check_type, strict_mode == true ? 1 : 0, data_prop}; | ||
ShapeInferenceOptions options{check_type, error_mode, data_prop}; | ||
shape_inference::InferShapes(model_path, output_path, OpSchemaRegistry::Instance(), options); | ||
}); | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -18,18 +18,22 @@ | |||
|
||||
using Dim = TensorShapeProto_Dimension; | ||||
|
||||
enum InferenceErrorMode : uint8_t { | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I think we could easily use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also wonder whether we need to expose anything like this externally? Who are the intended users and what will they use it for? May be it is better to focus on the intended use, which is internal ... specifically, we want to be able to propagate context information so that we know whether we are processing a top-level node, or a node contained inside a conditional (or loop? not sure about that). If it is a node inside a conditional, we want to tolerate some errors, since the node may never execute. So, I wonder if we really need this enumeration or expose it externally. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another concern here: the checker throws ValidationError. For the example discussed above, it may be necessary to catch/handle checker-errors as well, since a missing attribute might show up as a checker error. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it might be helpful to validate ONNX with this PR: #5488 (with converters etc.) to see what failures we run into, if any. And create some test-cases with if-then-else that we would like to successfully pass the checker, and approach this top-down, to make sure we are able to do what we want. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the enum is exposed via: infer_shapes, infer_shapes_path. It originally used a binary strict_mode, which can be replaced with IgnoreInferenceError or other enum values. In that sense, we can keep the 2 interfaces as it was so that the enum does not have to be visible from python. For checker, there is an option to run shape inference. In this case, the shape inference error mode is set to be FailAnyInferenceError. This means shape inference specific exception is thrown. As an example: onnx/onnx/test/checker_test.py Line 499 in abf0fa7
|
||||
IgnoreInferenceError, // Ignore any inference errors | ||||
Check warning on line 22 in onnx/defs/shape_inference.h GitHub Actions / Optional Lint
|
||||
FailAnyInferenceError, // Fail on any inference error | ||||
FailShapeInferenceError, // Fail on any shape inference error, like merging existing shape with inferred etc. | ||||
FailTypeInferenceError // Fail on any type inference error | ||||
}; | ||||
|
||||
struct ShapeInferenceOptions { | ||||
// Checks the type-equality for input and output | ||||
bool check_type; | ||||
// 1: Will throw any node level shape infer errors | ||||
// 0: Won't throw node-level shape infer errors, but other errors | ||||
// like merging existing shape with inferred etc are thrown | ||||
int error_mode; | ||||
InferenceErrorMode error_mode; | ||||
// Enables data propagation for limited operators | ||||
// to perform shape computation | ||||
bool enable_data_propagation; | ||||
ShapeInferenceOptions(bool check_type_val = false, int strict_mode_val = 0, bool data_prop_val = false) | ||||
: check_type(check_type_val), error_mode(strict_mode_val), enable_data_propagation(data_prop_val){}; | ||||
ShapeInferenceOptions(bool check_type_val = false, InferenceErrorMode error_mode_val = IgnoreInferenceError, bool data_prop_val = false) | ||||
Check warning on line 35 in onnx/defs/shape_inference.h GitHub Actions / Optional Lint
Check warning on line 35 in onnx/defs/shape_inference.h GitHub Actions / Optional Lint
Check warning on line 35 in onnx/defs/shape_inference.h GitHub Actions / clang-tidy-reviewclang-tidy
|
||||
: check_type(check_type_val), error_mode(error_mode_val), enable_data_propagation(data_prop_val){}; | ||||
Check warning on line 36 in onnx/defs/shape_inference.h GitHub Actions / Optional Lint
|
||||
}; | ||||
|
||||
// Maintains a SymbolTable for symbolic shape inference | ||||
|
@@ -52,9 +56,8 @@ | |||
virtual ~GraphInferencer() = default; | ||||
}; | ||||
|
||||
// Exception class used for handling errors in type and shape inference | ||||
|
||||
class InferenceError final : public std::runtime_error { | ||||
// Exception classes used for handling errors in type and shape inference | ||||
class InferenceError : public std::runtime_error { | ||||
public: | ||||
using std::runtime_error::runtime_error; | ||||
|
||||
|
@@ -75,11 +78,23 @@ | |||
std::string expanded_message_; | ||||
}; | ||||
|
||||
class ShapeInferenceError final : public InferenceError { | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Re. the naming: I wonder if it is better to call these Also, I think these might not be enough: some errors (like missing attributes or wrong attribute-value) are not really "shape error". So, I wonder if we should either rename There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point @gramalingam |
||||
public: | ||||
ShapeInferenceError(const std::string& message) | ||||
Check warning on line 83 in onnx/defs/shape_inference.h GitHub Actions / clang-tidy-reviewclang-tidy
|
||||
: InferenceError(ONNX_NAMESPACE::MakeString("[ShapeInferenceError] ", message)) {} | ||||
}; | ||||
|
||||
class TypeInferenceError final : public InferenceError { | ||||
public: | ||||
TypeInferenceError(const std::string& message) | ||||
Check warning on line 89 in onnx/defs/shape_inference.h GitHub Actions / clang-tidy-reviewclang-tidy
|
||||
: InferenceError(ONNX_NAMESPACE::MakeString("[TypeInferenceError] ", message)) {} | ||||
}; | ||||
|
||||
#define fail_type_inference(...) \ | ||||
ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString("[TypeInferenceError] ", __VA_ARGS__))); | ||||
ONNX_THROW_EX(ONNX_NAMESPACE::TypeInferenceError(ONNX_NAMESPACE::MakeString(__VA_ARGS__))); | ||||
|
||||
#define fail_shape_inference(...) \ | ||||
ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString("[ShapeInferenceError] ", __VA_ARGS__))); | ||||
ONNX_THROW_EX(ONNX_NAMESPACE::ShapeInferenceError(ONNX_NAMESPACE::MakeString(__VA_ARGS__))); | ||||
|
||||
struct InferenceContext { | ||||
virtual const AttributeProto* getAttribute(const std::string& name) const = 0; | ||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,7 @@ | |
def infer_shapes( | ||
model: ModelProto | bytes, | ||
check_type: bool = False, | ||
strict_mode: bool = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to keep the strict_mode option for backwards compatibility |
||
error_mode: C.InferenceErrorMode = C.InferenceErrorMode.IgnoreInferenceError, | ||
Check failure Code scanning / lintrunner MYPY/name-defined Error
Name "C.InferenceErrorMode" is not defined
To disable, use # type: ignore[name-defined]
Check failure Code scanning / lintrunner MYPY/attr-defined Error
Module has no attribute "InferenceErrorMode"; maybe "InferenceError"?
To disable, use # type: ignore[attr-defined]
|
||
data_prop: bool = False, | ||
) -> ModelProto: | ||
"""Apply shape inference to the provided ModelProto. | ||
|
@@ -34,8 +34,11 @@ | |
Arguments: | ||
model (Union[ModelProto, bytes], bool, bool, bool) -> ModelProto | ||
check_type (bool): Checks the type-equality for input and output | ||
strict_mode (bool): Stricter shape inference, it will throw errors if any; | ||
Otherwise, simply stop if any error | ||
error_mode (InferenceErrorMode): | ||
|
||
IgnoreInferenceError: ignore any inference error and continue | ||
FailAnyInferenceError: stop on any inference error and throw the exception combining all inference errors | ||
FailShapeInferenceError: stop on shape inference error and throw the exception | ||
FailTypeInferenceError: stop on type inference error and throw the exception | ||
data_prop (bool): Enables data propagation for limited operators to perform shape computation | ||
|
||
Returns: | ||
|
@@ -44,7 +47,7 @@ | |
if isinstance(model, (ModelProto, bytes)): | ||
model_str = model if isinstance(model, bytes) else model.SerializeToString() | ||
inferred_model_str = C.infer_shapes( | ||
model_str, check_type, strict_mode, data_prop | ||
model_str, check_type, error_mode, data_prop | ||
) | ||
return onnx.load_from_string(inferred_model_str) | ||
if isinstance(model, str): | ||
|
@@ -62,7 +65,7 @@ | |
model_path: str | os.PathLike, | ||
output_path: str | os.PathLike = "", | ||
check_type: bool = False, | ||
strict_mode: bool = False, | ||
error_mode: C.InferenceErrorMode = C.InferenceErrorMode.IgnoreInferenceError, | ||
Check failure Code scanning / lintrunner MYPY/name-defined Error
Name "C.InferenceErrorMode" is not defined
To disable, use # type: ignore[name-defined]
Check failure Code scanning / lintrunner MYPY/attr-defined Error
Module has no attribute "InferenceErrorMode"; maybe "InferenceError"?
To disable, use # type: ignore[attr-defined]
|
||
data_prop: bool = False, | ||
) -> None: | ||
""" | ||
|
@@ -91,7 +94,7 @@ | |
|
||
if output_path == "": | ||
output_path = model_path | ||
C.infer_shapes_path(model_path, output_path, check_type, strict_mode, data_prop) | ||
C.infer_shapes_path(model_path, output_path, check_type, error_mode, data_prop) | ||
|
||
|
||
def infer_node_outputs( | ||
|
@@ -168,3 +171,6 @@ | |
|
||
|
||
InferenceError = C.InferenceError | ||
TypeInferenceError = C.TypeInferenceError | ||
Check failure Code scanning / lintrunner MYPY/attr-defined Error
Module has no attribute "TypeInferenceError"; maybe "InferenceError"?
To disable, use # type: ignore[attr-defined]
|
||
ShapeInferenceError = C.ShapeInferenceError | ||
Check failure Code scanning / lintrunner MYPY/attr-defined Error
Module has no attribute "ShapeInferenceError"; maybe "InferenceError"?
To disable, use # type: ignore[attr-defined]
|
||
InferenceErrorMode = C.InferenceErrorMode | ||
Check failure Code scanning / lintrunner MYPY/attr-defined Error
Module has no attribute "InferenceErrorMode"; maybe "InferenceError"?
To disable, use # type: ignore[attr-defined]
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
|
||
import onnx | ||
from onnx import TensorProto, ValueInfoProto, helper, shape_inference, version_converter | ||
from onnx.shape_inference import InferenceError, InferenceErrorMode | ||
Check warning Code scanning / lintrunner RUFF/F401 Warning
onnx.shape\_inference.InferenceError imported but unused.
See https://beta.ruff.rs/docs/rules/ Check warning Code scanning / lintrunner PYLINT/W0611 Warning
Unused InferenceError imported from onnx.shape_inference (unused-import)
See unused-import. To disable, use # pylint: disable=unused-import |
||
|
||
|
||
##################################################################################### | ||
# Every test creates a model containing a single operator from the lowest possible | ||
|
@@ -106,11 +107,11 @@ | |
opset_imports=[helper.make_opsetid("", from_opset)], | ||
) | ||
onnx.checker.check_model(original) | ||
shape_inference.infer_shapes(original, strict_mode=True) | ||
shape_inference.infer_shapes(original, error_mode=InferenceErrorMode.FailAnyInferenceError) | ||
|
||
converted = version_converter.convert_version(original, LATEST_OPSET) | ||
onnx.checker.check_model(converted) | ||
shape_inference.infer_shapes(converted, strict_mode=True) | ||
shape_inference.infer_shapes(converted, error_mode=InferenceErrorMode.FailAnyInferenceError) | ||
|
||
def test_Abs(self) -> None: | ||
self._test_op_upgrade("Abs", 1, attrs={"consumed_inputs": [0]}) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to scope this enum for clarity? e.g.
errors::FailAnyInferenceError
orshape_inference::FailAnyInferenceError
etc.@jcwchen any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that is correct, this PR is far from complete. There are many changes on API surface so we want to make it right before merge. I am more concerned on the Python side. Naming is one thing that we need to be precise. In terms of scope, it is already under ShapeInference module but adding errors.
It is less a concern on the c++ side however.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious why it's less of a concern on the c++ side?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I should have said it was a concern for both. It is just that python api is more visible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@justinchuby it is better to scope the enum. do you think InferenceErrorMode instead of error is better?
enum class InferenceErrorMode : uint32_t {
IgnoreInferenceError, // Ignore any inference errors
FailAnyInferenceError, // Fail on any inference error
FailShapeInferenceError, // Fail on any shape inference error, like merging existing shape with inferred etc.
FailTypeInferenceError // Fail on any type inference error
};