Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support set schema inference function in python #5940

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

OYCN
Copy link
Contributor

@OYCN OYCN commented Feb 16, 2024

Description

Support set schema inference callback function for custion OpSchema before register in python.

example:

def func(ctx: onnx.shape_inference.InferenceContext):
    # get some info
    assert ctx.get_num_inputs() == 2
    value = ctx.get_input_type(0)
    ...
    # get or create output proto object
    output = ctx.get_output_type(0)
    # set type or shape
    ...
    # set the result proto
    ctx.set_output_type(0, output)

schema.set_type_and_shape_inference_function(func)

Note

Depends on #5906

Motivation and Context

Follow up of #5019

Signed-off-by: opluss <opluss@qq.com>
@OYCN
Copy link
Contributor Author

OYCN commented Feb 16, 2024

We're spending a considerable amount of code on passing proto objects between C++ and Python. However, the repository seems well-prepared for this feature.

Signed-off-by: opluss <opluss@qq.com>
Copy link

codecov bot commented Feb 16, 2024

Codecov Report

Attention: Patch coverage is 93.54839% with 4 lines in your changes are missing coverage. Please review.

Project coverage is 57.11%. Comparing base (83194ed) to head (01f955d).
Report is 13 commits behind head on main.

Files Patch % Lines
onnx/shape_inference.py 83.33% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #5940      +/-   ##
==========================================
+ Coverage   56.95%   57.11%   +0.15%     
==========================================
  Files         506      506              
  Lines       30467    30544      +77     
  Branches     4592     4602      +10     
==========================================
+ Hits        17353    17445      +92     
+ Misses      12285    12273      -12     
+ Partials      829      826       -3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

onnx/cpp2py_export.cc Fixed Show resolved Hide resolved
onnx/cpp2py_export.cc Fixed Show resolved Hide resolved
onnx/cpp2py_export.cc Fixed Show resolved Hide resolved
Signed-off-by: oPluss <opluss@qq.com>
add check in `InferenceContext` py wrapper

Signed-off-by: opluss <opluss@qq.com>
Signed-off-by: oPluss <opluss@qq.com>
Signed-off-by: opluss <opluss@qq.com>
Signed-off-by: opluss <opluss@qq.com>
@OYCN
Copy link
Contributor Author

OYCN commented Mar 29, 2024

Hi @justinchuby ,

In this PR, we can implement shape inference on the Python side, similar to how it's done on the C++ side. If you have any suggestions for this implementation, I'm open to making adjustments accordingly.

@justinchuby justinchuby marked this pull request as ready for review March 29, 2024 18:38
@justinchuby justinchuby requested a review from a team as a code owner March 29, 2024 18:38
@justinchuby
Copy link
Contributor

Thank you! Is it ready to be reviewed?

@OYCN OYCN changed the title [WIP] Support set schema inference function in python Support set schema inference function in python Mar 31, 2024
@OYCN
Copy link
Contributor Author

OYCN commented Mar 31, 2024

Thank you! Is it ready to be reviewed?

Yes, I have removed the 'WIP' prefix from the title. Please feel free to leave any comments. :D

@justinchuby justinchuby self-assigned this Apr 1, 2024
)

assert ctx.get_num_inputs() == 2
in0 = ctx.get_input_type(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern with this is it goes through serialization to access the type information. It is not really efficient. I would change the API so that it does not return a TypeProto but the type and the shape as regular python objects.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I agree with Xavier, but I am a bit confused also. I see the method implementation serializes proto values to string and returns them. We could just return a pointer to the C++ Proto object (wrapped as a Python object). Is that your suggestion Xavier?

out.tensor_type.shape.dim.add().dim_value = N
out.tensor_type.shape.dim.add().dim_value = La * Lb
out.tensor_type.shape.dim.add().dim_value = out_len[i]
ctx.set_output_type(i, out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here, we should avoid serialization with something like set_output_type_and_shape(in0_type, (N, La*Lb, out_lin[i]). The type is created on C++ side, there is no serialization and it would be more efficient.

@xadupre
Copy link
Contributor

xadupre commented Apr 9, 2024

It would be a nice feature to have.

@@ -114,6 +116,9 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {
#endif // ONNX_ML
);

// Avoid Segmentation fault if we not free the python function in Custom Schema
onnx_cpp2py_export.add_object("_cleanup", py::capsule([] { OpSchemaRegistry::OpSchemaDeregisterAll(); }));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify when this gets invoked?

}
auto attr = ctx->getAttribute(name);
if (attr == nullptr) {
fail_shape_inference("Internal error: `attr` is nullptr in `__get_attribute`");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not an error situation. In C++, this single method plays a dual role, both to check if an attribute is present, and to get its value. Is it possible to return None in this situation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or, just a return a C++ pointer (using pybind's conventions ... it must have some standard approach).

}
auto tensor = ctx->getInputData(index);
if (tensor == nullptr) {
fail_shape_inference("Internal error: `tensor` is nullptr in `__get_input_data`");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, not an error. Signifies no static-constant-value is known for this input. Best to map this to None or some similar thing

}
auto stensor = ctx->getInputSparseData(index);
if (stensor == nullptr) {
fail_shape_inference("Internal error: `stensor` is nullptr in `__get_input_sparse_data`");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above. Map to None?

@gramalingam
Copy link
Contributor

Thanks for creating the PR! It would be great to add this functionality. My comments above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: In progress
Development

Successfully merging this pull request may close these issues.

None yet

4 participants