Skip to content

Commit

Permalink
addition of group > 1 in test and in backend
Browse files Browse the repository at this point in the history
  • Loading branch information
chachaleo committed Jan 29, 2024
1 parent 8a6b70f commit 2b8d3a3
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 3 deletions.
150 changes: 150 additions & 0 deletions onnx/backend/test/case/node/convtranspose.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,3 +395,153 @@ def export_convtranspose_autopad_same() -> None:
).astype(np.float32)

expect(node, inputs=[x, W], outputs=[y], name="test_convtranspose_autopad_same")


@staticmethod

Check warning on line 400 in onnx/backend/test/case/node/convtranspose.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/convtranspose.py#L400

Added line #L400 was not covered by tests
def export_convtranspose_group_2() -> None:
x = np.array(

Check warning on line 402 in onnx/backend/test/case/node/convtranspose.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/convtranspose.py#L402

Added line #L402 was not covered by tests
[
[
[
[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]],
[
[9.0, 10.0, 11.0], [12.0, 13.0, 14.0], [15.0, 16.0, 17.0]]
]
]
).astype(np.float32)
W = np.array(

Check warning on line 412 in onnx/backend/test/case/node/convtranspose.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/convtranspose.py#L412

Added line #L412 was not covered by tests
[
[
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
],
[
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
]
]
).astype(np.float32)

node = onnx.helper.make_node("ConvTranspose", ["X", "W"], ["Y"], group = 2)

Check warning on line 423 in onnx/backend/test/case/node/convtranspose.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/convtranspose.py#L423

Added line #L423 was not covered by tests

y = np.array(

Check warning on line 425 in onnx/backend/test/case/node/convtranspose.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/convtranspose.py#L425

Added line #L425 was not covered by tests
[
[
[
[0.0, 1.0, 3.0, 3.0, 2.0],
[3.0, 8.0, 15.0, 12.0, 7.0],
[9.0, 21.0, 36.0, 27.0, 15.0],
[9.0, 20.0, 33.0, 24.0, 13.0],
[6.0, 13.0, 21.0, 15.0, 8.0]
],
[
[ 9.0, 19.0, 30.0, 21.0, 11.0],
[21.0, 44.0, 69.0, 48.0, 25.0],
[36.0, 75.0, 117.0, 81.0, 42.0],
[27.0, 56.0, 87.0, 60.0, 31.0],
[15.0, 31.0, 48.0, 33.0, 17.0]
]
]
]
).astype(np.float32)

expect(node, inputs=[x, W], outputs=[y], name="test_convtranspose_group_2")

Check warning on line 446 in onnx/backend/test/case/node/convtranspose.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/convtranspose.py#L446

Added line #L446 was not covered by tests

@staticmethod

Check warning on line 448 in onnx/backend/test/case/node/convtranspose.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/convtranspose.py#L448

Added line #L448 was not covered by tests
def export_convtranspose_group_2_image_3() -> None:
x = np.array(

Check warning on line 450 in onnx/backend/test/case/node/convtranspose.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/convtranspose.py#L450

Added line #L450 was not covered by tests
[
[
[
[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]],
[
[9.0, 10.0, 11.0], [12.0, 13.0, 14.0], [15.0, 16.0, 17.0]
]
],
[
[
[18.0, 19.0, 20.0], [21.0, 22.0, 23.0], [24.0, 25.0, 26.0]
],
[
[9.0, 10.0, 11.0], [12.0, 13.0, 14.0], [15.0, 16.0, 17.0]
]
],
[
[
[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]
],
[
[9.0, 10.0, 11.0], [12.0, 13.0, 14.0], [15.0, 16.0, 17.0]
]
]
]
).astype(np.float32)
W = np.array(

Check warning on line 477 in onnx/backend/test/case/node/convtranspose.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/convtranspose.py#L477

Added line #L477 was not covered by tests
[
[
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
],
[
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
]
]
).astype(np.float32)

node = onnx.helper.make_node("ConvTranspose", ["X", "W"], ["Y"], group = 2)

Check warning on line 488 in onnx/backend/test/case/node/convtranspose.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/convtranspose.py#L488

Added line #L488 was not covered by tests

y = np.array(

Check warning on line 490 in onnx/backend/test/case/node/convtranspose.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/convtranspose.py#L490

Added line #L490 was not covered by tests
[
[
[
[ 0.0, 1.0, 3.0, 3.0, 2.0],
[ 3.0, 8.0, 15.0, 12.0, 7.0],
[ 9.0, 21.0, 36.0, 27.0, 15.0],
[ 9.0, 20.0, 33.0, 24.0, 13.0],
[ 6.0, 13.0, 21.0, 15.0, 8.0]
],
[
[ 9.0, 19.0, 30.0, 21.0, 11.0],
[ 21.0, 44.0, 69.0, 48.0, 25.0],
[ 36.0, 75.0, 117.0, 81.0, 42.0],
[ 27.0, 56.0, 87.0, 60.0, 31.0],
[ 15.0, 31.0, 48.0, 33.0, 17.0]
],
],
[
[
[ 18.0, 37.0, 57.0, 39.0, 20.],
[ 39.0, 80.0, 123.0, 84.0, 43.],
[ 63.0, 129.0, 198.0, 135.0, 69.],
[ 45.0, 92.0, 141.0, 96.0, 49.],
[ 24.0, 49.0, 75.0, 51.0, 26.]
],
[
[ 9.0, 19.0, 30.0, 21.0, 11.],
[ 21.0, 44.0, 69.0, 48.0, 25.],
[ 36.0, 75.0, 117.0, 81.0, 42.],
[ 27.0, 56.0, 87.0, 60.0, 31.],
[ 15.0, 31.0, 48.0, 33.0, 17.]
],
],
[
[
[ 0.0, 1.0, 3.0, 3.0, 2.],
[ 3.0, 8.0, 15.0, 12.0, 7.],
[ 9.0, 21.0, 36.0, 27.0, 15.],
[ 9.0, 20.0, 33.0, 24.0, 13.],
[ 6.0, 13.0, 21.0, 15.0, 8.]
],
[
[ 9.0, 19.0, 30.0, 21.0, 11.],
[ 21.0, 44.0, 69.0, 48.0, 25.],
[ 36.0, 75.0, 117.0, 81.0, 42.],
[ 27.0, 56.0, 87.0, 60.0, 31.],
[ 15.0, 31.0, 48.0, 33.0, 17.]
]
]
]

).astype(np.float32)

expect(node, inputs=[x, W], outputs=[y], name="test_convtranspose_group_2_image_3")

Check warning on line 544 in onnx/backend/test/case/node/convtranspose.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/convtranspose.py#L544

Added line #L544 was not covered by tests



30 changes: 27 additions & 3 deletions onnx/reference/ops/op_conv_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,32 @@ def _run( # type: ignore
res += B[c]
final[image_id, c, ...] = res[...]
else:
raise NotImplementedError(
f"Implementation for group={group} > 1 is not available yet."
)
final = np.zeros((X.shape[0], num_output_channels ) + tuple(output_shape))
output_array = []

Check warning on line 107 in onnx/reference/ops/op_conv_transpose.py

View check run for this annotation

Codecov / codecov/patch

onnx/reference/ops/op_conv_transpose.py#L106-L107

Added lines #L106 - L107 were not covered by tests

for group_id in range(group):
group_X = X[:, group_id * C // group : (group_id + 1) * C // group, ...]
group_W = W[group_id * num_output_channels // group : (group_id + 1) * num_output_channels // group, ...]

Check warning on line 111 in onnx/reference/ops/op_conv_transpose.py

View check run for this annotation

Codecov / codecov/patch

onnx/reference/ops/op_conv_transpose.py#L110-L111

Added lines #L110 - L111 were not covered by tests

group_output = self._run(

Check warning on line 113 in onnx/reference/ops/op_conv_transpose.py

View check run for this annotation

Codecov / codecov/patch

onnx/reference/ops/op_conv_transpose.py#L113

Added line #L113 was not covered by tests
group_X,
group_W,
B=B,
auto_pad=auto_pad,
dilations=dilations,
group=1,
kernel_shape=kernel_shape,
output_padding=output_padding,
output_shape=output_shape,
pads=pads,
strides=strides,
)
group_output = np.array(group_output[0])
output_array.append(group_output)

Check warning on line 127 in onnx/reference/ops/op_conv_transpose.py

View check run for this annotation

Codecov / codecov/patch

onnx/reference/ops/op_conv_transpose.py#L126-L127

Added lines #L126 - L127 were not covered by tests

for image_id in range(X.shape[0]):
for group_id in range(group):
group_output = output_array[group_id]
final[image_id, group_id:(group_id+1), ...] = group_output[image_id, ...]

Check warning on line 132 in onnx/reference/ops/op_conv_transpose.py

View check run for this annotation

Codecov / codecov/patch

onnx/reference/ops/op_conv_transpose.py#L131-L132

Added lines #L131 - L132 were not covered by tests

return (final.astype(X.dtype),) # type: ignore[union-attr]

0 comments on commit 2b8d3a3

Please sign in to comment.