Skip to content

Commit

Permalink
Add Swish operator
Browse files Browse the repository at this point in the history
Signed-off-by: isdanni <leedanni@gmail.com>
  • Loading branch information
isdanni committed Mar 18, 2024
1 parent 9cc907f commit e38a91a
Show file tree
Hide file tree
Showing 24 changed files with 107 additions and 15 deletions.
30 changes: 30 additions & 0 deletions docs/Changelog.md
Expand Up @@ -25735,6 +25735,36 @@ This version of the operator has been available since version 21 of the default
<dd>Constrain output to int64 tensor.</dd>
</dl>

### <a name="Swish-21"></a>**Swish-21**</a>

Swish function, takes one input data (Tensor<T>) and produces one output data (Tensor<T>) of the same shape, where $Swish(x) = x * sigmoid(beta * x)$.

#### Version

This version of the operator has been available since version 21 of the default ONNX operator set.

#### Inputs

<dl>
<dt><tt>X</tt> (differentiable) : T</dt>
<dd>Input tensor</dd>
</dl>

#### Outputs

<dl>
<dt><tt>Y</tt> (differentiable) : T</dt>
<dd>Output tensor</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16), tensor(float), tensor(double)</dt>
<dd>Constrain input and output types to float tensors.</dd>
</dl>


### <a name="Size-21"></a>**Size-21**</a>

Takes a tensor as input and outputs a int64 scalar that equals to the total number of elements of the input tensor.
Expand Down
1 change: 1 addition & 0 deletions docs/Operators.md
Expand Up @@ -141,6 +141,7 @@ For an operator input/output's differentiability, it can be differentiable,
|<a href="#Sign">Sign</a>|<a href="Changelog.md#Sign-13">13</a>, <a href="Changelog.md#Sign-9">9</a>|
|<a href="#Sin">Sin</a>|<a href="Changelog.md#Sin-7">7</a>|
|<a href="#Sinh">Sinh</a>|<a href="Changelog.md#Sinh-9">9</a>|
|<a href="#Swish">Swish</a>|<a href="Changelog.md#Swish-21">21</a>|
|<a href="#Size">Size</a>|<a href="Changelog.md#Size-21">21</a>, <a href="Changelog.md#Size-19">19</a>, <a href="Changelog.md#Size-13">13</a>, <a href="Changelog.md#Size-1">1</a>|
|<a href="#Slice">Slice</a>|<a href="Changelog.md#Slice-13">13</a>, <a href="Changelog.md#Slice-11">11</a>, <a href="Changelog.md#Slice-10">10</a>, <a href="Changelog.md#Slice-1">1</a>|
|<a href="#SpaceToDepth">SpaceToDepth</a>|<a href="Changelog.md#SpaceToDepth-13">13</a>, <a href="Changelog.md#SpaceToDepth-1">1</a>|
Expand Down
29 changes: 29 additions & 0 deletions onnx/backend/test/case/node/swish.py
@@ -0,0 +1,29 @@
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0

import numpy as np

import onnx
from onnx.backend.test.case.base import Base
from onnx.backend.test.case.node import expect


def swish(x: np.ndarray, beta: np.float16) -> np.ndarray:
beta = 1.0
return x * 1 / (1 + np.exp(np.negative(x * beta)))


class Swish(Base):
@staticmethod
def export() -> None:
node = onnx.helper.make_node(
"Swish",
inputs=["x"],
outputs=["y"],
beta=1.0,
)

x = np.array([3, 4, 5]).astype(np.float32)
y = swish(x, beta=1.0)
expect(node, inputs=[x], outputs=[y], name="test_swish")
Binary file not shown.
@@ -1 +1 @@
ByJ���@D{,@�#@��@�8@��&@w�@��8@Zm=@ @�22@*�@�@>;@���?ZT�?~/?�5@�81@Վ7@�P>@%�2@��@�d1@\<�?�N&@~2�?�G<@�6@�@���?��0@�@��@��?j[$@��#@sK$@>9<@o�)@��@��@+@V�?5�(@a�(@յ�?(ճ?�V@�k@��@\@��>@�^�?��?O��?�n'@��?<<@YZ�?
ByJ���@D{,@�#@��@�8@��&@w�@��8@[m=@ @�22@*�@�@>;@���?ZT�?~/?�5@�81@Վ7@�P>@&�2@��@�d1@\<�?�N&@}2�?�G<@�6@�@���?��0@�@��@��?j[$@��#@sK$@>9<@o�)@��@��@+@V�?4�(@a�(@յ�?'ճ?�V@�k@��@\@��>@�^�?��?O��?�n'@��?<<@YZ�?
Expand Down
@@ -1 +1 @@
ByJ�ĸ?PL?*�%?�?���>��3?J��>���?8s�?l{�>��i?�? �?�_�?���=ت�=�<��{?#Hd?��?���?fm?C��>F@e?[��=��1?M>`H�?T� ?���>��>Z�b?%��>$�?c�<um*?�(?3*?��?D@?{H�>]��>k�E?9�v=;�:?q-<?�Y>�c>N�>X��>�S?/x�>L��?)Z�=�yW>��%>�06?�>a�>��|>
ByJ�ø?OL?(�%?�?���>��3?J��>���?8s�?l{�>��i?�? �?�_�?���=ت�=�<��{?"Hd?��?���?fm?C��>E@e?[��=��1?M>`H�?T� ?���>��>Z�b?%��>$�?c�<tm*?�(?3*?��?B@?{H�>]��>k�E?9�v=:�:?p-<?�Y>�c>N�>X��>�S?/x�>L��?)Z�=�yW>��%>�06?�>a�>��|>
Expand Down
@@ -1 +1 @@
ByJ�ɚ�?���>%�]?,��?)��?�~]��{X?ng��ӽ'��>��>���?A�3?b��=�c�>�ѧ><��?��P�ȝ>!<F��Wտ�??>H?��/��q�?�����\;=�>�W=�?��?� >@G�>�L�����\����t>!��?M=�?���v����Pj��$��Q馿��?����I�پ�P��� 7?LϠ�D<X��4N�4u�>bQ��-s����漎�>� �=���>���ߵ�
ByJ�ɚ�?���>%�]?,��?(��?�~]��{X?ng��ӽ'��>��>���?A�3?b��=�c�>�ѧ><��?��P�ȝ>!<F��Wտ�??>H?��/��q�?�����\;=�>�V=�?��?� >AG�>�L�����\����t>!��?M=�?���v����Pj��$��P馿��?����J�پ�P��� 7?LϠ�D<X��4N�4u�>cQ��-s����漎�>� �=���>���ߵ�
Expand Down
@@ -1 +1 @@
ByJ�#�?~��>�OF?V�?�"�?�F�܂B?~��p�ҽ�z�>4~>|�w?��&?���=���>��>� {?�4O��W�>=�4��F���3?{6?�r#�z�?m�w��K;=�{=�}#~?�%y?j>��>��9�:8���l����>�fc?0�`?�3���N��� O�u��4���m�?jS�fӾH�e�c.)?���d�V�B�:�(�>Q��'8^����s3�>��=�b�>ٵ�t*��
ByJ�#�?��>�OF?V�?�"�?�F�܂B?~��p�ҽ�z�>4~>|�w?��&?���=���>��>� {?�4O��W�>=�4��F���3?{6?�r#�z�?m�w��K;=�{=�}#~?�%y?j>��>��9�98���l����>�fc?0�`?�3���N��� O�u��4���m�?jS�fӾH�e�c.)?���d�V�A�:�(�>Q��&8^����s3�>��=�b�>ٵ�t*��
Expand Down
@@ -1,2 +1,2 @@
ByJ���?q�e?̍2?eo?�~�>+�D?�@�>B�?�W�?���>���?�?�%?/9�?0��=��=Ԧ�<�*�?5�?���?C�@�R�?���>��?w]�=�B?��>
��??1?���>���>��?!�>;,%?��<Ϟ8?W6?�T8?۸�?;U?��>���>��\?��v=�N?��O?��Z>g�>(3�>�&�>)�%?��>��$@���=�Y>�&>��G?��>T^?�v>
ByJ���?q�e?̍2?eo?�~�>+�D?�@�>B�?�W�?���>���?�?�%?/9�?0��=��=զ�<�*�?5�?���?C�@�R�?���>��?w]�=�B?��>
��??1?���>���>��?!�>;,%?��<Ξ8?W6?�T8?ܸ�?;U?��>���>��\?��v=�N?��O?��Z>h�>'3�>�&�>)�%?��>��$@���=�Y>�&>��G?��>T^?�v>
Expand Down
@@ -1 +1 @@
ByJ��<@@�b�?C\�?!!�@'T@�%�?�?�?x�?���?9�?�T�?gz@1�?��?<ь?)1�?��@��?�R�??��?���@LT�?��?]�?_|�@w}@P"�?�?�?U@�r@&��?C�?ص?�]l@@Ӈ?S��?@��?q9�?ٸ�?��?��?� @�6@H�e@���?�z�?�J�?��?2'@P�?^׶?T��?��?��?�?/�?�H�?l�?_��?���?
ByJ��<@@�b�?B\�?!!�@'T@�%�?�?�?x�?���?9�?�T�?gz@1�?��?;ь?)1�?��@��?�R�?>��?���@LT�?��?]�?^|�@w}@P"�?�?�?U@�r@%��?C�?ص?�]l@@Ӈ?S��?@��?q9�?ظ�?��?��?� @�6@G�e@���?�z�?�J�?��?2'@Q�?^׶?T��?��?��?�?/�?�H�?l�?_��?���?
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,2 +1,2 @@
 BreducedJ0�eMIz�@�&~X@}��U��忎c;��^@n6���"@��
 BreducedJ0�eMIz�@�&~X@{��U��忎c;��^@n6���"@��
�1�?
@@ -1,2 +1,2 @@
 BreducedJ0�eMIz�@�&~X@}��U��忎c;��^@n6���"@��
 BreducedJ0�eMIz�@�&~X@{��U��忎c;��^@n6���"@��
�1�?
@@ -1,2 +1,2 @@
 BreducedJ0�eMIz�@�&~X@}��U��忎c;��^@n6���"@��
 BreducedJ0�eMIz�@�&~X@{��U��忎c;��^@n6���"@��
�1�?
@@ -1,2 +1,2 @@
 BreducedJ0�eMIz�@�&~X@}��U��忎c;��^@n6���"@��
 BreducedJ0�eMIz�@�&~X@{��U��忎c;��^@n6���"@��
�1�?
@@ -1,2 +1,2 @@
 BreducedJ0�eMIz�@�&~X@}��U��忎c;��^@n6���"@��
 BreducedJ0�eMIz�@�&~X@{��U��忎c;��^@n6���"@��
�1�?
@@ -1,2 +1,2 @@
 BreducedJ0�eMIz�@�&~X@}��U��忎c;��^@n6���"@��
 BreducedJ0�eMIz�@�&~X@{��U��忎c;��^@n6���"@��
�1�?
@@ -1 +1 @@
ByJ�$E5@gd�>�B�?̹�@�,J@����6��?����ӽ�.�>�>��@//V?F��=E��>w�>�b@�S��>�7v�H��a3?��y?6�O��-�@��~;=,�@�uH@W�@DM>cD�>�)����c����"� >0R�?^��?�N˾�%��`2���O��ru*��\@g1����ݸͿ��[?�W�5~[�\�����>��¼�>����>�S�=�<�>V-��ҽ�
ByJ�%E5@hd�>�B�?͹�@�,J@����6��?����ӽ�.�>�>��@./V?F��=E��>w�>�b@�S��>�7v�H��a3?��y?6�O��-�@��~;=,�@�uH@W�@DM>cD�>�)����c����"� >0R�?_��?�N˾�%��`2���O��ru*��\@f1����ݸͿ��[?�W�6~[�\�����>��¼�>����>�S�=�<�>V-��ҽ�
Expand Down
@@ -1 +1 @@
ByJ�R������>�R�?=����JQ�f�����?-/�j%ԽF��>�>��A[�s?pm�=�u�>5z�>&PAzU����>s꒿��*?�D? �?��j��I�����̎;=]�A��T�A�0A��>�h�>�P��Q@@!���6k!>��4@��%@��оD����e޿������@�> �7���RmB���{?Y~�Ap1]��Ο�͛�>pw�՟������>�m�=�ğ>xZ<�
ByJ�R������>�R�?=����JQ�f�����?-/�k%ԽF��>��>��A[�s?pm�=�u�>5z�>&PAzU����>s꒿��*?�D?!�?��j��I�����̎;=]�A��T�A�0A��>�h�>�P��P@@!���6k!>��4@��%@��оD����e޿������@�> �7���RmB���{?Y~�Ap1]��Ο�̛�>pw�ԟ������>�m�=�ğ>xZ<�~
Expand Down
27 changes: 27 additions & 0 deletions onnx/defs/math/defs.cc
Expand Up @@ -644,6 +644,33 @@ ONNX_OPERATOR_SET_SCHEMA(
.SetContextDependentFunctionBodyBuilder(BuildContextDependentFunctionBodyGelu)
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));

static const char* Swish_ver21_doc = R"DOC(
Swish function takes one input data (Tensor<T>) and produces one output data (Tensor<T>) of the same shape,
where $Swish(x) = x * sigmoid(beta * x)$.
)DOC";

ONNX_OPERATOR_SET_SCHEMA(
Swish,
21,
OpSchema()
.Attr("beta", "Value of beta.", AttributeProto::FLOAT, 1.0f)
.SetDoc(Swish_ver21_doc)
.Input(0, "X", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.Output(0, "Y", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(bfloat16)"},
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)
.FunctionBody(
R"ONNX(
{
S_X = Sigmoid<beta = 1.0>(X)
Y = Mul (X, S_X)
}
)ONNX",
21));

static const char* Exp_ver13_doc = R"DOC(
Calculates the exponential of the given input tensor, element-wise.
)DOC";
Expand Down
2 changes: 2 additions & 0 deletions onnx/defs/operator_sets.h
Expand Up @@ -1157,6 +1157,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 21, Size);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 21, Squeeze);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 21, Transpose);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 21, Unsqueeze);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 21, Swish);

// Iterate over schema from ai.onnx version 21
class OpSet_Onnx_ver21 {
Expand All @@ -1182,6 +1183,7 @@ class OpSet_Onnx_ver21 {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 21, Squeeze)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 21, Transpose)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 21, Unsqueeze)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 21, SiLU)>());
}
};

Expand Down
3 changes: 3 additions & 0 deletions onnx/test/version_converter/automatic_upgrade_test.py
Expand Up @@ -1196,6 +1196,9 @@ def test_Sigmoid(self) -> None:
def test_Sign(self) -> None:
self._test_op_upgrade("Sign", 9)

def test_Swish(self) -> None:
self._test_op_upgrade("Swish", 21)

def test_Sinh(self) -> None:
self._test_op_upgrade("Sinh", 9)

Expand Down

0 comments on commit e38a91a

Please sign in to comment.