-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Differentiate shape and type inference errors #5519
base: main
Are you sure you want to change the base?
Changes from 5 commits
1846156
00cffa9
7a25ab4
4dbf23c
a34720d
73eeb91
099bdcf
8d0644a
1c574ec
5ee4e02
cfa8267
abf0fa7
d6e8efb
5bcf659
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,7 @@ | |
def infer_shapes( | ||
model: ModelProto | bytes, | ||
check_type: bool = False, | ||
strict_mode: bool = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to keep the strict_mode option for backwards compatibility |
||
error_mode: C.InferenceErrorMode = C.InferenceErrorMode.IgnoreInferenceError, | ||
Check failure Code scanning / lintrunner MYPY/name-defined Error
Name "C.InferenceErrorMode" is not defined
To disable, use # type: ignore[name-defined]
Check failure Code scanning / lintrunner MYPY/attr-defined Error
Module has no attribute "InferenceErrorMode"; maybe "InferenceError"?
To disable, use # type: ignore[attr-defined]
|
||
data_prop: bool = False, | ||
) -> ModelProto: | ||
"""Apply shape inference to the provided ModelProto. | ||
|
@@ -36,6 +36,11 @@ | |
check_type: Checks the type-equality for input and output. | ||
strict_mode: Stricter shape inference, it will throw errors if any; | ||
Otherwise, simply stop if any error. | ||
error_mode (InferenceErrorMode): | ||
|
||
IgnoreInferenceError: ignore any inference error and continue | ||
FailAnyInferenceError: stop on any inference error and throw the exception combining all inference errors | ||
FailShapeInferenceError: stop on shape inference error and throw the exception | ||
FailTypeInferenceError: stop on type inference error and throw the exception | ||
data_prop: Enables data propagation for limited operators to perform shape computation. | ||
|
||
Returns: | ||
|
@@ -44,7 +49,7 @@ | |
if isinstance(model, (ModelProto, bytes)): | ||
model_str = model if isinstance(model, bytes) else model.SerializeToString() | ||
inferred_model_str = C.infer_shapes( | ||
model_str, check_type, strict_mode, data_prop | ||
model_str, check_type, error_mode, data_prop | ||
) | ||
return onnx.load_from_string(inferred_model_str) | ||
if isinstance(model, str): | ||
|
@@ -62,7 +67,7 @@ | |
model_path: str | os.PathLike, | ||
output_path: str | os.PathLike = "", | ||
check_type: bool = False, | ||
strict_mode: bool = False, | ||
error_mode: C.InferenceErrorMode = C.InferenceErrorMode.IgnoreInferenceError, | ||
Check failure Code scanning / lintrunner MYPY/name-defined Error
Name "C.InferenceErrorMode" is not defined
To disable, use # type: ignore[name-defined]
Check failure Code scanning / lintrunner MYPY/attr-defined Error
Module has no attribute "InferenceErrorMode"; maybe "InferenceError"?
To disable, use # type: ignore[attr-defined]
|
||
data_prop: bool = False, | ||
) -> None: | ||
"""Take model path for shape_inference. | ||
|
@@ -93,7 +98,7 @@ | |
|
||
if output_path == "": | ||
output_path = model_path | ||
C.infer_shapes_path(model_path, output_path, check_type, strict_mode, data_prop) | ||
C.infer_shapes_path(model_path, output_path, check_type, error_mode, data_prop) | ||
|
||
|
||
def infer_node_outputs( | ||
|
@@ -169,3 +174,6 @@ | |
|
||
|
||
InferenceError = C.InferenceError | ||
TypeError = C.TypeError | ||
Check failure Code scanning / lintrunner MYPY/attr-defined Error
Module has no attribute "TypeError"
To disable, use # type: ignore[attr-defined]
Check warning Code scanning / lintrunner RUFF/A001 Warning
Variable TypeError is shadowing a Python builtin.
See https://docs.astral.sh/ruff/rules/builtin-variable-shadowing |
||
ShapeError = C.ShapeError | ||
Check failure Code scanning / lintrunner MYPY/attr-defined Error
Module has no attribute "ShapeError"
To disable, use # type: ignore[attr-defined]
|
||
InferenceErrorMode = C.InferenceErrorMode | ||
Check failure Code scanning / lintrunner MYPY/attr-defined Error
Module has no attribute "InferenceErrorMode"; maybe "InferenceError"?
To disable, use # type: ignore[attr-defined]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to scope this enum for clarity? e.g.
errors::FailAnyInferenceError
orshape_inference::FailAnyInferenceError
etc.@jcwchen any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that is correct, this PR is far from complete. There are many changes on API surface so we want to make it right before merge. I am more concerned on the Python side. Naming is one thing that we need to be precise. In terms of scope, it is already under ShapeInference module but adding errors.
It is less a concern on the c++ side however.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious why it's less of a concern on the c++ side?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I should have said it was a concern for both. It is just that python api is more visible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@justinchuby it is better to scope the enum. do you think InferenceErrorMode instead of error is better?
enum class InferenceErrorMode : uint32_t {
IgnoreInferenceError, // Ignore any inference errors
FailAnyInferenceError, // Fail on any inference error
FailShapeInferenceError, // Fail on any shape inference error, like merging existing shape with inferred etc.
FailTypeInferenceError // Fail on any type inference error
};