Skip to content

Commit

Permalink
Add Swish operator
Browse files Browse the repository at this point in the history
Signed-off-by: isdanni <leedanni@gmail.com>
  • Loading branch information
isdanni committed May 2, 2024
1 parent 1529880 commit b0ceb4d
Show file tree
Hide file tree
Showing 213 changed files with 93 additions and 0 deletions.
29 changes: 29 additions & 0 deletions docs/Changelog.md
Expand Up @@ -25735,6 +25735,35 @@ This version of the operator has been available since version 21 of the default
<dd>Constrain output to int64 tensor.</dd>
</dl>

### <a name="Swish-22"></a>**Swish-22**</a>

Swish function, takes one input data (Tensor<T>) and produces one output data (Tensor<T>) of the same shape, where $Swish(x) = x * sigmoid(beta * x)$.

#### Version

This version of the operator has been available since version 22 of the default ONNX operator set.

#### Inputs

<dl>
<dt><tt>X</tt> (differentiable) : T</dt>
<dd>Input tensor</dd>
</dl>

#### Outputs

<dl>
<dt><tt>Y</tt> (differentiable) : T</dt>
<dd>Output tensor</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16), tensor(float), tensor(double)</dt>
<dd>Constrain input and output types to float tensors.</dd>
</dl>

### <a name="Size-21"></a>**Size-21**</a>

Takes a tensor as input and outputs a int64 scalar that equals to the total number of elements of the input tensor.
Expand Down
1 change: 1 addition & 0 deletions docs/Operators.md
Expand Up @@ -142,6 +142,7 @@ For an operator input/output's differentiability, it can be differentiable,
|<a href="#Sin">Sin</a>|<a href="Changelog.md#Sin-22">22</a>, <a href="Changelog.md#Sin-7">7</a>|
|<a href="#Sinh">Sinh</a>|<a href="Changelog.md#Sinh-22">22</a>, <a href="Changelog.md#Sinh-9">9</a>|
|<a href="#Size">Size</a>|<a href="Changelog.md#Size-21">21</a>, <a href="Changelog.md#Size-19">19</a>, <a href="Changelog.md#Size-13">13</a>, <a href="Changelog.md#Size-1">1</a>|
|<a href="#Swish">Swish</a>|<a href="Changelog.md#Swish-22">22</a>|
|<a href="#Slice">Slice</a>|<a href="Changelog.md#Slice-13">13</a>, <a href="Changelog.md#Slice-11">11</a>, <a href="Changelog.md#Slice-10">10</a>, <a href="Changelog.md#Slice-1">1</a>|
|<a href="#SpaceToDepth">SpaceToDepth</a>|<a href="Changelog.md#SpaceToDepth-13">13</a>, <a href="Changelog.md#SpaceToDepth-1">1</a>|
|<a href="#Split">Split</a>|<a href="Changelog.md#Split-18">18</a>, <a href="Changelog.md#Split-13">13</a>, <a href="Changelog.md#Split-11">11</a>, <a href="Changelog.md#Split-2">2</a>, <a href="Changelog.md#Split-1">1</a>|
Expand Down
29 changes: 29 additions & 0 deletions onnx/backend/test/case/node/swish.py
@@ -0,0 +1,29 @@
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

Check warning on line 4 in onnx/backend/test/case/node/swish.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/swish.py#L4

Added line #L4 was not covered by tests

import numpy as np

Check warning on line 6 in onnx/backend/test/case/node/swish.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/swish.py#L6

Added line #L6 was not covered by tests

import onnx
from onnx.backend.test.case.base import Base
from onnx.backend.test.case.node import expect

Check warning on line 10 in onnx/backend/test/case/node/swish.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/swish.py#L8-L10

Added lines #L8 - L10 were not covered by tests


def swish(x: np.ndarray, alpha: np.float16) -> np.ndarray:
return x * 1 / (1 + np.exp(np.negative(x * alpha)))

Check warning on line 14 in onnx/backend/test/case/node/swish.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/swish.py#L13-L14

Added lines #L13 - L14 were not covered by tests


class Swish(Base):
@staticmethod

Check warning on line 18 in onnx/backend/test/case/node/swish.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/swish.py#L17-L18

Added lines #L17 - L18 were not covered by tests
def export() -> None:
node = onnx.helper.make_node(

Check warning on line 20 in onnx/backend/test/case/node/swish.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/swish.py#L20

Added line #L20 was not covered by tests
"Swish",
inputs=["x"],
outputs=["y"],
)

x = np.array([3, 4, 5]).astype(np.float32)
y = swish(x, alpha=1.0)

Check warning on line 27 in onnx/backend/test/case/node/swish.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/swish.py#L26-L27

Added lines #L26 - L27 were not covered by tests

expect(node, inputs=[x], outputs=[y], name="test_swish")

Check warning on line 29 in onnx/backend/test/case/node/swish.py

View check run for this annotation

Codecov / codecov/patch

onnx/backend/test/case/node/swish.py#L29

Added line #L29 was not covered by tests
Binary file modified onnx/backend/test/data/node/test_acos/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_acos_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_acosh/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_acosh_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_asin/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_asin_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_asinh/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_asinh_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_atan/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_atan_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_atanh/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_atanh_example/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_averagepool_2d_ceil/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_averagepool_2d_pads/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_bernoulli/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_bernoulli_double/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_bernoulli_expanded/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_bernoulli_seed/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_convtranspose/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_convtranspose_1d/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_convtranspose_3d/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_convtranspose_pad/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_convtranspose_pads/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cos/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cos_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cosh/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cosh_example/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_det_2d/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_det_nd/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_dropout_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_dropout_default_mask/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_elu/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_elu_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_elu_example/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_eyelike_with_dtype/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_globalaveragepool/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_globalmaxpool/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_gridsample/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_gridsample_bicubic/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_gridsample_bilinear/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_gridsample_nearest/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_gru_batchwise/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_gru_defaults/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_gru_seq_length/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_hardsigmoid/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_hardsigmoid_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_hardsigmoid_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_hardswish/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_hardswish_expanded/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_instancenorm_epsilon/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_instancenorm_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_1d_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_2d_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_2d_dilations/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_2d_pads/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_2d_same_lower/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_2d_same_upper/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_2d_strides/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_3d_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lstm_batchwise/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lstm_defaults/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lstm_with_peepholes/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_1d_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_2d_ceil/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_2d_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_2d_dilations/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_2d_pads/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_2d_strides/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_2d_uint8/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_3d_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_3d_dilations/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_mish/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_mish_expanded/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_nllloss_NC/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_nllloss_NC_expanded/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_nllloss_NCd1/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_nllloss_NCd1_ii/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_nllloss_NCd1_weight/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_nllloss_NCd1d2/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_rnn_seq_length/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_roialign_mode_max/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_round/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_selu/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_selu_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_selu_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_simple_rnn_batchwise/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_simple_rnn_defaults/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_sin/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_sin_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_sinh/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_sinh_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_softplus/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_softplus_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_softsign/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_softsign_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_tan/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_tan_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_thresholdedrelu/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_training_dropout/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
29 changes: 29 additions & 0 deletions onnx/defs/math/defs.cc
Expand Up @@ -621,6 +621,35 @@ ONNX_OPERATOR_SET_SCHEMA(
.SetContextDependentFunctionBodyBuilder(BuildContextDependentFunctionBodyGelu)
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));

static const char* Swish_ver22_doc = R"DOC(
Swish function takes one input data (Tensor<T>) and produces one output data (Tensor<T>) of the same shape,
where $Swish(x) = x * sigmoid(beta * x)$.
)DOC";

ONNX_OPERATOR_SET_SCHEMA(
Swish,
22,
OpSchema()
.SetDoc(Swish_ver22_doc)
.Input(0, "X", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.Output(0, "Y", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(bfloat16)"},
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)
.FunctionBody(
R"ONNX(
{
Alpha = Constant <value_float: float = @alpha>()
AlphaCast = CastLike (Alpha, X)
AlphaMulX = Mul (AlphaCast, X)
SigmoidAlphaMulX = Sigmoid(AlphaMulX)
Y = Mul (X, SigmoidAlphaMulX)
}
)ONNX",
22));

static const char* Exp_ver13_doc = R"DOC(
Calculates the exponential of the given input tensor, element-wise.
)DOC";
Expand Down
2 changes: 2 additions & 0 deletions onnx/defs/operator_sets.h
Expand Up @@ -1234,6 +1234,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, RNN);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, GRU);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, LSTM);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, GridSample);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, Swish);

// Iterate over schema from ai.onnx version 22
class OpSet_Onnx_ver22 {
Expand Down Expand Up @@ -1287,6 +1288,7 @@ class OpSet_Onnx_ver22 {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, GRU)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, LSTM)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, GridSample)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 22, Swish)>());
}
};

Expand Down
3 changes: 3 additions & 0 deletions onnx/test/version_converter/automatic_upgrade_test.py
Expand Up @@ -1301,6 +1301,9 @@ def test_Sum(self) -> None:
attrs={"consumed_inputs": [0]},
)

def test_Swish(self) -> None:
self._test_op_upgrade("Swish", 22)

def test_Tanh(self) -> None:
self._test_op_upgrade("Tanh", 1, attrs={"consumed_inputs": [0]})

Expand Down

0 comments on commit b0ceb4d

Please sign in to comment.