Skip to content

Commit

Permalink
【PaddlePaddle Hackathon 2】8、为 Paddle 新增 nanmean API (#40472)
Browse files Browse the repository at this point in the history
* Update __init__.py

* Update math.py

* Create test_nanmean_api.py

* Update __init__.py

* Update __init__.py

* Update math.py

* Update test_nanmean_api.py

* Update __init__.py

* Update math.py

* Update test_nanmean_api.py

* Update test_nanmean_api.py

* Update test_nanmean_api.py

* Update math.py

* Update test_nanmean_api.py

* Update math.py

Update the nanmean example code

* Update math.py

* Update math.py

* Update math.py

Remove redundant code in nanmean

* Update math.py

change default keepdim = False

* Update test_nanmean_api.py

add nan into self.x

* Update test_nanmean_api.py

rerun CI check

* Update test_nanmean_api.py

* update code of nanmean in python/paddle/tensor/math.py and test_nanmean_api.py

* Update test_nanmean_api.py

update code format

* Update test_nanmean_api.py

update code format

* Update test_nanmean_api.py

add check grad code.

* Update math.py

update nanmean's describe of Args x

* Update test_nanmean_api.py

update format and release the test_case(self.x, keepdim=True) in check grad code.

* Update test_nanmean_api.py

Update gradient checking method

* Update test_nanmean_api.py

update code format

* Update test_nanmean_api.py 

Update code format and copyright in test_nanmean_api.py

* Update math.py

update arguments describe and code example

* Update math.py

修改了nanmean的axis参数的文档描述。

* Update math.py

updata nanmean's sample code (:name: code-example1)

* Update math.py

修改nanmean的example code 错误

* Update math.py

update example code

* Update math.py

update example code of nanmean
  • Loading branch information
Li-fAngyU committed Apr 6, 2022
1 parent 176df91 commit 1d43e2d
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
Expand Up @@ -213,6 +213,7 @@
from .tensor.math import stanh # noqa: F401
from .tensor.math import sum # noqa: F401
from .tensor.math import nansum # noqa: F401
from .tensor.math import nanmean # noqa: F401
from .tensor.math import tanh # noqa: F401
from .tensor.math import tanh_ # noqa: F401
from .tensor.math import add_n # noqa: F401
Expand Down Expand Up @@ -545,6 +546,7 @@
'not_equal',
'sum',
'nansum',
'nanmean',
'tile',
'greater_equal',
'isfinite',
Expand Down
137 changes: 137 additions & 0 deletions python/paddle/fluid/tests/unittests/test_nanmean_api.py
@@ -0,0 +1,137 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard

np.random.seed(10)


class TestNanmeanAPI(unittest.TestCase):
# test paddle.tensor.math.nanmean

def setUp(self):
self.x_shape = [2, 3, 4, 5]
self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32)
self.x[0, :, :, :] = np.nan
self.x_grad = np.array([[np.nan, np.nan, 3.],
[0., np.nan, 2.]]).astype(np.float32)
self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()

def test_api_static(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('X', self.x_shape)
out1 = paddle.nanmean(x)
out2 = paddle.tensor.nanmean(x)
out3 = paddle.tensor.math.nanmean(x)
axis = np.arange(len(self.x_shape)).tolist()
out4 = paddle.nanmean(x, axis)
out5 = paddle.nanmean(x, tuple(axis))
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x},
fetch_list=[out1, out2, out3, out4, out5])
out_ref = np.nanmean(self.x)
for out in res:
self.assertEqual(np.allclose(out, out_ref, rtol=1e-04), True)

def test_api_dygraph(self):
paddle.disable_static(self.place)

def test_case(x, axis=None, keepdim=False):
x_tensor = paddle.to_tensor(x)
out = paddle.nanmean(x_tensor, axis, keepdim)
if isinstance(axis, list):
axis = tuple(axis)
if len(axis) == 0:
axis = None

out_ref = np.nanmean(x, axis, keepdims=keepdim)
if np.isnan(out_ref).sum():
nan_mask = np.isnan(out_ref)
out_ref[nan_mask] = 0
out_np = out.numpy()
out_np[nan_mask] = 0
self.assertEqual(np.allclose(out_np, out_ref, rtol=1e-04), True)
else:
self.assertEqual(
np.allclose(
out.numpy(), out_ref, rtol=1e-04), True)

test_case(self.x)
test_case(self.x, [])
test_case(self.x, -1)
test_case(self.x, keepdim=True)
test_case(self.x, 2, keepdim=True)
test_case(self.x, [0, 2])
test_case(self.x, (0, 2))
test_case(self.x, [0, 1, 2, 3])
paddle.enable_static()

def test_errors(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('X', [10, 12], 'int32')
self.assertRaises(TypeError, paddle.nanmean, x)

def test_api_dygraph_grad(self):
paddle.disable_static(self.place)

def test_case(x, axis=None, keepdim=False):
if isinstance(axis, list):
axis = list(axis)
if len(axis) == 0:
axis = None
x_tensor = paddle.to_tensor(x, stop_gradient=False)
y = paddle.nanmean(x_tensor, axis, keepdim)
dx = paddle.grad(y, x_tensor)[0].numpy()
sum_dx_ref = np.prod(y.shape)
if np.isnan(y.numpy()).sum():
sum_dx_ref -= np.isnan(y.numpy()).sum()
cnt = paddle.sum(~paddle.isnan(x_tensor),
axis=axis,
keepdim=keepdim)
if (cnt == 0).sum():
dx[np.isnan(dx)] = 0
sum_dx = dx.sum()
self.assertEqual(np.allclose(sum_dx, sum_dx_ref, rtol=1e-04), True)

test_case(self.x)
test_case(self.x, [])
test_case(self.x, -1)
test_case(self.x, keepdim=True)
test_case(self.x, 2, keepdim=True)
test_case(self.x, [0, 2])
test_case(self.x, (0, 2))
test_case(self.x, [0, 1, 2, 3])

test_case(self.x_grad)
test_case(self.x_grad, [])
test_case(self.x_grad, -1)
test_case(self.x_grad, keepdim=True)
test_case(self.x_grad, 0, keepdim=True)
test_case(self.x_grad, 1)
test_case(self.x_grad, (0, 1))
paddle.enable_static()


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions python/paddle/tensor/__init__.py
Expand Up @@ -165,6 +165,7 @@
from .math import stanh # noqa: F401
from .math import sum # noqa: F401
from .math import nansum # noqa: F401
from .math import nanmean # noqa: F401
from .math import tanh # noqa: F401
from .math import tanh_ # noqa: F401
from .math import add_n # noqa: F401
Expand Down Expand Up @@ -333,6 +334,7 @@
'stanh',
'sum',
'nansum',
'nanmean',
'tanh',
'tanh_',
'add_n',
Expand Down
68 changes: 68 additions & 0 deletions python/paddle/tensor/math.py 100755 → 100644
Expand Up @@ -1024,6 +1024,73 @@ def nansum(x, axis=None, dtype=None, keepdim=False, name=None):
return sum(tmp_tensor, axis, dtype, keepdim, name)


def nanmean(x, axis=None, keepdim=False, name=None):
r"""
Compute the arithmetic mean along the specified axis, ignoring NaNs.
Args:
x (Tensor): The input Tensor with data type uint16, float16, float32, float64.
axis (int|list|tuple, optional):The axis along which to perform nanmean
calculations. ``axis`` should be int, list(int) or tuple(int). If
``axis`` is a list/tuple of dimension(s), nanmean is calculated along
all element(s) of ``axis`` . ``axis`` or element(s) of ``axis``
should be in range [-D, D), where D is the dimensions of ``x`` . If
``axis`` or element(s) of ``axis`` is less than 0, it works the
same way as :math:`axis + D` . If ``axis`` is None, nanmean is
calculated over all elements of ``x``. Default is None.
keepdim (bool, optional): Whether to reserve the reduced dimension(s)
in the output Tensor. If ``keepdim`` is True, the dimensions of
the output Tensor is the same as ``x`` except in the reduced
dimensions(it is of size 1 in this case). Otherwise, the shape of
the output Tensor is squeezed in ``axis`` . Default is False.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, results of arithmetic mean along ``axis`` of ``x``, with the same data
type as ``x``.
Examples:
.. code-block:: python
:name: code-example1
import paddle
# x is a 2-D Tensor:
x = paddle.to_tensor([[float('nan'), 0.3, 0.5, 0.9],
[0.1, 0.2, float('-nan'), 0.7]])
out1 = paddle.nanmean(x)
# [0.44999996]
out2 = paddle.nanmean(x, axis=0)
# [0.1, 0.25, 0.5, 0.79999995]
out3 = paddle.nanmean(x, axis=0, keepdim=True)
# [[0.1, 0.25, 0.5, 0.79999995]]
out4 = paddle.nanmean(x, axis=1)
# [0.56666666 0.33333334]
out5 = paddle.nanmean(x, axis=1, keepdim=True)
# [[0.56666666]
# [0.33333334]]
# y is a 3-D Tensor:
y = paddle.to_tensor([[[1, float('nan')], [3, 4]],
[[5, 6], [float('-nan'), 8]]])
out6 = paddle.nanmean(y, axis=[1, 2])
# [2.66666675, 6.33333349]
out7 = paddle.nanmean(y, axis=[0, 1])
# [3., 6.]
"""
if isinstance(axis, int):
axis = [axis]
check_variable_and_dtype(x, 'x/input',
['uint16', 'float16', 'float32', 'float64'],
'nanmean' )
if axis is not None:
check_type(axis, 'axis/dim', (int, list, tuple), 'nanmean')

cnt = paddle.sum(~paddle.isnan(x), axis = axis,keepdim=keepdim)
return paddle.divide(paddle.nansum(x, axis=axis, keepdim=keepdim, name=name), cnt.astype(x.dtype))


@templatedoc(op_type="sum")
def add_n(inputs, name=None):
"""
Expand Down Expand Up @@ -3941,6 +4008,7 @@ def diff(x, n=1, axis=-1, prepend=None, append=None, name=None):
else:
out = elementwise_sub(input_back, input_front, axis=axis)
return out

else:
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'bool', 'int32', 'int64'], 'diff')
check_type(axis, 'axis', (int), 'diff')
Expand Down

0 comments on commit 1d43e2d

Please sign in to comment.