Skip to content

Commit

Permalink
Fix conversion from split 11 to split 18. (#6020)
Browse files Browse the repository at this point in the history
Addresses #5222 

There does not exist a Split 12 from what I can see. I believe whomever
made the conversion for Split 12 to Split 13 made a mistake and intended
to do so for Split 11 to Split 13?

IIUC, the conversions should chain. Running this with my built ONNX
changes seems to work fine

```
>>> import onnx
rt onnxruntime
model = onnx.load('simple_model_split.onnx')
model = onnx.version_converter.convert_version(model, 18)
onnx.checker.check_model(model)  # success
sess = onnxruntime.InferenceSession(model.SerializeToString())  # fail>>> import onnxruntime
>>> model = onnx.load('simple_model_split.onnx')
>>> model = onnx.version_converter.convert_version(model, 18)
>>> onnx.checker.check_model(model)  # success
>>> sess = onnxruntime.InferenceSession(model.SerializeToString())  # fail
>>>
>>> sess
<onnxruntime.capi.onnxruntime_inference_collection.InferenceSession object at 0x7f82dcce0a30>
>>>

>>> import onnxruntime as rt
>>> import numpy as np
>>> model_serialized = model.SerializeToString()
>>> sess = rt.InferenceSession(model_serialized)
>>> dummy_input = np.random.randn(1, 1, 28, 28).astype(np.float32)
 Perfo>>>
>>> outputs = sess.run(output_names=['output1', 'output2'],
...                    input_feed={'input': dummy_input})

>>> print(outputs)
[array([[[[ 4.02100086e-01,  1.91692543e+00,  5.60453236e-01, ...,
           1.10894358e+00, -9.10616279e-01,  2.63174176e-01],
         [ 9.16809678e-01,  8.39324117e-01,  1.30101681e+00, ...,
          -6.86384559e-01,  8.50693703e-01, -2.23367453e-01],
         [ 9.90843654e-01,  1.80581224e+00, -6.30877018e-01, ...,
          -4.17214930e-02,  9.42695260e-01, -9.54421163e-01],
```

Inferencing seems to work. I should remove the old split 12-> 13
converters but I would like confirmation from a contributor before I
move forward.

---------

Signed-off-by: Danial Javady <danialjavady96@gmail.com>
  • Loading branch information
ZelboK committed Mar 15, 2024
1 parent 6eb45c0 commit 9cc907f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
19 changes: 19 additions & 0 deletions onnx/test/version_converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1439,6 +1439,25 @@ def test_split_13_12(self) -> None:
assert converted_model.graph.node[0].op_type == "Split"
assert converted_model.opset_import[0].version == 12

def test_split_with_optional_input(self) -> None:

nodes = [helper.make_node("Split", ["X"], ["Y1", "Y2"], axis=1)]
graph = helper.make_graph(
nodes,
"test_split_optional_input",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (6,))],
[
helper.make_tensor_value_info("Y1", TensorProto.FLOAT, (3,)),
helper.make_tensor_value_info("Y2", TensorProto.FLOAT, (3,)),
],
)
converted_model = self._converted(graph, helper.make_operatorsetid("", 12), 18)

assert converted_model.graph.node[0].op_type == "Split"
assert converted_model.opset_import[0].version == 18

assert len(converted_model.graph.node[0].output) == 2

# Test Split Adapter: 12 -> 13
def test_split_12_13(self) -> None:
nodes = [helper.make_node("Split", ["X"], ["Y1", "Y2"], split=[2, 3])]
Expand Down
2 changes: 1 addition & 1 deletion onnx/version_converter/adapters/split_17_18.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Split_17_18 : public Adapter {

void adapt_split_17_18(std::shared_ptr<Graph>, Node* node) const {
const auto num_outputs = node->outputs().size();
SetAttribute(knum_outputs, num_outputs);
node->i_(knum_outputs, num_outputs);
}

Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
Expand Down

0 comments on commit 9cc907f

Please sign in to comment.