Skip to content

Commit

Permalink
Implement GELU as function op (onnx#5277)
Browse files Browse the repository at this point in the history
### Description
These changes have been made to support the GELU operator as a function
op.

### Motivation and Context
Support for [GELU: Gaussian Error Linear
Unit](https://paperswithcode.com/method/gelu) activation function, which
was requested in onnx#4933.
onnx#4423 also mentions this under the new ops section of `Contributions
Welcome`.

As per the discussion in onnx#4933, I have added GELU as a context-dependent
function-op, that uses the attribute `approximate` to return one of the
two possible function-body definitions.

The first function definition is the regular GELU:
`GELU(x)=x∗Φ(x) = 0.5 * x * (1 + erf(x / sqrt(2)))`

The second is the fast approximation based on `tanh`:
`GELU(x)=0.5 ∗ x ∗ (1+Tanh( sqrt(2/π) ∗ (x + 0.044715 ∗ x^3)))`

This implementation uses the [PyTorch docs for
GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html?highlight=gelu#torch.nn.GELU)
as a reference.

PS: I also refactored `onnx/defs/math/defs.cc` to bring the operator
implementation of `mish` right next to its doc string.

---------

Signed-off-by: pranshupant <pranshupant@gmail.com>
Co-authored-by: G. Ramalingam <grama@microsoft.com>
Signed-off-by: Megan Hampton <hamptonm@us.ibm.com>
  • Loading branch information
2 people authored and MegoHam21 committed Jul 10, 2023
1 parent 093f744 commit 56ca46f
Show file tree
Hide file tree
Showing 32 changed files with 355 additions and 13 deletions.
42 changes: 42 additions & 0 deletions docs/Changelog.md
Expand Up @@ -23919,6 +23919,48 @@ This version of the operator has been available since version 20 of the default
<dd>Constrain output types to be numerics.</dd>
</dl>

### <a name="Gelu-20"></a>**Gelu-20**</a>

Gelu takes one input data (Tensor<T>) and produces one
output data (Tensor<T>) where the gaussian error linear units function,
$y = 0.5 * x * (1 + erf(x/sqrt(2)))$ is applied to the tensor elementwise.
If the attribute "approximate" is set to "tanh", the function estimation,
$y = 0.5 * x * (1 + Tanh(sqrt(2/\pi) * (x + 0.044715 * x^3)))$ is used and applied
to the tensor elementwise.


#### Version

This version of the operator has been available since version 20 of the default ONNX operator set.

#### Attributes

<dl>
<dt><tt>approximate</tt> : string (default is none)</dt>
<dd>Gelu approximation algorithm: `"tanh"`, `"none"`(default).`"none"`: do not use approximation.`"tanh"`: use tanh approximation.</dd>
</dl>

#### Inputs

<dl>
<dt><tt>X</tt> (differentiable) : T</dt>
<dd>Input tensor</dd>
</dl>

#### Outputs

<dl>
<dt><tt>Y</tt> (differentiable) : T</dt>
<dd>Output tensor</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)</dt>
<dd>Constrain input and output types to float tensors.</dd>
</dl>

### <a name="GridSample-20"></a>**GridSample-20**</a>

Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from the `grid`.
Expand Down
96 changes: 96 additions & 0 deletions docs/Operators.md
Expand Up @@ -170,6 +170,7 @@ For an operator input/output's differentiability, it can be differentiable,
|<a href="#Clip">Clip</a>|<a href="Changelog.md#Clip-13">13</a>, <a href="Changelog.md#Clip-12">12</a>, <a href="Changelog.md#Clip-11">11</a>, <a href="Changelog.md#Clip-6">6</a>, <a href="Changelog.md#Clip-1">1</a>|13|
|<a href="#DynamicQuantizeLinear">DynamicQuantizeLinear</a>|<a href="Changelog.md#DynamicQuantizeLinear-11">11</a>|11|
|<a href="#Elu">Elu</a>|<a href="Changelog.md#Elu-6">6</a>, <a href="Changelog.md#Elu-1">1</a>|18|
|<a href="#Gelu">Gelu</a>|<a href="Changelog.md#Gelu-20">20</a>|20|
|<a href="#GreaterOrEqual">GreaterOrEqual</a>|<a href="Changelog.md#GreaterOrEqual-16">16</a>, <a href="Changelog.md#GreaterOrEqual-12">12</a>|16|
|<a href="#GroupNormalization">GroupNormalization</a>|<a href="Changelog.md#GroupNormalization-18">18</a>|18|
|<a href="#HammingWindow">HammingWindow</a>|<a href="Changelog.md#HammingWindow-17">17</a>|17|
Expand Down Expand Up @@ -9410,6 +9411,101 @@ expect(
</details>


### <a name="Gelu"></a><a name="gelu">**Gelu**</a>

Gelu takes one input data (Tensor<T>) and produces one
output data (Tensor<T>) where the gaussian error linear units function,
$y = 0.5 * x * (1 + erf(x/sqrt(2)))$ is applied to the tensor elementwise.
If the attribute "approximate" is set to "tanh", the function estimation,
$y = 0.5 * x * (1 + Tanh(sqrt(2/\pi) * (x + 0.044715 * x^3)))$ is used and applied
to the tensor elementwise.


#### Version

This version of the operator has been available since version 20 of the default ONNX operator set.

#### Attributes

<dl>
<dt><tt>approximate</tt> : string (default is none)</dt>
<dd>Gelu approximation algorithm: `"tanh"`, `"none"`(default).`"none"`: do not use approximation.`"tanh"`: use tanh approximation.</dd>
</dl>

#### Inputs

<dl>
<dt><tt>X</tt> (differentiable) : T</dt>
<dd>Input tensor</dd>
</dl>

#### Outputs

<dl>
<dt><tt>Y</tt> (differentiable) : T</dt>
<dd>Output tensor</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)</dt>
<dd>Constrain input and output types to float tensors.</dd>
</dl>


#### Examples

<details>
<summary>gelu_default</summary>

```python
node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"])

x = np.array([-1, 0, 1]).astype(np.float32)
# expected output [-0.15865526, 0., 0.84134474]
y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1")

x = np.random.randn(3, 4, 5).astype(np.float32)
# expected output [2.99595031, 3.99987331, 4.99999857]
y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2")
```

</details>


<details>
<summary>gelu_tanh</summary>

```python
node = onnx.helper.make_node(
"Gelu", inputs=["x"], outputs=["y"], approximate="tanh"
)

x = np.array([-1, 0, 1]).astype(np.float32)
# expected output [-0.158808, 0., 0.841192]
y = (
0.5
* x
* (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))
).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1")

x = np.random.randn(3, 4, 5).astype(np.float32)
# expected output [2.9963627, 3.99993, 4.9999995]
y = (
0.5
* x
* (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))
).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2")
```

</details>


### <a name="Gemm"></a><a name="gemm">**Gemm**</a>

General Matrix multiplication:
Expand Down
52 changes: 51 additions & 1 deletion docs/TestCoverage.md
Expand Up @@ -6,7 +6,7 @@
* [Overall Test Coverage](#overall-test-coverage)
# Node Test Coverage
## Summary
Node tests have covered 173/186 (93.01%, 5 generators excluded) common operators.
Node tests have covered 174/187 (93.05%, 5 generators excluded) common operators.

Node tests have covered 0/0 (N/A) experimental operators.

Expand Down Expand Up @@ -6241,6 +6241,56 @@ expect(
</details>


### Gelu
There are 2 test cases, listed as following:
<details>
<summary>gelu_default</summary>

```python
node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"])

x = np.array([-1, 0, 1]).astype(np.float32)
# expected output [-0.15865526, 0., 0.84134474]
y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1")

x = np.random.randn(3, 4, 5).astype(np.float32)
# expected output [2.99595031, 3.99987331, 4.99999857]
y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2")
```

</details>
<details>
<summary>gelu_tanh</summary>

```python
node = onnx.helper.make_node(
"Gelu", inputs=["x"], outputs=["y"], approximate="tanh"
)

x = np.array([-1, 0, 1]).astype(np.float32)
# expected output [-0.158808, 0., 0.841192]
y = (
0.5
* x
* (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))
).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1")

x = np.random.randn(3, 4, 5).astype(np.float32)
# expected output [2.9963627, 3.99993, 4.9999995]
y = (
0.5
* x
* (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))
).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2")
```

</details>


### Gemm
There are 11 test cases, listed as following:
<details>
Expand Down
51 changes: 51 additions & 0 deletions onnx/backend/test/case/node/gelu.py
@@ -0,0 +1,51 @@
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0

import math

import numpy as np

import onnx
from onnx.backend.test.case.base import Base
from onnx.backend.test.case.node import expect


class Gelu(Base):
@staticmethod
def export_gelu_tanh() -> None:
node = onnx.helper.make_node(
"Gelu", inputs=["x"], outputs=["y"], approximate="tanh"
)

x = np.array([-1, 0, 1]).astype(np.float32)
# expected output [-0.158808, 0., 0.841192]
y = (
0.5
* x
* (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))
).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1")

x = np.random.randn(3, 4, 5).astype(np.float32)
# expected output [2.9963627, 3.99993, 4.9999995]
y = (
0.5
* x
* (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))
).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2")

@staticmethod
def export_gelu_default() -> None:
node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"])

x = np.array([-1, 0, 1]).astype(np.float32)
# expected output [-0.15865526, 0., 0.84134474]
y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1")

x = np.random.randn(3, 4, 5).astype(np.float32)
# expected output [2.99595031, 3.99987331, 4.99999857]
y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2")
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1 @@
BxJ�x��?h��>��z?�j@$ �?�.z��8s?b��hdӽ�9�>(�>�%�?^�B?�0�= B�>]ת>�=�?R�iJ�>�Z�/d#��S'?�K]?��=��C@�(��Hm;= �?�2�?��?��>���>�Ec������!��� >*z�?��?�Oƾmǚ��6��&õ�gڿ��?�x�FKྙ[��� G?4�ο��Y�L=e��> �����k��QN�>.:�=�ݚ>�b"�6���
@@ -0,0 +1,3 @@
ByJ��?�K�>��Q?A�@t��? V$���I?�W��uB���>�d�=��?VQ?V��=q��>�~W>�Q�?�x�G>d�+�8^_�$��>�o2?�.�ޓ@�3ٽxD�<+7����?п�?24�=I�z>jL*��'A������=D�?D�?��
�|��
��B����KR�?h@�:U�Z����?T��;%��;�)����>�f�����a�J�>�s=��?>��*�S �
Expand Down
Binary file not shown.
@@ -0,0 +1 @@
BxJ�x��?h��>��z?�j@$ �?�.z��8s?b��hdӽ�9�>(�>�%�?^�B?�0�= B�>]ת>�=�?R�iJ�>�Z�/d#��S'?�K]?��=��C@�(��Hm;= �?�2�?��?��>���>�Ec������!��� >*z�?��?�Oƾmǚ��6��&õ�gڿ��?�x�FKྙ[��� G?4�ο��Y�L=e��> �����k��QN�>.:�=�ݚ>�b"�6���
@@ -0,0 +1,3 @@
ByJ��?�K�>��Q?A�@t��? V$���I?�W��uB���>�d�=��?VQ?V��=q��>�~W>�Q�?�x�G>d�+�8^_�$��>�o2?�.�ޓ@�3ٽxD�<+7����?п�?24�=I�z>jL*��'A������=D�?D�?��
�|��
��B����KR�?h@�:U�Z����?T��;%��;�)����>�f�����a�J�>�s=��?>��*�S �
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1 @@
BxJ�x��?h��>��z?�j@$ �?�.z��8s?b��hdӽ�9�>(�>�%�?^�B?�0�= B�>]ת>�=�?R�iJ�>�Z�/d#��S'?�K]?��=��C@�(��Hm;= �?�2�?��?��>���>�Ec������!��� >*z�?��?�Oƾmǚ��6��&õ�gڿ��?�x�FKྙ[��� G?4�ο��Y�L=e��> �����k��QN�>.:�=�ݚ>�b"�6���
Binary file not shown.
Binary file not shown.
@@ -0,0 +1 @@
BxJ�x��?h��>��z?�j@$ �?�.z��8s?b��hdӽ�9�>(�>�%�?^�B?�0�= B�>]ת>�=�?R�iJ�>�Z�/d#��S'?�K]?��=��C@�(��Hm;= �?�2�?��?��>���>�Ec������!��� >*z�?��?�Oƾmǚ��6��&õ�gڿ��?�x�FKྙ[��� G?4�ο��Y�L=e��> �����k��QN�>.:�=�ݚ>�b"�6���
Binary file not shown.

0 comments on commit 56ca46f

Please sign in to comment.