Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
Signed-off-by: Liqun Fu <liqfu@microsoft.com>
  • Loading branch information
liqunfu committed Dec 25, 2023
1 parent cfa8267 commit abf0fa7
Show file tree
Hide file tree
Showing 13 changed files with 120 additions and 82 deletions.
10 changes: 5 additions & 5 deletions onnx/cpp2py_export.cc
Expand Up @@ -611,12 +611,12 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {
py::register_exception<InferenceError>(shape_inference, "InferenceError");
py::register_exception<TypeError>(shape_inference, "TypeError");
py::register_exception<ShapeError>(shape_inference, "ShapeError");

py::enum_<InferenceErrorMode>(shape_inference, "InferenceErrorMode")
.value("IgnoreInferenceError", IgnoreInferenceError)
.value("FailAnyInferenceError", FailAnyInferenceError)
.value("FailShapeInferenceError", FailShapeInferenceError)
.value("FailTypeInferenceError", FailTypeInferenceError);
.value("IgnoreInferenceError", IgnoreInferenceError)
.value("FailAnyInferenceError", FailAnyInferenceError)
.value("FailShapeInferenceError", FailShapeInferenceError)
.value("FailTypeInferenceError", FailTypeInferenceError);

shape_inference.def(
"infer_shapes",
Expand Down
26 changes: 12 additions & 14 deletions onnx/defs/shape_inference.h
Expand Up @@ -20,9 +20,9 @@ using Dim = TensorShapeProto_Dimension;

enum InferenceErrorMode : uint32_t {
IgnoreInferenceError, // Ignore any inference errors

Check warning on line 22 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnx/defs/shape_inference.h:22: At least two spaces is best between code and comments [whitespace/comments] [2]
FailAnyInferenceError, // Fail on any inference error
FailShapeInferenceError, // Fail on any shape inference error, like merging existing shape with inferred etc.
FailTypeInferenceError // Fail on any type inference error
FailAnyInferenceError, // Fail on any inference error

Check warning on line 23 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnx/defs/shape_inference.h:23: At least two spaces is best between code and comments [whitespace/comments] [2]
FailShapeInferenceError, // Fail on any shape inference error, like merging existing shape with inferred etc.

Check warning on line 24 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnx/defs/shape_inference.h:24: At least two spaces is best between code and comments [whitespace/comments] [2]
FailTypeInferenceError // Fail on any type inference error

Check warning on line 25 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnx/defs/shape_inference.h:25: At least two spaces is best between code and comments [whitespace/comments] [2]
};

struct ShapeInferenceOptions {
Expand All @@ -32,7 +32,10 @@ struct ShapeInferenceOptions {
// Enables data propagation for limited operators
// to perform shape computation
bool enable_data_propagation;
ShapeInferenceOptions(bool check_type_val = false, InferenceErrorMode error_mode_val = IgnoreInferenceError, bool data_prop_val = false)
ShapeInferenceOptions(
bool check_type_val = false,
InferenceErrorMode error_mode_val = IgnoreInferenceError,
bool data_prop_val = false)
: check_type(check_type_val), error_mode(error_mode_val), enable_data_propagation(data_prop_val){};

Check warning on line 39 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnx/defs/shape_inference.h:39: You don't need a ; after a } [readability/braces] [4]
};

Expand Down Expand Up @@ -83,24 +86,19 @@ class InferenceError : public std::runtime_error {

class ShapeError final : public InferenceError {
public:
ShapeError(const std::string& message)
: InferenceError(ONNX_NAMESPACE::MakeString("[ShapeError] ", message)) {}
ShapeError(const std::string& message) : InferenceError(ONNX_NAMESPACE::MakeString("[ShapeError] ", message)) {}
};

class TypeError final : public InferenceError {
public:
TypeError(const std::string& message)
: InferenceError(ONNX_NAMESPACE::MakeString("[TypeError] ", message)) {}
TypeError(const std::string& message) : InferenceError(ONNX_NAMESPACE::MakeString("[TypeError] ", message)) {}
};

#define fail_type_inference(...) \
ONNX_THROW_EX(ONNX_NAMESPACE::TypeError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));
#define fail_type_inference(...) ONNX_THROW_EX(ONNX_NAMESPACE::TypeError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));

Check warning on line 97 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: variadic macro 'fail_type_inference' used; consider using a 'constexpr' variadic template function [cppcoreguidelines-macro-usage] ```cpp #define fail_type_inference(...) ONNX_THROW_EX(ONNX_NAMESPACE::TypeError(ONNX_NAMESPACE::MakeString(__VA_ARGS__))); ^ ```

#define fail_shape_inference(...) \
ONNX_THROW_EX(ONNX_NAMESPACE::ShapeError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));
#define fail_shape_inference(...) ONNX_THROW_EX(ONNX_NAMESPACE::ShapeError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));

Check warning on line 99 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: variadic macro 'fail_shape_inference' used; consider using a 'constexpr' variadic template function [cppcoreguidelines-macro-usage] ```cpp #define fail_shape_inference(...) ONNX_THROW_EX(ONNX_NAMESPACE::ShapeError(ONNX_NAMESPACE::MakeString(__VA_ARGS__))); ^ ```

#define fail_inference(...) \
ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));
#define fail_inference(...) ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));

Check warning on line 101 in onnx/defs/shape_inference.h

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: variadic macro 'fail_inference' used; consider using a 'constexpr' variadic template function [cppcoreguidelines-macro-usage] ```cpp #define fail_inference(...) ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString(__VA_ARGS__))); ^ ```
struct InferenceContext {
virtual const AttributeProto* getAttribute(const std::string& name) const = 0;
virtual size_t getNumInputs() const = 0;
Expand Down
2 changes: 1 addition & 1 deletion onnx/shape_inference.py
Expand Up @@ -36,7 +36,7 @@ def infer_shapes(
check_type: Checks the type-equality for input and output.
strict_mode: Stricter shape inference, it will throw errors if any;
Otherwise, simply stop if any error.
error_mode (InferenceErrorMode):
error_mode (InferenceErrorMode):
IgnoreInferenceError: ignore any inference error and continue
FailAnyInferenceError: stop on any inference error and throw the exception combining all inference errors
FailShapeInferenceError: stop on shape inference error and throw the exception
Expand Down
18 changes: 7 additions & 11 deletions onnx/shape_inference/implementation.cc
Expand Up @@ -723,23 +723,19 @@ class ShapeInferenceImplBase {
ONNX_HANDLE_EXCEPTION([&]() {
type_inference_errors += type_error.what();
any_inference_errors += type_error.what();
});
});
}
ONNX_CATCH(const ONNX_NAMESPACE::ShapeError& shape_error) {
ONNX_HANDLE_EXCEPTION([&]() {
shape_inference_errors += shape_error.what();
any_inference_errors += shape_error.what();
});
});
}
ONNX_CATCH(const ONNX_NAMESPACE::InferenceError& inference_error) {
ONNX_HANDLE_EXCEPTION([&]() {
any_inference_errors += inference_error.what();
});
ONNX_HANDLE_EXCEPTION([&]() { any_inference_errors += inference_error.what(); });
}
ONNX_CATCH(const std::exception& other_error) {
ONNX_HANDLE_EXCEPTION([&]() {
other_errors += other_error.what();
});
ONNX_HANDLE_EXCEPTION([&]() { other_errors += other_error.what(); });
}
}
// depend on error_mode and collected errors, fail shape or type inference
Expand All @@ -756,9 +752,9 @@ class ShapeInferenceImplBase {
}
}

//const std::vector<std::string>& getErrors() const {
// return inference_errors;
//}
// const std::vector<std::string>& getErrors() const {
// return inference_errors;
// }

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.

private:
InferredTypes inferred_types;
Expand Down
28 changes: 7 additions & 21 deletions onnx/test/checker_test.py
Expand Up @@ -162,9 +162,7 @@ def test_check_graph_types(self) -> None:

model = helper.make_model(graph, producer_name="test")

self.assertRaises(
shape_inference.ShapeError, checker.check_model, model, True
)
self.assertRaises(shape_inference.ShapeError, checker.check_model, model, True)

checker.check_graph(graph)

Expand All @@ -181,9 +179,7 @@ def test_check_graph_types(self) -> None:

model = helper.make_model(graph, producer_name="test")

self.assertRaises(
shape_inference.ShapeError, checker.check_model, model, True
)
self.assertRaises(shape_inference.ShapeError, checker.check_model, model, True)

checker.check_graph(graph)

Expand Down Expand Up @@ -478,9 +474,7 @@ def test_check_model_unsupported_input_type(self) -> None:
node = helper.make_node("Add", ["X", "Y"], ["Z"])
graph = helper.make_graph([node], "test_add_input", [X, Y], [Z])
model = helper.make_model(graph, producer_name="test", opset_imports=[onnx_id])
self.assertRaises(
shape_inference.ShapeError, checker.check_model, model, True
)
self.assertRaises(shape_inference.ShapeError, checker.check_model, model, True)

def test_check_model_inconsistent_type(self) -> None:
N = 10
Expand All @@ -491,9 +485,7 @@ def test_check_model_inconsistent_type(self) -> None:
node = helper.make_node("Add", ["X", "Y"], ["Z"])
graph = helper.make_graph([node], "test_add_input", [X, Y], [Z])
model = helper.make_model(graph, producer_name="test", opset_imports=[onnx_id])
self.assertRaises(
shape_inference.ShapeError, checker.check_model, model, True
)
self.assertRaises(shape_inference.ShapeError, checker.check_model, model, True)

def test_check_model_unsupported_output_type(self) -> None:
N = 10
Expand All @@ -504,9 +496,7 @@ def test_check_model_unsupported_output_type(self) -> None:
node = helper.make_node("Add", ["X", "Y"], ["Z"])
graph = helper.make_graph([node], "test_add_input", [X, Y], [Z])
model = helper.make_model(graph, producer_name="test", opset_imports=[onnx_id])
self.assertRaises(
shape_inference.TypeError, checker.check_model, model, True
)
self.assertRaises(shape_inference.TypeError, checker.check_model, model, True)

def test_loop_with_same_initializer_input_below_ir4(self) -> None:
# This is for testing IR<4: tensors must exist both in initializer and input
Expand Down Expand Up @@ -846,9 +836,7 @@ def test_loop_with_different_initializer_input_below_ir4(self) -> None:
],
),
)
self.assertRaises(
shape_inference.TypeError, checker.check_model, model, True
)
self.assertRaises(shape_inference.TypeError, checker.check_model, model, True)

def test_loop_with_same_initializer_input_above_ir4(self) -> None:
# This is for testing IR>=4:
Expand Down Expand Up @@ -1026,9 +1014,7 @@ def test_loop_with_same_initializer_input_above_ir4(self) -> None:
],
),
)
self.assertRaises(
shape_inference.TypeError, checker.check_model, model, True
)
self.assertRaises(shape_inference.TypeError, checker.check_model, model, True)

def test_empty_list_attribute(self):
model = onnx.parser.parse_model(
Expand Down
20 changes: 15 additions & 5 deletions onnx/test/inference_function_test.py
Expand Up @@ -19,8 +19,12 @@
make_tensor_value_info,
)
from onnx.numpy_helper import from_array
from onnx.shape_inference import InferenceError, InferenceErrorMode, ShapeError, infer_node_outputs

from onnx.shape_inference import (
InferenceError,

Check warning

Code scanning / lintrunner

RUFF/F401 Warning

onnx.shape\_inference.InferenceError imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import
InferenceErrorMode,
ShapeError,
infer_node_outputs,
)

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'InferenceError' is not used.

ADD_SCHEMA = max(
(s for s in get_all_schemas_with_history() if s.name == "Add" and s.domain == ""),
Expand Down Expand Up @@ -237,9 +241,13 @@ def test_inference_with_conflow(self) -> None:
}
"""
model = onnx.parser.parse_model(model_script)
onnx.shape_inference.infer_shapes(model, error_mode=InferenceErrorMode.IgnoreInferenceError)
onnx.shape_inference.infer_shapes(
model, error_mode=InferenceErrorMode.IgnoreInferenceError
)
with self.assertRaises(ShapeError):
onnx.shape_inference.infer_shapes(model, error_mode=InferenceErrorMode.FailAnyInferenceError)
onnx.shape_inference.infer_shapes(
model, error_mode=InferenceErrorMode.FailAnyInferenceError
)

def test_inference_with_attribute(self) -> None:
model_script = """
Expand Down Expand Up @@ -274,7 +282,9 @@ def test_inference_with_attribute(self) -> None:
}
"""
model = onnx.parser.parse_model(model_script)
onnx.shape_inference.infer_shapes(model, error_mode=InferenceErrorMode.FailAnyInferenceError)
onnx.shape_inference.infer_shapes(
model, error_mode=InferenceErrorMode.FailAnyInferenceError
)


if __name__ == "__main__":
Expand Down
11 changes: 9 additions & 2 deletions onnx/test/model_inference_test.py
Expand Up @@ -31,7 +31,9 @@ def _check_inference_error(self, model_text: str):
"""Check that the model inference raises an InferenceError."""
model = onnx.parser.parse_model(model_text)
with self.assertRaises(onnx.shape_inference.InferenceError):
onnx.shape_inference.infer_shapes(model, True, error_mode=InferenceErrorMode.FailAnyInferenceError)
onnx.shape_inference.infer_shapes(

Check warning on line 34 in onnx/test/model_inference_test.py

View check run for this annotation

Codecov / codecov/patch

onnx/test/model_inference_test.py#L34

Added line #L34 was not covered by tests
model, True, error_mode=InferenceErrorMode.FailAnyInferenceError
)

def test_unknown_op(self):
"""Test that model inference handles unknown ops.
Expand Down Expand Up @@ -164,7 +166,12 @@ def _check_shape(self, model_text: str, *expected: typing.Sequence[int]):
known shapes.
"""
model = onnx.parser.parse_model(model_text)
inferred = onnx.shape_inference.infer_shapes(model, True, error_mode=InferenceErrorMode.FailAnyInferenceError, data_prop=True)
inferred = onnx.shape_inference.infer_shapes(
model,
True,
error_mode=InferenceErrorMode.FailAnyInferenceError,
data_prop=True,
)
outputs = inferred.graph.output
for output, expected_shape in zip(outputs, expected):
inferred_type = output.type
Expand Down
20 changes: 13 additions & 7 deletions onnx/test/shape_inference_test.py
Expand Up @@ -13,8 +13,6 @@
from parameterized import parameterized

import onnx.shape_inference
from onnx.shape_inference import InferenceErrorMode

from onnx import (
ONNX_ML,
GraphProto,
Expand Down Expand Up @@ -46,6 +44,7 @@
make_tensor_value_info,
)
from onnx.parser import parse_graph
from onnx.shape_inference import InferenceErrorMode


def get_available_versions(schema: OpSchema) -> set[int]:
Expand Down Expand Up @@ -143,9 +142,11 @@ def _inferred(
else:
orig_model = graph_or_model
inferred_model = onnx.shape_inference.infer_shapes(
orig_model, error_mode=InferenceErrorMode.FailAnyInferenceError, data_prop=data_prop
orig_model,
error_mode=InferenceErrorMode.FailAnyInferenceError,
data_prop=data_prop,
)

checker.check_model(inferred_model)
return inferred_model

Expand Down Expand Up @@ -7368,7 +7369,9 @@ def test_infer_initializer_input_consistency_all_none(self) -> None:
initializer_shape, input_shape
)

onnx.shape_inference.infer_shapes(original_model, error_mode=InferenceErrorMode.FailAnyInferenceError)
onnx.shape_inference.infer_shapes(
original_model, error_mode=InferenceErrorMode.FailAnyInferenceError
)

def test_infer_initializer_input_consistency_single_none(self) -> None:
initializer_shape = (8, 7)
Expand All @@ -7377,7 +7380,9 @@ def test_infer_initializer_input_consistency_single_none(self) -> None:
initializer_shape, input_shape
)

onnx.shape_inference.infer_shapes(original_model, error_mode=InferenceErrorMode.FailAnyInferenceError)
onnx.shape_inference.infer_shapes(
original_model, error_mode=InferenceErrorMode.FailAnyInferenceError
)

def test_infer_initializer_input_consistency_differnt_rank(self) -> None:

Check warning on line 7387 in onnx/test/shape_inference_test.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "differnt" is a misspelling of "different" Raw Output: ./onnx/test/shape_inference_test.py:7387:49: "differnt" is a misspelling of "different"
initializer_shape = (8, 7, 9)
Expand All @@ -7403,7 +7408,8 @@ def test_infer_initializer_input_consistency_all_none_serialized(self) -> None:
)

onnx.shape_inference.infer_shapes(
original_model.SerializeToString(), error_mode=InferenceErrorMode.FailAnyInferenceError
original_model.SerializeToString(),
error_mode=InferenceErrorMode.FailAnyInferenceError,
)

def test_trilu_upper(self) -> None:
Expand Down
18 changes: 13 additions & 5 deletions onnx/test/symbolic_shape_test.py
Expand Up @@ -6,9 +6,9 @@
from typing import List, Optional

import onnx.shape_inference
from onnx.shape_inference import InferenceErrorMode
from onnx import ModelProto, TensorProto, TensorShapeProto, ValueInfoProto, helper
from onnx.helper import make_model, make_tensor_value_info
from onnx.shape_inference import InferenceErrorMode


class TestSymbolicShape(unittest.TestCase):
Expand Down Expand Up @@ -82,7 +82,9 @@ def test_concat_enable_symbolic(self) -> None:
)

onnx_model = make_model(graph_def)
inferred_model = onnx.shape_inference.infer_shapes(onnx_model, error_mode=InferenceErrorMode.FailAnyInferenceError)
inferred_model = onnx.shape_inference.infer_shapes(
onnx_model, error_mode=InferenceErrorMode.FailAnyInferenceError
)
self._assert_valueinfo_shape(
inferred_model, [make_tensor_value_info("C", TensorProto.FLOAT, (2, -1))]
)
Expand Down Expand Up @@ -115,7 +117,9 @@ def test_two_symbolic_concat(self) -> None:
)

onnx_model = make_model(graph_def)
inferred_model = onnx.shape_inference.infer_shapes(onnx_model, error_mode=InferenceErrorMode.FailAnyInferenceError)
inferred_model = onnx.shape_inference.infer_shapes(
onnx_model, error_mode=InferenceErrorMode.FailAnyInferenceError
)
self._assert_valueinfo_shape(
inferred_model,
[
Expand Down Expand Up @@ -155,7 +159,9 @@ def test_duplicate_symbolic_shape(self) -> None:

onnx_model = make_model(graph_def)
original_count = self._count_unique_dim_param_number(onnx_model)
inferred_model = onnx.shape_inference.infer_shapes(onnx_model, error_mode=InferenceErrorMode.FailAnyInferenceError)
inferred_model = onnx.shape_inference.infer_shapes(
onnx_model, error_mode=InferenceErrorMode.FailAnyInferenceError
)
inferred_count = self._count_unique_dim_param_number(inferred_model)
# to prevent duplicate so the inferred count will be count + 2
# new symbol 'unk__2' and 'unk__3' should be generated
Expand Down Expand Up @@ -185,7 +191,9 @@ def test_unknown_shape(self) -> None:
)

onnx_model = make_model(graph_def)
inferred_model = onnx.shape_inference.infer_shapes(onnx_model, error_mode=InferenceErrorMode.FailAnyInferenceError)
inferred_model = onnx.shape_inference.infer_shapes(
onnx_model, error_mode=InferenceErrorMode.FailAnyInferenceError
)
self._assert_valueinfo_shape(
inferred_model, [make_tensor_value_info("C", TensorProto.FLOAT, (3, -1))]
)
Expand Down

0 comments on commit abf0fa7

Please sign in to comment.