New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Differentiate shape and type inference errors #5519
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Liqun Fu <liqfu@microsoft.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lintrunner found more than 10 potential problems in the proposed changes. Check the Files changed tab for more details.
onnx/cpp2py_export.cc
Outdated
@@ -577,31 +577,39 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { | |||
auto shape_inference = onnx_cpp2py_export.def_submodule("shape_inference"); | |||
shape_inference.doc() = "Shape Inference submodule"; | |||
py::register_exception<InferenceError>(shape_inference, "InferenceError"); | |||
py::register_exception<TypeInferenceError>(shape_inference, "TypeInferenceError"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just making sure: Are we subclassing InferenceError? It would be nice if that is the case because users can then opt to catch InferenceError
for all cases, or catch TypeInferenceError
etc. for more specific errors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes we are subclassing TypeInferenceError and ShapeInferenceError from InferenceError.
@@ -1036,7 +1036,7 @@ void check_model(const std::string& model_path, bool full_check, bool skip_opset | |||
check_model(model, ctx); | |||
|
|||
if (full_check) { | |||
ShapeInferenceOptions options{true, 1, false}; | |||
ShapeInferenceOptions options{true, FailAnyInferenceError, false}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to scope this enum for clarity? e.g. errors::FailAnyInferenceError
or shape_inference::FailAnyInferenceError
etc.
@jcwchen any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that is correct, this PR is far from complete. There are many changes on API surface so we want to make it right before merge. I am more concerned on the Python side. Naming is one thing that we need to be precise. In terms of scope, it is already under ShapeInference module but adding errors.
It is less a concern on the c++ side however.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious why it's less of a concern on the c++ side?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I should have said it was a concern for both. It is just that python api is more visible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@justinchuby it is better to scope the enum. do you think InferenceErrorMode instead of error is better?
enum class InferenceErrorMode : uint32_t {
IgnoreInferenceError, // Ignore any inference errors
FailAnyInferenceError, // Fail on any inference error
FailShapeInferenceError, // Fail on any shape inference error, like merging existing shape with inferred etc.
FailTypeInferenceError // Fail on any type inference error
};
onnx/defs/shape_inference.h
Outdated
@@ -75,11 +78,23 @@ class InferenceError final : public std::runtime_error { | |||
std::string expanded_message_; | |||
}; | |||
|
|||
class ShapeInferenceError final : public InferenceError { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re. the naming: I wonder if it is better to call these ShapeError
and TypeError
. (I realize that this naming comes from the original macro names.)
Also, I think these might not be enough: some errors (like missing attributes or wrong attribute-value) are not really "shape error". So, I wonder if we should either rename ShapeError
to ValueError
or at least add another exception called ValueError
or something more generic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point @gramalingam
I like to keep ShapeError for its original use, adding AttributeError (more specific than ValueError) for errors raised from attribute checking. In addition to ShapeError, TypeError, and AttributeError, InferShape can also raise ValidationError via schema->CheckInputOutputType. It is currently swallowed as a shapeError which is in correct.
onnx/defs/shape_inference.h
Outdated
@@ -18,18 +18,22 @@ namespace ONNX_NAMESPACE { | |||
|
|||
using Dim = TensorShapeProto_Dimension; | |||
|
|||
enum InferenceErrorMode : uint8_t { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think we could easily use uint32_t
. Cost is not significant, and may help us in long run.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also wonder whether we need to expose anything like this externally? Who are the intended users and what will they use it for? May be it is better to focus on the intended use, which is internal ... specifically, we want to be able to propagate context information so that we know whether we are processing a top-level node, or a node contained inside a conditional (or loop? not sure about that). If it is a node inside a conditional, we want to tolerate some errors, since the node may never execute. So, I wonder if we really need this enumeration or expose it externally.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another concern here: the checker throws ValidationError. For the example discussed above, it may be necessary to catch/handle checker-errors as well, since a missing attribute might show up as a checker error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might be helpful to validate ONNX with this PR: #5488 (with converters etc.) to see what failures we run into, if any. And create some test-cases with if-then-else that we would like to successfully pass the checker, and approach this top-down, to make sure we are able to do what we want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the enum is exposed via: infer_shapes, infer_shapes_path. It originally used a binary strict_mode, which can be replaced with IgnoreInferenceError or other enum values. In that sense, we can keep the 2 interfaces as it was so that the enum does not have to be visible from python. For checker, there is an option to run shape inference. In this case, the shape inference error mode is set to be FailAnyInferenceError. This means shape inference specific exception is thrown. As an example:
onnx/onnx/test/checker_test.py
Line 499 in abf0fa7
self.assertRaises(shape_inference.TypeError, checker.check_model, model, True) |
…ypes Signed-off-by: Liqun Fu <liqfu@microsoft.com>
Signed-off-by: Liqun Fu <liqfu@microsoft.com>
@@ -20,7 +20,7 @@ | |||
def infer_shapes( | |||
model: ModelProto | bytes, | |||
check_type: bool = False, | |||
strict_mode: bool = False, | |||
error_mode: C.InferenceErrorMode = C.InferenceErrorMode.IgnoreInferenceError, |
Check failure
Code scanning / lintrunner
MYPY/name-defined Error
@@ -20,7 +20,7 @@ | |||
def infer_shapes( | |||
model: ModelProto | bytes, | |||
check_type: bool = False, | |||
strict_mode: bool = False, | |||
error_mode: C.InferenceErrorMode = C.InferenceErrorMode.IgnoreInferenceError, |
Check failure
Code scanning / lintrunner
MYPY/attr-defined Error
@@ -169,3 +174,6 @@ | |||
|
|||
|
|||
InferenceError = C.InferenceError | |||
TypeError = C.TypeError | |||
ShapeError = C.ShapeError |
Check failure
Code scanning / lintrunner
MYPY/attr-defined Error
@@ -169,3 +174,6 @@ | |||
|
|||
|
|||
InferenceError = C.InferenceError | |||
TypeError = C.TypeError |
Check warning
Code scanning / lintrunner
RUFF/A001 Warning
See https://docs.astral.sh/ruff/rules/builtin-variable-shadowing
Signed-off-by: Liqun Fu <liqfu@microsoft.com>
Signed-off-by: Liqun Fu <liqfu@microsoft.com>
Signed-off-by: Liqun Fu <liqfu@microsoft.com>
Signed-off-by: Liqun Fu <liqfu@microsoft.com>
Signed-off-by: Liqun Fu <liqfu@microsoft.com>
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #5519 +/- ##
==========================================
+ Coverage 56.45% 56.46% +0.01%
==========================================
Files 504 504
Lines 29865 29875 +10
Branches 4484 4484
==========================================
+ Hits 16860 16870 +10
Misses 12188 12188
Partials 817 817 ☔ View full report in Codecov by Sentry. |
@@ -19,7 +19,12 @@ | |||
make_tensor_value_info, | |||
) | |||
from onnx.numpy_helper import from_array | |||
from onnx.shape_inference import InferenceError, infer_node_outputs | |||
from onnx.shape_inference import ( | |||
InferenceError, |
Check warning
Code scanning / lintrunner
RUFF/F401 Warning
See https://docs.astral.sh/ruff/rules/unused-import
@@ -17,6 +17,12 @@ | |||
from onnx import ModelProto, NodeProto, TensorProto | |||
from onnx.backend.base import Device, DeviceType | |||
from onnx.backend.test.runner import BackendIsNotSupposedToImplementIt | |||
from onnx.shape_inference import ( | |||
InferenceError, |
Check warning
Code scanning / lintrunner
RUFF/F401 Warning
See https://docs.astral.sh/ruff/rules/unused-import
from onnx.shape_inference import ( | ||
InferenceError, | ||
InferenceErrorMode, | ||
ShapeError, |
Check warning
Code scanning / lintrunner
RUFF/F401 Warning
See https://docs.astral.sh/ruff/rules/unused-import
InferenceError, | ||
InferenceErrorMode, | ||
ShapeError, | ||
TypeError, |
Check warning
Code scanning / lintrunner
RUFF/F401 Warning
See https://docs.astral.sh/ruff/rules/unused-import
from onnx.shape_inference import ( | ||
InferenceError, | ||
InferenceErrorMode, | ||
ShapeError, | ||
infer_node_outputs, | ||
) |
Check notice
Code scanning / CodeQL
Unused import Note
from onnx.shape_inference import ( | ||
InferenceError, | ||
InferenceErrorMode, | ||
ShapeError, | ||
TypeError, | ||
) |
Check notice
Code scanning / CodeQL
Unused import Note
Import of 'ShapeError' is not used.
Import of 'TypeError' is not used.
@@ -20,7 +20,7 @@ | |||
def infer_shapes( | |||
model: ModelProto | bytes, | |||
check_type: bool = False, | |||
strict_mode: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to keep the strict_mode option for backwards compatibility
Would it make sense for it to throw errors anyways, and then let the user decide what errors to catch and ignore? This way we can simplify the interface and retain the original options without having to introduce an enum to the api. |
Description
Shape inference raises 2 types of inference errors. Currently there is a single inference error for them. It is necessary to differentiate them because shape inference errors are minor where type inference errors are fatal. This PR is to make the 2 types of errors and use it in backend and python.
Motivation and Context
this pr is part of out effort to handle #4986
validate with:
microsoft/onnxruntime#18948