Skip to content

Commit

Permalink
Add Swish operator
Browse files Browse the repository at this point in the history
Signed-off-by: isdanni <leedanni@gmail.com>
  • Loading branch information
isdanni committed Apr 26, 2024
1 parent 4e7289d commit 2633a85
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 0 deletions.
29 changes: 29 additions & 0 deletions docs/Changelog.md
Expand Up @@ -25735,6 +25735,35 @@ This version of the operator has been available since version 21 of the default
<dd>Constrain output to int64 tensor.</dd>
</dl>

### <a name="Swish-22"></a>**Swish-22**</a>

Swish function, takes one input data (Tensor<T>) and produces one output data (Tensor<T>) of the same shape, where $Swish(x) = x * sigmoid(beta * x)$.

#### Version

This version of the operator has been available since version 21 of the default ONNX operator set.

#### Inputs

<dl>
<dt><tt>X</tt> (differentiable) : T</dt>
<dd>Input tensor</dd>
</dl>

#### Outputs

<dl>
<dt><tt>Y</tt> (differentiable) : T</dt>
<dd>Output tensor</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16), tensor(float), tensor(double)</dt>
<dd>Constrain input and output types to float tensors.</dd>
</dl>

### <a name="Size-21"></a>**Size-21**</a>

Takes a tensor as input and outputs a int64 scalar that equals to the total number of elements of the input tensor.
Expand Down
1 change: 1 addition & 0 deletions docs/Operators.md
Expand Up @@ -142,6 +142,7 @@ For an operator input/output's differentiability, it can be differentiable,
|<a href="#Sin">Sin</a>|<a href="Changelog.md#Sin-22">22</a>, <a href="Changelog.md#Sin-7">7</a>|
|<a href="#Sinh">Sinh</a>|<a href="Changelog.md#Sinh-22">22</a>, <a href="Changelog.md#Sinh-9">9</a>|
|<a href="#Size">Size</a>|<a href="Changelog.md#Size-21">21</a>, <a href="Changelog.md#Size-19">19</a>, <a href="Changelog.md#Size-13">13</a>, <a href="Changelog.md#Size-1">1</a>|
|<a href="#Swish">Swish</a>|<a href="Changelog.md#Swish-22">22</a>|
|<a href="#Slice">Slice</a>|<a href="Changelog.md#Slice-13">13</a>, <a href="Changelog.md#Slice-11">11</a>, <a href="Changelog.md#Slice-10">10</a>, <a href="Changelog.md#Slice-1">1</a>|
|<a href="#SpaceToDepth">SpaceToDepth</a>|<a href="Changelog.md#SpaceToDepth-13">13</a>, <a href="Changelog.md#SpaceToDepth-1">1</a>|
|<a href="#Split">Split</a>|<a href="Changelog.md#Split-18">18</a>, <a href="Changelog.md#Split-13">13</a>, <a href="Changelog.md#Split-11">11</a>, <a href="Changelog.md#Split-2">2</a>, <a href="Changelog.md#Split-1">1</a>|
Expand Down
28 changes: 28 additions & 0 deletions onnx/backend/test/case/node/swish.py
@@ -0,0 +1,28 @@
# Copyright (c) ONNX Project Contributors

Check warning

Code scanning / lintrunner

RUFF/I002 Warning

Missing required import: from \_\_future\__ import annotations.
See https://docs.astral.sh/ruff/rules/missing-required-import

Check warning

Code scanning / lintrunner

RUFF/format Warning

Run lintrunner -a to apply this patch.
#
# SPDX-License-Identifier: Apache-2.0

import numpy as np

Check warning on line 5 in onnx/backend/test/case/node/swish.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/swish.py#L5

Added line #L5 was not covered by tests

import onnx
from onnx.backend.test.case.base import Base
from onnx.backend.test.case.node import expect

Check warning on line 9 in onnx/backend/test/case/node/swish.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/swish.py#L7-L9

Added lines #L7 - L9 were not covered by tests


def swish(x: np.ndarray, alpha: np.float16) -> np.ndarray:
return x * 1 / (1 + np.exp(np.negative(x * alpha)))

Check warning on line 13 in onnx/backend/test/case/node/swish.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/swish.py#L12-L13

Added lines #L12 - L13 were not covered by tests


class Swish(Base):
@staticmethod

Check warning on line 17 in onnx/backend/test/case/node/swish.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/swish.py#L16-L17

Added lines #L16 - L17 were not covered by tests
def export() -> None:
node = onnx.helper.make_node(

Check warning on line 19 in onnx/backend/test/case/node/swish.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/swish.py#L19

Added line #L19 was not covered by tests
"Swish",
inputs=["x"],
outputs=["y"],
)

x = np.array([3, 4, 5]).astype(np.float32)
y = swish(x, alpha=1.0)

Check warning on line 26 in onnx/backend/test/case/node/swish.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/swish.py#L25-L26

Added lines #L25 - L26 were not covered by tests

expect(node, inputs=[x], outputs=[y], name="test_swish")

Check warning on line 28 in onnx/backend/test/case/node/swish.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/swish.py#L28

Added line #L28 was not covered by tests
29 changes: 29 additions & 0 deletions onnx/defs/math/defs.cc
Expand Up @@ -621,6 +621,35 @@ ONNX_OPERATOR_SET_SCHEMA(
.SetContextDependentFunctionBodyBuilder(BuildContextDependentFunctionBodyGelu)
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));

static const char* Swish_ver22_doc = R"DOC(
Swish function takes one input data (Tensor<T>) and produces one output data (Tensor<T>) of the same shape,
where $Swish(x) = x * sigmoid(beta * x)$.
)DOC";

ONNX_OPERATOR_SET_SCHEMA(
Swish,
22,
OpSchema()
.SetDoc(Swish_ver22_doc)
.Input(0, "X", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.Output(0, "Y", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(bfloat16)"},
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)
.FunctionBody(
R"ONNX(
{
Alpha = Constant <value_float: float = @alpha>()
AlphaCast = CastLike (Alpha, X)
AlphaMulX = Mul (AlphaCast, X)
SigmoidAlphaMulX = Sigmoid(AlphaMulX)
Y = Mul (X, SigmoidAlphaMulX)
}
)ONNX",
22));

static const char* Exp_ver13_doc = R"DOC(
Calculates the exponential of the given input tensor, element-wise.
)DOC";
Expand Down
2 changes: 2 additions & 0 deletions onnx/defs/operator_sets.h
Expand Up @@ -1234,6 +1234,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, RNN);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, GRU);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, LSTM);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, GridSample);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, Swish);

// Iterate over schema from ai.onnx version 22
class OpSet_Onnx_ver22 {
Expand Down Expand Up @@ -1287,6 +1288,7 @@ class OpSet_Onnx_ver22 {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, GRU)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, LSTM)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, GridSample)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, Swish)>());
}
};

Expand Down
3 changes: 3 additions & 0 deletions onnx/test/version_converter/automatic_upgrade_test.py
Expand Up @@ -1301,6 +1301,9 @@ def test_Sum(self) -> None:
attrs={"consumed_inputs": [0]},
)

def test_Swish(self) -> None:
self._test_op_upgrade("Swish", 22)

def test_Tanh(self) -> None:
self._test_op_upgrade("Tanh", 1, attrs={"consumed_inputs": [0]})

Expand Down

0 comments on commit 2633a85

Please sign in to comment.