diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index e532633b6eb35..fa0f3b27677eb 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -213,6 +213,7 @@ from .tensor.math import stanh # noqa: F401 from .tensor.math import sum # noqa: F401 from .tensor.math import nansum # noqa: F401 +from .tensor.math import nanmean # noqa: F401 from .tensor.math import tanh # noqa: F401 from .tensor.math import tanh_ # noqa: F401 from .tensor.math import add_n # noqa: F401 @@ -545,6 +546,7 @@ 'not_equal', 'sum', 'nansum', + 'nanmean', 'tile', 'greater_equal', 'isfinite', diff --git a/python/paddle/fluid/tests/unittests/test_nanmean_api.py b/python/paddle/fluid/tests/unittests/test_nanmean_api.py new file mode 100644 index 0000000000000..90a9a130899d3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_nanmean_api.py @@ -0,0 +1,137 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid import Program, program_guard + +np.random.seed(10) + + +class TestNanmeanAPI(unittest.TestCase): + # test paddle.tensor.math.nanmean + + def setUp(self): + self.x_shape = [2, 3, 4, 5] + self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32) + self.x[0, :, :, :] = np.nan + self.x_grad = np.array([[np.nan, np.nan, 3.], + [0., np.nan, 2.]]).astype(np.float32) + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_api_static(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('X', self.x_shape) + out1 = paddle.nanmean(x) + out2 = paddle.tensor.nanmean(x) + out3 = paddle.tensor.math.nanmean(x) + axis = np.arange(len(self.x_shape)).tolist() + out4 = paddle.nanmean(x, axis) + out5 = paddle.nanmean(x, tuple(axis)) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x}, + fetch_list=[out1, out2, out3, out4, out5]) + out_ref = np.nanmean(self.x) + for out in res: + self.assertEqual(np.allclose(out, out_ref, rtol=1e-04), True) + + def test_api_dygraph(self): + paddle.disable_static(self.place) + + def test_case(x, axis=None, keepdim=False): + x_tensor = paddle.to_tensor(x) + out = paddle.nanmean(x_tensor, axis, keepdim) + if isinstance(axis, list): + axis = tuple(axis) + if len(axis) == 0: + axis = None + + out_ref = np.nanmean(x, axis, keepdims=keepdim) + if np.isnan(out_ref).sum(): + nan_mask = np.isnan(out_ref) + out_ref[nan_mask] = 0 + out_np = out.numpy() + out_np[nan_mask] = 0 + self.assertEqual(np.allclose(out_np, out_ref, rtol=1e-04), True) + else: + self.assertEqual( + np.allclose( + out.numpy(), out_ref, rtol=1e-04), True) + + test_case(self.x) + test_case(self.x, []) + test_case(self.x, -1) + test_case(self.x, keepdim=True) + test_case(self.x, 2, keepdim=True) + test_case(self.x, [0, 2]) + test_case(self.x, (0, 2)) + test_case(self.x, [0, 1, 2, 3]) + paddle.enable_static() + + def test_errors(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('X', [10, 12], 'int32') + self.assertRaises(TypeError, paddle.nanmean, x) + + def test_api_dygraph_grad(self): + paddle.disable_static(self.place) + + def test_case(x, axis=None, keepdim=False): + if isinstance(axis, list): + axis = list(axis) + if len(axis) == 0: + axis = None + x_tensor = paddle.to_tensor(x, stop_gradient=False) + y = paddle.nanmean(x_tensor, axis, keepdim) + dx = paddle.grad(y, x_tensor)[0].numpy() + sum_dx_ref = np.prod(y.shape) + if np.isnan(y.numpy()).sum(): + sum_dx_ref -= np.isnan(y.numpy()).sum() + cnt = paddle.sum(~paddle.isnan(x_tensor), + axis=axis, + keepdim=keepdim) + if (cnt == 0).sum(): + dx[np.isnan(dx)] = 0 + sum_dx = dx.sum() + self.assertEqual(np.allclose(sum_dx, sum_dx_ref, rtol=1e-04), True) + + test_case(self.x) + test_case(self.x, []) + test_case(self.x, -1) + test_case(self.x, keepdim=True) + test_case(self.x, 2, keepdim=True) + test_case(self.x, [0, 2]) + test_case(self.x, (0, 2)) + test_case(self.x, [0, 1, 2, 3]) + + test_case(self.x_grad) + test_case(self.x_grad, []) + test_case(self.x_grad, -1) + test_case(self.x_grad, keepdim=True) + test_case(self.x_grad, 0, keepdim=True) + test_case(self.x_grad, 1) + test_case(self.x_grad, (0, 1)) + paddle.enable_static() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 32902029b8a47..fc6c8f106ce4f 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -165,6 +165,7 @@ from .math import stanh # noqa: F401 from .math import sum # noqa: F401 from .math import nansum # noqa: F401 +from .math import nanmean # noqa: F401 from .math import tanh # noqa: F401 from .math import tanh_ # noqa: F401 from .math import add_n # noqa: F401 @@ -333,6 +334,7 @@ 'stanh', 'sum', 'nansum', + 'nanmean', 'tanh', 'tanh_', 'add_n', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py old mode 100755 new mode 100644 index a69ecb6db4d93..9751892e70188 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -1024,6 +1024,73 @@ def nansum(x, axis=None, dtype=None, keepdim=False, name=None): return sum(tmp_tensor, axis, dtype, keepdim, name) +def nanmean(x, axis=None, keepdim=False, name=None): + r""" + Compute the arithmetic mean along the specified axis, ignoring NaNs. + + Args: + x (Tensor): The input Tensor with data type uint16, float16, float32, float64. + axis (int|list|tuple, optional):The axis along which to perform nanmean + calculations. ``axis`` should be int, list(int) or tuple(int). If + ``axis`` is a list/tuple of dimension(s), nanmean is calculated along + all element(s) of ``axis`` . ``axis`` or element(s) of ``axis`` + should be in range [-D, D), where D is the dimensions of ``x`` . If + ``axis`` or element(s) of ``axis`` is less than 0, it works the + same way as :math:`axis + D` . If ``axis`` is None, nanmean is + calculated over all elements of ``x``. Default is None. + keepdim (bool, optional): Whether to reserve the reduced dimension(s) + in the output Tensor. If ``keepdim`` is True, the dimensions of + the output Tensor is the same as ``x`` except in the reduced + dimensions(it is of size 1 in this case). Otherwise, the shape of + the output Tensor is squeezed in ``axis`` . Default is False. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, results of arithmetic mean along ``axis`` of ``x``, with the same data + type as ``x``. + + Examples: + + .. code-block:: python + :name: code-example1 + + import paddle + # x is a 2-D Tensor: + x = paddle.to_tensor([[float('nan'), 0.3, 0.5, 0.9], + [0.1, 0.2, float('-nan'), 0.7]]) + out1 = paddle.nanmean(x) + # [0.44999996] + out2 = paddle.nanmean(x, axis=0) + # [0.1, 0.25, 0.5, 0.79999995] + out3 = paddle.nanmean(x, axis=0, keepdim=True) + # [[0.1, 0.25, 0.5, 0.79999995]] + out4 = paddle.nanmean(x, axis=1) + # [0.56666666 0.33333334] + out5 = paddle.nanmean(x, axis=1, keepdim=True) + # [[0.56666666] + # [0.33333334]] + + # y is a 3-D Tensor: + y = paddle.to_tensor([[[1, float('nan')], [3, 4]], + [[5, 6], [float('-nan'), 8]]]) + out6 = paddle.nanmean(y, axis=[1, 2]) + # [2.66666675, 6.33333349] + out7 = paddle.nanmean(y, axis=[0, 1]) + # [3., 6.] + """ + if isinstance(axis, int): + axis = [axis] + check_variable_and_dtype(x, 'x/input', + ['uint16', 'float16', 'float32', 'float64'], + 'nanmean' ) + if axis is not None: + check_type(axis, 'axis/dim', (int, list, tuple), 'nanmean') + + cnt = paddle.sum(~paddle.isnan(x), axis = axis,keepdim=keepdim) + return paddle.divide(paddle.nansum(x, axis=axis, keepdim=keepdim, name=name), cnt.astype(x.dtype)) + + @templatedoc(op_type="sum") def add_n(inputs, name=None): """ @@ -3941,6 +4008,7 @@ def diff(x, n=1, axis=-1, prepend=None, append=None, name=None): else: out = elementwise_sub(input_back, input_front, axis=axis) return out + else: check_variable_and_dtype(x, 'x', ['float32', 'float64', 'bool', 'int32', 'int64'], 'diff') check_type(axis, 'axis', (int), 'diff')