Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Swish operator #5964

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/Changelog.md
Expand Up @@ -28712,4 +28712,3 @@ This version of the operator has been available since version 1 of the 'ai.onnx.
<dt><tt>T3</tt> : tensor(float), tensor(double)</dt>
<dd>Constrain input types to float tensors.</dd>
</dl>

1 change: 0 additions & 1 deletion docs/Operators.md
Expand Up @@ -36187,4 +36187,3 @@ expect(

</details>


29 changes: 29 additions & 0 deletions onnx/backend/test/case/node/swish.py
@@ -0,0 +1,29 @@
# Copyright (c) ONNX Project Contributors
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
#
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

Check warning on line 4 in onnx/backend/test/case/node/swish.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/swish.py#L4

Added line #L4 was not covered by tests

import numpy as np

Check warning on line 6 in onnx/backend/test/case/node/swish.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/swish.py#L6

Added line #L6 was not covered by tests

import onnx
from onnx.backend.test.case.base import Base
from onnx.backend.test.case.node import expect

Check warning on line 10 in onnx/backend/test/case/node/swish.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/swish.py#L8-L10

Added lines #L8 - L10 were not covered by tests


def swish(x: np.ndarray, alpha: np.float16) -> np.ndarray:
return x * 1 / (1 + np.exp(np.negative(x * alpha)))

Check warning on line 14 in onnx/backend/test/case/node/swish.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/swish.py#L13-L14

Added lines #L13 - L14 were not covered by tests


class Swish(Base):
@staticmethod

Check warning on line 18 in onnx/backend/test/case/node/swish.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/swish.py#L17-L18

Added lines #L17 - L18 were not covered by tests
def export() -> None:
node = onnx.helper.make_node(

Check warning on line 20 in onnx/backend/test/case/node/swish.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/swish.py#L20

Added line #L20 was not covered by tests
"Swish",
inputs=["x"],
outputs=["y"],
)

x = np.array([3, 4, 5]).astype(np.float32)
y = swish(x, alpha=1.0)

Check warning on line 27 in onnx/backend/test/case/node/swish.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/swish.py#L26-L27

Added lines #L26 - L27 were not covered by tests

expect(node, inputs=[x], outputs=[y], name="test_swish")

Check warning on line 29 in onnx/backend/test/case/node/swish.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/swish.py#L29

Added line #L29 was not covered by tests
Binary file modified onnx/backend/test/data/node/test_acos/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_acos_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_acosh/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_acosh_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_asin/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_asin_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_asinh/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_asinh_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_atan/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_atan_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_atanh/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_atanh_example/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_bernoulli/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_bernoulli_double/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_bernoulli_expanded/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_bernoulli_seed/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_convtranspose/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_convtranspose_1d/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_convtranspose_3d/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_convtranspose_pad/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_convtranspose_pads/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cos/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cos_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cosh/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cosh_example/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1 +1,2 @@
*B x_zero_point
*B
zero_point
@@ -1 +1,2 @@
*B x_zero_point
*B
zero_point
Binary file modified onnx/backend/test/data/node/test_det_2d/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_det_nd/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_dropout_default/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_elu/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_elu_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_elu_example/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_eyelike_with_dtype/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_globalaveragepool/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_globalmaxpool/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_gridsample/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_gridsample_bicubic/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_gridsample_nearest/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_gru_batchwise/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_gru_defaults/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_gru_seq_length/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_hardsigmoid/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_hardswish/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_hardswish_expanded/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_1d_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_2d_default/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_2d_pads/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_2d_strides/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_3d_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lstm_batchwise/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lstm_defaults/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_1d_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_2d_ceil/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_2d_default/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_2d_pads/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_2d_strides/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_2d_uint8/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_3d_default/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_mish/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_mish_expanded/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_nllloss_NC/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_nllloss_NCd1/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_nllloss_NCd1_ii/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_nllloss_NCd1d2/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1 +1,2 @@
*B y_zero_point
*B
zero_point
@@ -1 +1,2 @@
*B y_zero_point
*B
zero_point
Binary file modified onnx/backend/test/data/node/test_rnn_seq_length/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_roialign_mode_max/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_round/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_selu/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_selu_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_selu_example/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_sin/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_sin_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_sinh/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_sinh_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_softplus/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_softplus_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_softsign/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_softsign_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_tan/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_tan_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_thresholdedrelu/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_training_dropout/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
29 changes: 29 additions & 0 deletions onnx/defs/math/defs.cc
Expand Up @@ -621,6 +621,35 @@ ONNX_OPERATOR_SET_SCHEMA(
.SetContextDependentFunctionBodyBuilder(BuildContextDependentFunctionBodyGelu)
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));

static const char* Swish_ver22_doc = R"DOC(
Swish function takes one input data (Tensor<T>) and produces one output data (Tensor<T>) of the same shape,
where $Swish(x) = x * sigmoid(beta * x)$.
)DOC";

ONNX_OPERATOR_SET_SCHEMA(
Swish,
22,
OpSchema()
.SetDoc(Swish_ver22_doc)
.Input(0, "X", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.Output(0, "Y", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.TypeConstraint(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

beta is missing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are using SiLU I'm assuming beta is 1.0? I can add it tho

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice HardSwish uses alpha and beta in the form alpha*x + beta ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on #5964 (comment) I will add beta and rename the operator to swish, thanks for reviewing this!

"T",
{"tensor(float16)", "tensor(float)", "tensor(bfloat16)"},
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)
.FunctionBody(
R"ONNX(
{
Alpha = Constant <value_float: float = @alpha>()
AlphaCast = CastLike (Alpha, X)
AlphaMulX = Mul (AlphaCast, X)
SigmoidAlphaMulX = Sigmoid(AlphaMulX)
Y = Mul (X, SigmoidAlphaMulX)
}
)ONNX",
22));

static const char* Exp_ver13_doc = R"DOC(
Calculates the exponential of the given input tensor, element-wise.
)DOC";
Expand Down
2 changes: 2 additions & 0 deletions onnx/defs/operator_sets.h
Expand Up @@ -1234,6 +1234,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, RNN);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, GRU);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, LSTM);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, GridSample);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, Swish);

// Iterate over schema from ai.onnx version 22
class OpSet_Onnx_ver22 {
Expand Down Expand Up @@ -1287,6 +1288,7 @@ class OpSet_Onnx_ver22 {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, GRU)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, LSTM)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, GridSample)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, Swish)>());
}
};

Expand Down
3 changes: 3 additions & 0 deletions onnx/test/version_converter/automatic_upgrade_test.py
Expand Up @@ -1301,6 +1301,9 @@ def test_Sum(self) -> None:
attrs={"consumed_inputs": [0]},
)

def test_Swish(self) -> None:
self._test_op_upgrade("Swish", 22)

def test_Tanh(self) -> None:
self._test_op_upgrade("Tanh", 1, attrs={"consumed_inputs": [0]})

Expand Down