-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support set schema inference function in python #5940
base: main
Are you sure you want to change the base?
Changes from all commits
d6d944a
e813a3a
b558aed
095f0fd
a3a00d6
b47c69e
61a5305
f9beedb
3132b4d
3a57147
e3b744b
01f955d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <pybind11/functional.h> | ||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
|
||
|
@@ -15,6 +16,7 @@ | |
#include "onnx/defs/parser.h" | ||
#include "onnx/defs/printer.h" | ||
#include "onnx/defs/schema.h" | ||
#include "onnx/defs/shape_inference.h" | ||
#include "onnx/inliner/inliner.h" | ||
#include "onnx/py_utils.h" | ||
#include "onnx/shape_inference/implementation.h" | ||
|
@@ -114,6 +116,9 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { | |
#endif // ONNX_ML | ||
); | ||
|
||
// Avoid Segmentation fault if we not free the python function in Custom Schema | ||
onnx_cpp2py_export.add_object("_cleanup", py::capsule([] { OpSchemaRegistry::OpSchemaDeregisterAll(); })); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you clarify when this gets invoked? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the segfault is caused by: About '_cleanup' : https://pybind11.readthedocs.io/en/stable/advanced/misc.html#module-destructors |
||
|
||
// Submodule `schema` | ||
auto defs = onnx_cpp2py_export.def_submodule("defs"); | ||
defs.doc() = "Schema submodule"; | ||
|
@@ -394,6 +399,12 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { | |
func_proto.SerializeToString(&func_bytes); | ||
} | ||
return py::bytes(func_bytes); | ||
}) | ||
.def( | ||
"set_type_and_shape_inference_function", | ||
[](OpSchema* op, const std::function<void(InferenceContext*)>& func) { | ||
auto wrapper = [=](InferenceContext& ctx) { func(&ctx); }; | ||
return op->TypeAndShapeInferenceFunction(wrapper); | ||
}); | ||
|
||
defs.def( | ||
|
@@ -625,6 +636,127 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { | |
shape_inference.doc() = "Shape Inference submodule"; | ||
py::register_exception<InferenceError>(shape_inference, "InferenceError"); | ||
|
||
py::class_<InferenceContext> inference_ctx(shape_inference, "InferenceContext", "Inference context"); | ||
|
||
inference_ctx.def("__get_attribute", [](InferenceContext* ctx, std::string name) { | ||
if (ctx == nullptr) { | ||
fail_shape_inference("Internal error: `ctx` is nullptr in `__get_attribute`"); | ||
} | ||
auto attr = ctx->getAttribute(name); | ||
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved
Hide resolved
|
||
if (attr == nullptr) { | ||
fail_shape_inference("Internal error: `attr` is nullptr in `__get_attribute`"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not an error situation. In C++, this single method plays a dual role, both to check if an attribute is present, and to get its value. Is it possible to return There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or, just a return a C++ pointer (using pybind's conventions ... it must have some standard approach). |
||
} | ||
std::string data; | ||
attr->SerializeToString(&data); | ||
return py::bytes(data); | ||
}); | ||
inference_ctx.def("get_num_inputs", &InferenceContext::getNumInputs); | ||
inference_ctx.def("has_input", &InferenceContext::hasInput); | ||
inference_ctx.def("__get_input_type", [](InferenceContext* ctx, size_t index) { | ||
if (ctx == nullptr) { | ||
fail_shape_inference("Internal error: `ctx` is nullptr in `__get_input_type`"); | ||
} | ||
auto type = ctx->getInputType(index); | ||
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved
Hide resolved
|
||
if (type == nullptr) { | ||
fail_shape_inference("Internal error: `type` is nullptr in `__get_input_type`"); | ||
} | ||
std::string data; | ||
type->SerializeToString(&data); | ||
return py::bytes(data); | ||
}); | ||
inference_ctx.def("__get_input_data", [](InferenceContext* ctx, size_t index) { | ||
if (ctx == nullptr) { | ||
fail_shape_inference("Internal error: `ctx` is nullptr in `__get_input_data`"); | ||
} | ||
auto tensor = ctx->getInputData(index); | ||
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved
Hide resolved
|
||
if (tensor == nullptr) { | ||
fail_shape_inference("Internal error: `tensor` is nullptr in `__get_input_data`"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, not an error. Signifies no static-constant-value is known for this input. Best to map this to |
||
} | ||
std::string data; | ||
tensor->SerializeToString(&data); | ||
return py::bytes(data); | ||
}); | ||
inference_ctx.def("__get_input_sparse_data", [](InferenceContext* ctx, size_t index) { | ||
if (ctx == nullptr) { | ||
fail_shape_inference("Internal error: `ctx` is nullptr in `__get_input_sparse_data`"); | ||
} | ||
auto stensor = ctx->getInputSparseData(index); | ||
if (stensor == nullptr) { | ||
fail_shape_inference("Internal error: `stensor` is nullptr in `__get_input_sparse_data`"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As above. Map to |
||
} | ||
std::string data; | ||
stensor->SerializeToString(&data); | ||
return py::bytes(data); | ||
}); | ||
inference_ctx.def("__get_symbolic_input", [](InferenceContext* ctx, size_t index) { | ||
if (ctx == nullptr) { | ||
fail_shape_inference("Internal error: `ctx` is nullptr in `__get_symbolic_input`"); | ||
} | ||
auto shape = ctx->getSymbolicInput(index); | ||
if (shape == nullptr) { | ||
fail_shape_inference("Internal error: `shape` is nullptr in `__get_symbolic_input`"); | ||
} | ||
std::string data; | ||
shape->SerializeToString(&data); | ||
return py::bytes(data); | ||
}); | ||
inference_ctx.def("__get_graph_attribute_inferencer", &InferenceContext::getGraphAttributeInferencer); | ||
inference_ctx.def("get_num_outputs", &InferenceContext::getNumOutputs); | ||
inference_ctx.def("__get_output_type", [](InferenceContext* ctx, size_t index) { | ||
if (ctx == nullptr) { | ||
fail_shape_inference("Internal error: `ctx` is nullptr in `__get_output_type`"); | ||
} | ||
auto type = ctx->getOutputType(index); | ||
if (type == nullptr) { | ||
fail_shape_inference("Internal error: `type` is nullptr in `__get_output_type`"); | ||
} | ||
std::string data; | ||
type->SerializeToString(&data); | ||
return py::bytes(data); | ||
}); | ||
inference_ctx.def("__set_output_type", [](InferenceContext* ctx, size_t index, py::bytes bytes) { | ||
if (ctx == nullptr) { | ||
fail_shape_inference("Internal error: `ctx` is nullptr in `__set_output_type`"); | ||
} | ||
auto type = ctx->getOutputType(index); | ||
if (type == nullptr) { | ||
fail_shape_inference("Internal error: `type` is nullptr in `__set_output_type`"); | ||
} | ||
ParseProtoFromPyBytes(type, bytes); | ||
}); | ||
|
||
py::class_<GraphInferencer> graph_inferencer(shape_inference, "GraphInferencer", "Graph Inferencer"); | ||
graph_inferencer.def( | ||
"__do_inferencing", | ||
[](GraphInferencer* inferencer, | ||
const std::vector<py::bytes>& input_types, | ||
const std::vector<py::bytes>& input_data) { | ||
std::vector<TypeProto> type_proto; | ||
std::vector<TensorProto> tensor_proto; | ||
std::vector<const TypeProto*> type_inputs; | ||
std::vector<const TensorProto*> tensor_inputs; | ||
for (const auto& bytes : input_types) { | ||
TypeProto proto{}; | ||
ParseProtoFromPyBytes(&proto, bytes); | ||
type_proto.emplace_back(proto); | ||
type_inputs.emplace_back(&type_proto.back()); | ||
} | ||
for (const auto& bytes : input_data) { | ||
TensorProto proto{}; | ||
ParseProtoFromPyBytes(&proto, bytes); | ||
tensor_proto.emplace_back(proto); | ||
tensor_inputs.emplace_back(&tensor_proto.back()); | ||
} | ||
auto ret = inferencer->doInferencing(type_inputs, tensor_inputs); | ||
std::vector<py::bytes> out; | ||
for (const auto& type : ret) { | ||
std::string data; | ||
type->SerializeToString(&data); | ||
out.emplace_back(py::bytes(data)); | ||
} | ||
return out; | ||
}); | ||
|
||
shape_inference.def( | ||
"infer_shapes", | ||
[](const py::bytes& bytes, bool check_type, bool strict_mode, bool data_prop) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10172,6 +10172,103 @@ def test_check_type_when_schema_has_empty_io(self): | |
op_schema.name, op_schema.since_version, op_schema.domain | ||
) | ||
|
||
def test_custom_schema_shape_inference(self) -> None: | ||
# CustomOp schema: | ||
# attrs: | ||
# out_len: [L0, L1, ...] | ||
# inputs: | ||
# a[N, La] | ||
# b[N, Lb] | ||
# outputs: | ||
# out0[N, La * Lb, L0] | ||
# out1[N, La * Lb, L1] | ||
# ... | ||
N = 3 | ||
La = 32 | ||
Lb = 64 | ||
out_len = [1, 2] | ||
outs = [f"out{i}" for i in range(len(out_len))] | ||
graph = self._make_graph( | ||
[ | ||
("a", TensorProto.FLOAT, (N, La)), | ||
("b", TensorProto.FLOAT, (N, Lb)), | ||
], | ||
[make_node("CustomOp", ["a", "b"], outs, out_len=out_len)], | ||
[], | ||
) | ||
with self.assertRaises(onnx.checker.ValidationError): | ||
self._assert_inferred( | ||
graph, | ||
[ | ||
make_tensor_value_info( | ||
f"out{i}", TensorProto.FLOAT, (N, La * Lb, Li) | ||
) | ||
for i, Li in enumerate(out_len) | ||
], | ||
) | ||
|
||
schema = OpSchema( | ||
"CustomOp", | ||
"", | ||
1, | ||
inputs=[ | ||
defs.OpSchema.FormalParameter("a", "float"), | ||
defs.OpSchema.FormalParameter("b", "float"), | ||
], | ||
outputs=[ | ||
defs.OpSchema.FormalParameter( | ||
"out", "float", param_option=OpSchema.FormalParameterOption.Variadic | ||
), | ||
], | ||
attributes=[ | ||
defs.OpSchema.Attribute("out_len", defs.OpSchema.AttrType.INTS) | ||
], | ||
) | ||
|
||
def func(ctx: onnx.shape_inference.InferenceContext): | ||
def parse_tensor_input(t: TypeProto): | ||
assert isinstance(t, TypeProto) | ||
return ( | ||
t.tensor_type.elem_type, | ||
[ | ||
d.dim_value if d.HasField("dim_value") else None | ||
for d in t.tensor_type.shape.dim | ||
], | ||
) | ||
|
||
assert ctx.get_num_inputs() == 2 | ||
in0 = ctx.get_input_type(0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My concern with this is it goes through serialization to access the type information. It is not really efficient. I would change the API so that it does not return a TypeProto but the type and the shape as regular python objects. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I agree with Xavier, but I am a bit confused also. I see the method implementation serializes proto values to string and returns them. We could just return a pointer to the C++ Proto object (wrapped as a Python object). Is that your suggestion Xavier? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we closely mimic the C++ API design interface, or should we integrate Python's native types for interactions? Utilizing the proto pointer for interactions may require additional codes to bind them to Python (If there is another way please correct me), or we need to include some third-party library. |
||
in1 = ctx.get_input_type(1) | ||
in0_type, in0_shape = parse_tensor_input(in0) | ||
in1_type, in1_shape = parse_tensor_input(in1) | ||
assert in0_type == in1_type == TensorProto.FLOAT | ||
assert len(in0_shape) == len(in1_shape) == 2 | ||
assert in0_shape[0] == in1_shape[0] | ||
N, La = in0_shape | ||
_, Lb = in1_shape | ||
attr = ctx.get_attribute("out_len") | ||
out_len = attr.ints | ||
assert len(out_len) == ctx.get_num_outputs() | ||
for i in range(ctx.get_num_outputs()): | ||
out = ctx.get_output_type(i) | ||
out.tensor_type.elem_type = in0_type | ||
out.tensor_type.shape.dim.add().dim_value = N | ||
out.tensor_type.shape.dim.add().dim_value = La * Lb | ||
out.tensor_type.shape.dim.add().dim_value = out_len[i] | ||
ctx.set_output_type(i, out) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment here, we should avoid serialization with something like |
||
|
||
schema.set_type_and_shape_inference_function(func) | ||
onnx.defs.register_schema(schema) | ||
|
||
self._assert_inferred( | ||
graph, | ||
[ | ||
make_tensor_value_info(f"out{i}", TensorProto.FLOAT, (N, La * Lb, Li)) | ||
for i, Li in enumerate(out_len) | ||
], | ||
) | ||
onnx.defs.deregister_schema(schema.name, schema.since_version, schema.domain) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps a more graceful approach would be to collect the schemas registered from Python and deregister them during cleanup. However, I'm not sure if it's worth the effort. In most cases, invoking cleanup implies that Python is exiting.