Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support set schema inference function in python #5940

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
132 changes: 132 additions & 0 deletions onnx/cpp2py_export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//
// SPDX-License-Identifier: Apache-2.0

#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

Expand All @@ -15,6 +16,7 @@
#include "onnx/defs/parser.h"
#include "onnx/defs/printer.h"
#include "onnx/defs/schema.h"
#include "onnx/defs/shape_inference.h"
#include "onnx/inliner/inliner.h"
#include "onnx/py_utils.h"
#include "onnx/shape_inference/implementation.h"
Expand Down Expand Up @@ -114,6 +116,9 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {
#endif // ONNX_ML
);

// Avoid Segmentation fault if we not free the python function in Custom Schema
onnx_cpp2py_export.add_object("_cleanup", py::capsule([] { OpSchemaRegistry::OpSchemaDeregisterAll(); }));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps a more graceful approach would be to collect the schemas registered from Python and deregister them during cleanup. However, I'm not sure if it's worth the effort. In most cases, invoking cleanup implies that Python is exiting.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify when this gets invoked?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the segfault is caused by:
The python object (the inference function in custom schema) need destroyed before the python interpreter is destroyed. The static container within the schema factory is destroyed after main function and before the interpreter. Therefore, we need to manually destroy the Python object.

About '_cleanup' : https://pybind11.readthedocs.io/en/stable/advanced/misc.html#module-destructors


// Submodule `schema`
auto defs = onnx_cpp2py_export.def_submodule("defs");
defs.doc() = "Schema submodule";
Expand Down Expand Up @@ -394,6 +399,12 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {
func_proto.SerializeToString(&func_bytes);
}
return py::bytes(func_bytes);
})
.def(
"set_type_and_shape_inference_function",
[](OpSchema* op, const std::function<void(InferenceContext*)>& func) {
auto wrapper = [=](InferenceContext& ctx) { func(&ctx); };
return op->TypeAndShapeInferenceFunction(wrapper);
});

defs.def(
Expand Down Expand Up @@ -625,6 +636,127 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {
shape_inference.doc() = "Shape Inference submodule";
py::register_exception<InferenceError>(shape_inference, "InferenceError");

py::class_<InferenceContext> inference_ctx(shape_inference, "InferenceContext", "Inference context");

inference_ctx.def("__get_attribute", [](InferenceContext* ctx, std::string name) {
if (ctx == nullptr) {
fail_shape_inference("Internal error: `ctx` is nullptr in `__get_attribute`");
}
auto attr = ctx->getAttribute(name);
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved Hide resolved
if (attr == nullptr) {
fail_shape_inference("Internal error: `attr` is nullptr in `__get_attribute`");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not an error situation. In C++, this single method plays a dual role, both to check if an attribute is present, and to get its value. Is it possible to return None in this situation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or, just a return a C++ pointer (using pybind's conventions ... it must have some standard approach).

}
std::string data;
attr->SerializeToString(&data);
return py::bytes(data);
});
inference_ctx.def("get_num_inputs", &InferenceContext::getNumInputs);
inference_ctx.def("has_input", &InferenceContext::hasInput);
inference_ctx.def("__get_input_type", [](InferenceContext* ctx, size_t index) {
if (ctx == nullptr) {
fail_shape_inference("Internal error: `ctx` is nullptr in `__get_input_type`");
}
auto type = ctx->getInputType(index);
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved Hide resolved
if (type == nullptr) {
fail_shape_inference("Internal error: `type` is nullptr in `__get_input_type`");
}
std::string data;
type->SerializeToString(&data);
return py::bytes(data);
});
inference_ctx.def("__get_input_data", [](InferenceContext* ctx, size_t index) {
if (ctx == nullptr) {
fail_shape_inference("Internal error: `ctx` is nullptr in `__get_input_data`");
}
auto tensor = ctx->getInputData(index);
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved Hide resolved
if (tensor == nullptr) {
fail_shape_inference("Internal error: `tensor` is nullptr in `__get_input_data`");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, not an error. Signifies no static-constant-value is known for this input. Best to map this to None or some similar thing

}
std::string data;
tensor->SerializeToString(&data);
return py::bytes(data);
});
inference_ctx.def("__get_input_sparse_data", [](InferenceContext* ctx, size_t index) {
if (ctx == nullptr) {
fail_shape_inference("Internal error: `ctx` is nullptr in `__get_input_sparse_data`");
}
auto stensor = ctx->getInputSparseData(index);
if (stensor == nullptr) {
fail_shape_inference("Internal error: `stensor` is nullptr in `__get_input_sparse_data`");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above. Map to None?

}
std::string data;
stensor->SerializeToString(&data);
return py::bytes(data);
});
inference_ctx.def("__get_symbolic_input", [](InferenceContext* ctx, size_t index) {
if (ctx == nullptr) {
fail_shape_inference("Internal error: `ctx` is nullptr in `__get_symbolic_input`");
}
auto shape = ctx->getSymbolicInput(index);
if (shape == nullptr) {
fail_shape_inference("Internal error: `shape` is nullptr in `__get_symbolic_input`");
}
std::string data;
shape->SerializeToString(&data);
return py::bytes(data);
});
inference_ctx.def("__get_graph_attribute_inferencer", &InferenceContext::getGraphAttributeInferencer);
inference_ctx.def("get_num_outputs", &InferenceContext::getNumOutputs);
inference_ctx.def("__get_output_type", [](InferenceContext* ctx, size_t index) {
if (ctx == nullptr) {
fail_shape_inference("Internal error: `ctx` is nullptr in `__get_output_type`");
}
auto type = ctx->getOutputType(index);
if (type == nullptr) {
fail_shape_inference("Internal error: `type` is nullptr in `__get_output_type`");
}
std::string data;
type->SerializeToString(&data);
return py::bytes(data);
});
inference_ctx.def("__set_output_type", [](InferenceContext* ctx, size_t index, py::bytes bytes) {
if (ctx == nullptr) {
fail_shape_inference("Internal error: `ctx` is nullptr in `__set_output_type`");
}
auto type = ctx->getOutputType(index);
if (type == nullptr) {
fail_shape_inference("Internal error: `type` is nullptr in `__set_output_type`");
}
ParseProtoFromPyBytes(type, bytes);
});

py::class_<GraphInferencer> graph_inferencer(shape_inference, "GraphInferencer", "Graph Inferencer");
graph_inferencer.def(
"__do_inferencing",
[](GraphInferencer* inferencer,
const std::vector<py::bytes>& input_types,
const std::vector<py::bytes>& input_data) {
std::vector<TypeProto> type_proto;
std::vector<TensorProto> tensor_proto;
std::vector<const TypeProto*> type_inputs;
std::vector<const TensorProto*> tensor_inputs;
for (const auto& bytes : input_types) {
TypeProto proto{};
ParseProtoFromPyBytes(&proto, bytes);
type_proto.emplace_back(proto);
type_inputs.emplace_back(&type_proto.back());
}
for (const auto& bytes : input_data) {
TensorProto proto{};
ParseProtoFromPyBytes(&proto, bytes);
tensor_proto.emplace_back(proto);
tensor_inputs.emplace_back(&tensor_proto.back());
}
auto ret = inferencer->doInferencing(type_inputs, tensor_inputs);
std::vector<py::bytes> out;
for (const auto& type : ret) {
std::string data;
type->SerializeToString(&data);
out.emplace_back(py::bytes(data));
}
return out;
});

shape_inference.def(
"infer_shapes",
[](const py::bytes& bytes, bool check_type, bool strict_mode, bool data_prop) {
Expand Down
1 change: 1 addition & 0 deletions onnx/onnx_cpp2py_export/defs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class OpSchema:
) -> dict[str, bytes]: ...
@property
def function_body(self) -> FunctionProto: ...
def set_type_and_shape_inference_function(self, func) -> None: ...

class TypeConstraintParam:
def __init__(
Expand Down
32 changes: 32 additions & 0 deletions onnx/onnx_cpp2py_export/shape_inference.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,38 @@
from types import List

from onnx import AttributeProto, TypeProto, TensorProto, SparseTensorProto, TensorShapeProto

class InferenceError(Exception): ...

class GraphInferencer:
# Impl in cpp (onnx/cpp2py_export.cc)
def __do_inferencing(self, input_types: List[bytes], input_data: List[bytes]) -> List[bytes]: ...
# Impl in py (onnx/shape_inference.py)
def do_inferencing(self, input_types: List[TypeProto], input_data: List[TensorProto]) -> List[TypeProto]: ...

class InferenceContext:
# Impl in cpp (onnx/cpp2py_export.cc)
def get_num_inputs(self) -> int: ...
def has_input(self, idx: int) -> bool: ...
def get_num_outputs(self) -> int: ...
def __get_attribute(self, name: str) -> bytes: ...
def __get_input_type(self, idx: int) -> bytes: ...
def __get_input_data(self, idx: int) -> bytes: ...
def __get_input_sparse_data(self, idx: int) -> bytes: ...
def __get_symbolic_input(self, idx: int) -> bytes: ...
def __get_graph_attribute_inferencer(self) -> GraphInferencer: ...
def __get_output_type(self, idx: int) -> bytes: ...
def __set_output_type(self, idx: int, output: bytes) -> None: ...
# Impl in py (onnx/shape_inference.py)
def get_attribute(self, name: str) -> AttributeProto: ...
def get_input_type(self, idx: int) -> TypeProto: ...
def get_input_data(self, idx: int) -> TensorProto: ...
def get_input_sparse_data(self, idx: int) -> SparseTensorProto: ...
def get_symbolic_input(self, idx: int) -> TensorShapeProto: ...
def get_graph_attribute_inferencer(self) -> GraphInferencer: ...
def get_output_type(self, idx: int) -> TypeProto: ...
def set_output_type(self, idx: int, output: TypeProto) -> None: ...

def infer_shapes(
b: bytes, check_type: bool, strict_mode: bool, data_prop: bool
) -> bytes: ...
Expand Down
53 changes: 52 additions & 1 deletion onnx/shape_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,58 @@

import onnx
import onnx.onnx_cpp2py_export.shape_inference as C # noqa: N812
from onnx import AttributeProto, FunctionProto, ModelProto, TypeProto
from onnx import (
AttributeProto,
FunctionProto,
ModelProto,
SparseTensorProto,
TensorProto,
TensorShapeProto,
TypeProto,
)

GraphInferencer = C.GraphInferencer


def _do_inferencing(
self, input_types: list[TypeProto], input_data: list[TensorProto]
) -> list[TypeProto]:
input_types_bytes = [proto.SerializeToString() for proto in input_types]
input_data_bytes = [proto.SerializeToString() for proto in input_data]
ret = self.__impl.__do_inferencing(input_types_bytes, input_data_bytes)

Check warning on line 35 in onnx/shape_inference.py

View check run for this annotation

Codecov / codecov/patch

onnx/shape_inference.py#L35

Added line #L35 was not covered by tests
return [TypeProto.FromString(data) for data in ret]


GraphInferencer.do_inferencing = _do_inferencing # type: ignore


def _parse_to_proto(attr, proto_type):
def impl(self, *args, **kwargs):
data = getattr(self, attr)(*args, **kwargs)
return proto_type.FromString(data)

return impl


InferenceContext = C.InferenceContext
InferenceContext.get_attribute = _parse_to_proto("__get_attribute", AttributeProto) # type: ignore
InferenceContext.get_input_type = _parse_to_proto("__get_input_type", TypeProto) # type: ignore
InferenceContext.get_input_data = _parse_to_proto("__get_input_data", TensorProto) # type: ignore
InferenceContext.get_input_sparse_data = _parse_to_proto( # type: ignore
"__get_input_sparse_data", SparseTensorProto
)
InferenceContext.get_symbolic_input = _parse_to_proto( # type: ignore
"__get_symbolic_input", TensorShapeProto
)
InferenceContext.get_output_type = _parse_to_proto("__get_output_type", TypeProto) # type: ignore


def _op_set_output_type(self, idx: int, output: TypeProto):
data = output.SerializeToString()
self.__set_output_type(idx, data)


InferenceContext.set_output_type = _op_set_output_type # type: ignore


def infer_shapes(
Expand Down
97 changes: 97 additions & 0 deletions onnx/test/shape_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10172,6 +10172,103 @@ def test_check_type_when_schema_has_empty_io(self):
op_schema.name, op_schema.since_version, op_schema.domain
)

def test_custom_schema_shape_inference(self) -> None:
# CustomOp schema:
# attrs:
# out_len: [L0, L1, ...]
# inputs:
# a[N, La]
# b[N, Lb]
# outputs:
# out0[N, La * Lb, L0]
# out1[N, La * Lb, L1]
# ...
N = 3
La = 32
Lb = 64
out_len = [1, 2]
outs = [f"out{i}" for i in range(len(out_len))]
graph = self._make_graph(
[
("a", TensorProto.FLOAT, (N, La)),
("b", TensorProto.FLOAT, (N, Lb)),
],
[make_node("CustomOp", ["a", "b"], outs, out_len=out_len)],
[],
)
with self.assertRaises(onnx.checker.ValidationError):
self._assert_inferred(
graph,
[
make_tensor_value_info(
f"out{i}", TensorProto.FLOAT, (N, La * Lb, Li)
)
for i, Li in enumerate(out_len)
],
)

schema = OpSchema(
"CustomOp",
"",
1,
inputs=[
defs.OpSchema.FormalParameter("a", "float"),
defs.OpSchema.FormalParameter("b", "float"),
],
outputs=[
defs.OpSchema.FormalParameter(
"out", "float", param_option=OpSchema.FormalParameterOption.Variadic
),
],
attributes=[
defs.OpSchema.Attribute("out_len", defs.OpSchema.AttrType.INTS)
],
)

def func(ctx: onnx.shape_inference.InferenceContext):
def parse_tensor_input(t: TypeProto):
assert isinstance(t, TypeProto)
return (
t.tensor_type.elem_type,
[
d.dim_value if d.HasField("dim_value") else None
for d in t.tensor_type.shape.dim
],
)

assert ctx.get_num_inputs() == 2
in0 = ctx.get_input_type(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern with this is it goes through serialization to access the type information. It is not really efficient. I would change the API so that it does not return a TypeProto but the type and the shape as regular python objects.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I agree with Xavier, but I am a bit confused also. I see the method implementation serializes proto values to string and returns them. We could just return a pointer to the C++ Proto object (wrapped as a Python object). Is that your suggestion Xavier?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we closely mimic the C++ API design interface, or should we integrate Python's native types for interactions? Utilizing the proto pointer for interactions may require additional codes to bind them to Python (If there is another way please correct me), or we need to include some third-party library.

in1 = ctx.get_input_type(1)
in0_type, in0_shape = parse_tensor_input(in0)
in1_type, in1_shape = parse_tensor_input(in1)
assert in0_type == in1_type == TensorProto.FLOAT
assert len(in0_shape) == len(in1_shape) == 2
assert in0_shape[0] == in1_shape[0]
N, La = in0_shape
_, Lb = in1_shape
attr = ctx.get_attribute("out_len")
out_len = attr.ints
assert len(out_len) == ctx.get_num_outputs()
for i in range(ctx.get_num_outputs()):
out = ctx.get_output_type(i)
out.tensor_type.elem_type = in0_type
out.tensor_type.shape.dim.add().dim_value = N
out.tensor_type.shape.dim.add().dim_value = La * Lb
out.tensor_type.shape.dim.add().dim_value = out_len[i]
ctx.set_output_type(i, out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here, we should avoid serialization with something like set_output_type_and_shape(in0_type, (N, La*Lb, out_lin[i]). The type is created on C++ side, there is no serialization and it would be more efficient.


schema.set_type_and_shape_inference_function(func)
onnx.defs.register_schema(schema)

self._assert_inferred(
graph,
[
make_tensor_value_info(f"out{i}", TensorProto.FLOAT, (N, La * Lb, Li))
for i, Li in enumerate(out_len)
],
)
onnx.defs.deregister_schema(schema.name, schema.since_version, schema.domain)


if __name__ == "__main__":
unittest.main()