Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Split-20 - change uneven split behavior to be more torch-like #5321

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
2 changes: 1 addition & 1 deletion onnx/defs/tensor/defs.cc
Expand Up @@ -676,7 +676,7 @@
between the number of elements in the output tensors is minimized, which
results in the last few dimensions' sizes being lowered by one.
If 'minimize_diff' is set to 'false', the last output tensors will be emptied in order to
accomodate for the missing values.

Check warning on line 679 in onnx/defs/tensor/defs.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "accomodate" is a misspelling of "accommodate" Raw Output: ./onnx/defs/tensor/defs.cc:679:0: "accomodate" is a misspelling of "accommodate"
)DOC";

ONNX_OPERATOR_SET_SCHEMA(
Expand Down Expand Up @@ -753,7 +753,7 @@
if (ctx.hasInput(1) && num_outputs_attr) {
fail_shape_inference("Both 'split' input and 'num_outputs' attribute were given");
}
if (ctx.hasInput(1)) { // 'split' is input

Check warning on line 756 in onnx/defs/tensor/defs.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnx/defs/tensor/defs.cc:756: At least two spaces is best between code and comments [whitespace/comments] [2]
auto split_proto = ctx.getInputData(1);
if (split_proto == nullptr) {
// skip if split is not an initializer
Expand All @@ -778,7 +778,7 @@
}
} else { // no value available for 'split'
if (num_outputs_attr) {
const int num_outputs = num_outputs_attr->i();
const long int num_outputs = num_outputs_attr->i();

Check warning on line 781 in onnx/defs/tensor/defs.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Use int16/int64/etc, rather than the C type long [runtime/int] [4] Raw Output: onnx/defs/tensor/defs.cc:781: Use int16/int64/etc, rather than the C type long [runtime/int] [4]
p-wysocki marked this conversation as resolved.
Show resolved Hide resolved
if (num_outputs < 1) {
fail_shape_inference("Attribute `num_outputs` value cannot be lower than 1");
}
Expand All @@ -790,7 +790,7 @@
int chunk_size = split_dim_value / num_outputs;
if (minimize_diff) {
int reduced_dims = num_outputs * (chunk_size + 1) - split_dim_value;
for (int i=0; i<num_outputs-reduced_dims; i++) {

Check warning on line 793 in onnx/defs/tensor/defs.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Missing spaces around < [whitespace/operators] [3] Raw Output: onnx/defs/tensor/defs.cc:793: Missing spaces around < [whitespace/operators] [3]
split.push_back(chunk_size+1);
}
while (split.size() <= num_outputs) {
Expand Down