diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 63f16c4eb78f1..3578b9a1aaeea 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -329,6 +329,7 @@ from .tensor.stat import numel # noqa: F401 from .tensor.stat import median # noqa: F401 from .tensor.stat import quantile # noqa: F401 +from .tensor.stat import nanquantile # noqa: F401 from .device import get_cudnn_version # noqa: F401 from .device import set_device # noqa: F401 from .device import get_device # noqa: F401 @@ -495,6 +496,7 @@ 'numel', 'median', 'quantile', + 'nanquantile', 'no_grad', 'set_grad_enabled', 'is_grad_enabled', diff --git a/python/paddle/fluid/tests/unittests/test_quantile.py b/python/paddle/fluid/tests/unittests/test_quantile.py deleted file mode 100644 index 936d1d3be3a19..0000000000000 --- a/python/paddle/fluid/tests/unittests/test_quantile.py +++ /dev/null @@ -1,236 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import unittest -import numpy as np -import paddle - - -class TestQuantile(unittest.TestCase): - """ - This class is used for numerical precision testing. If there is - a corresponding numpy API, the precision comparison can be performed directly. - Otherwise, it needs to be verified by numpy implementated function. - """ - - def setUp(self): - np.random.seed(678) - self.input_data = np.random.rand(6, 7, 8, 9, 10) - - # Test correctness when q and axis are set. - def test_quantile_single_q(self): - x = paddle.to_tensor(self.input_data) - paddle_res = paddle.quantile(x, q=0.5, axis=2) - np_res = np.quantile(self.input_data, q=0.5, axis=2) - self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) - - # Test correctness for default axis. - def test_quantile_with_no_axis(self): - x = paddle.to_tensor(self.input_data) - paddle_res = paddle.quantile(x, q=0.35) - np_res = np.quantile(self.input_data, q=0.35) - self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) - - # Test correctness for multiple axis. - def test_quantile_with_multi_axis(self): - x = paddle.to_tensor(self.input_data) - paddle_res = paddle.quantile(x, q=0.75, axis=[0, 2, 3]) - np_res = np.quantile(self.input_data, q=0.75, axis=[0, 2, 3]) - self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) - - # Test correctness when keepdim is set. - def test_quantile_with_keepdim(self): - x = paddle.to_tensor(self.input_data) - paddle_res = paddle.quantile(x, q=0.35, axis=4, keepdim=True) - np_res = np.quantile(self.input_data, q=0.35, axis=4, keepdims=True) - self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) - - # Test correctness when all parameters are set. - def test_quantile_with_keepdim_and_multiple_axis(self): - x = paddle.to_tensor(self.input_data) - paddle_res = paddle.quantile(x, q=0.1, axis=[1, 4], keepdim=True) - np_res = np.quantile(self.input_data, q=0.1, axis=[1, 4], keepdims=True) - self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) - - # Test correctness when q = 0. - def test_quantile_with_boundary_q(self): - x = paddle.to_tensor(self.input_data) - paddle_res = paddle.quantile(x, q=0, axis=3) - np_res = np.quantile(self.input_data, q=0, axis=3) - self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) - - # Test correctness when input includes NaN. - def test_quantile_include_NaN(self): - input_data = np.random.randn(2, 3, 4) - input_data[0, 1, 1] = np.nan - x = paddle.to_tensor(input_data) - paddle_res = paddle.quantile(x, q=0.35, axis=0) - self.assertTrue(paddle.isnan(paddle_res[1, 1])) - - -class TestQuantileMuitlpleQ(unittest.TestCase): - """ - This class is used to test multiple input of q. - """ - - def setUp(self): - np.random.seed(678) - self.input_data = np.random.rand(10, 3, 4, 5, 4) - - def test_quantile(self): - x = paddle.to_tensor(self.input_data) - paddle_res = paddle.quantile(x, q=[0.3, 0.44], axis=-2) - np_res = np.quantile(self.input_data, q=[0.3, 0.44], axis=-2) - self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) - - def test_quantile_multiple_axis(self): - x = paddle.to_tensor(self.input_data) - paddle_res = paddle.quantile(x, q=[0.2, 0.67], axis=[1, -1]) - np_res = np.quantile(self.input_data, q=[0.2, 0.67], axis=[1, -1]) - self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) - - def test_quantile_multiple_axis_keepdim(self): - x = paddle.to_tensor(self.input_data) - paddle_res = paddle.quantile( - x, q=[0.1, 0.2, 0.3], axis=[1, 2], keepdim=True) - np_res = np.quantile( - self.input_data, q=[0.1, 0.2, 0.3], axis=[1, 2], keepdims=True) - self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) - - -class TestQuantileError(unittest.TestCase): - """ - This class is used to test that exceptions are thrown correctly. - Validity of all parameter values and types should be considered. - """ - - def setUp(self): - self.x = paddle.randn((2, 3, 4)) - - def test_errors(self): - # Test error when q > 1 - def test_q_range_error_1(): - paddle_res = paddle.quantile(self.x, q=1.5) - - self.assertRaises(ValueError, test_q_range_error_1) - - # Test error when q < 0 - def test_q_range_error_2(): - paddle_res = paddle.quantile(self.x, q=[0.2, -0.3]) - - self.assertRaises(ValueError, test_q_range_error_2) - - # Test error with no valid q - def test_q_range_error_3(): - paddle_res = paddle.quantile(self.x, q=[]) - - self.assertRaises(ValueError, test_q_range_error_3) - - # Test error when x is not Tensor - def test_x_type_error(): - x = [1, 3, 4] - paddle_res = paddle.quantile(x, q=0.9) - - self.assertRaises(TypeError, test_x_type_error) - - # Test error when scalar axis is not int - def test_axis_type_error_1(): - paddle_res = paddle.quantile(self.x, q=0.4, axis=0.4) - - self.assertRaises(ValueError, test_axis_type_error_1) - - # Test error when axis in List is not int - def test_axis_type_error_2(): - paddle_res = paddle.quantile(self.x, q=0.4, axis=[1, 0.4]) - - self.assertRaises(ValueError, test_axis_type_error_2) - - # Test error when axis not in [-D, D) - def test_axis_value_error_1(): - paddle_res = paddle.quantile(self.x, q=0.4, axis=10) - - self.assertRaises(ValueError, test_axis_value_error_1) - - # Test error when axis not in [-D, D) - def test_axis_value_error_2(): - paddle_res = paddle.quantile(self.x, q=0.4, axis=[1, -10]) - - self.assertRaises(ValueError, test_axis_value_error_2) - - # Test error with no valid axis - def test_axis_value_error_3(): - paddle_res = paddle.quantile(self.x, q=0.4, axis=[]) - - self.assertRaises(ValueError, test_axis_value_error_3) - - -class TestQuantileRuntime(unittest.TestCase): - """ - This class is used to test the API could run correctly with - different devices, different data types, and dygraph/static mode. - """ - - def setUp(self): - np.random.seed(678) - self.input_data = np.random.rand(6, 7, 8, 9, 10) - self.dtypes = ['float32', 'float64'] - self.devices = ['cpu'] - if paddle.device.is_compiled_with_cuda(): - self.devices.append('gpu') - - def test_dygraph(self): - paddle.disable_static() - for device in self.devices: - # Check different devices - paddle.set_device(device) - for dtype in self.dtypes: - # Check different dtypes - np_input_data = self.input_data.astype(dtype) - x = paddle.to_tensor(np_input_data, dtype=dtype) - paddle_res = paddle.quantile(x, q=0.5, axis=2) - np_res = np.quantile(np_input_data, q=0.5, axis=2) - self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) - - def test_static(self): - paddle.enable_static() - for device in self.devices: - x = paddle.static.data( - name="x", shape=self.input_data.shape, dtype=paddle.float32) - x_fp64 = paddle.static.data( - name="x_fp64", - shape=self.input_data.shape, - dtype=paddle.float64) - - results = paddle.quantile(x, q=0.5, axis=2) - np_input_data = self.input_data.astype('float32') - results_fp64 = paddle.quantile(x_fp64, q=0.5, axis=2) - np_input_data_fp64 = self.input_data.astype('float64') - - exe = paddle.static.Executor(device) - paddle_res, paddle_res_fp64 = exe.run( - paddle.static.default_main_program(), - feed={"x": np_input_data, - "x_fp64": np_input_data_fp64}, - fetch_list=[results, results_fp64]) - np_res = np.quantile(np_input_data, q=0.5, axis=2) - np_res_fp64 = np.quantile(np_input_data_fp64, q=0.5, axis=2) - self.assertTrue( - np.allclose(paddle_res, np_res) and np.allclose(paddle_res_fp64, - np_res_fp64)) - - -if __name__ == '__main__': - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_quantile_and_nanquantile.py b/python/paddle/fluid/tests/unittests/test_quantile_and_nanquantile.py new file mode 100644 index 0000000000000..f0368cd2bc34f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_quantile_and_nanquantile.py @@ -0,0 +1,268 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle + +API_list = [(paddle.quantile, np.quantile), + (paddle.nanquantile, np.nanquantile)] + + +class TestQuantileAndNanquantile(unittest.TestCase): + """ + This class is used for numerical precision testing. If there is + a corresponding numpy API, the precision comparison can be performed directly. + Otherwise, it needs to be verified by numpy implementated function. + """ + + def setUp(self): + self.input_data = np.random.rand(4, 7, 6) + + # Test correctness when q and axis are set. + def test_single_q(self): + inp = self.input_data + for (func, res_func) in API_list: + x = paddle.to_tensor(inp) + paddle_res = func(x, q=0.5, axis=2) + np_res = res_func(inp, q=0.5, axis=2) + self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) + inp[0, 1, 2] = np.nan + + # Test correctness for default axis. + def test_with_no_axis(self): + inp = self.input_data + for (func, res_func) in API_list: + x = paddle.to_tensor(inp) + paddle_res = func(x, q=0.35) + np_res = res_func(inp, q=0.35) + self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) + inp[0, 2, 1] = np.nan + inp[0, 1, 2] = np.nan + + # Test correctness for multiple axis. + def test_with_multi_axis(self): + inp = self.input_data + for (func, res_func) in API_list: + x = paddle.to_tensor(inp) + paddle_res = func(x, q=0.75, axis=[0, 2]) + np_res = res_func(inp, q=0.75, axis=[0, 2]) + self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) + inp[0, 5, 3] = np.nan + inp[0, 6, 2] = np.nan + + # Test correctness when keepdim is set. + def test_with_keepdim(self): + inp = self.input_data + for (func, res_func) in API_list: + x = paddle.to_tensor(inp) + paddle_res = func(x, q=0.35, axis=2, keepdim=True) + np_res = res_func(inp, q=0.35, axis=2, keepdims=True) + self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) + inp[0, 3, 4] = np.nan + + # Test correctness when all parameters are set. + def test_with_keepdim_and_multiple_axis(self): + inp = self.input_data + for (func, res_func) in API_list: + x = paddle.to_tensor(inp) + paddle_res = func(x, q=0.1, axis=[1, 2], keepdim=True) + np_res = res_func(inp, q=0.1, axis=[1, 2], keepdims=True) + self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) + inp[0, 6, 3] = np.nan + + # Test correctness when q = 0. + def test_with_boundary_q(self): + inp = self.input_data + for (func, res_func) in API_list: + x = paddle.to_tensor(inp) + paddle_res = func(x, q=0, axis=1) + np_res = res_func(inp, q=0, axis=1) + self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) + inp[0, 2, 5] = np.nan + + # Test correctness when input includes NaN. + def test_quantile_include_NaN(self): + input_data = np.random.randn(2, 3, 4) + input_data[0, 1, 1] = np.nan + x = paddle.to_tensor(input_data) + paddle_res = paddle.quantile(x, q=0.35, axis=0) + np_res = np.quantile(input_data, q=0.35, axis=0) + self.assertTrue(np.allclose(paddle_res.numpy(), np_res, equal_nan=True)) + + # Test correctness when input filled with NaN. + def test_nanquantile_all_NaN(self): + input_data = np.full(shape=[2, 3], fill_value=np.nan) + input_data[0, 2] = 0 + x = paddle.to_tensor(input_data) + paddle_res = paddle.nanquantile(x, q=0.35, axis=0) + np_res = np.nanquantile(input_data, q=0.35, axis=0) + self.assertTrue(np.allclose(paddle_res.numpy(), np_res, equal_nan=True)) + + +class TestMuitlpleQ(unittest.TestCase): + """ + This class is used to test multiple input of q. + """ + + def setUp(self): + self.input_data = np.random.rand(5, 3, 4) + + def test_quantile(self): + x = paddle.to_tensor(self.input_data) + paddle_res = paddle.quantile(x, q=[0.3, 0.44], axis=-2) + np_res = np.quantile(self.input_data, q=[0.3, 0.44], axis=-2) + self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) + + def test_quantile_multiple_axis(self): + x = paddle.to_tensor(self.input_data) + paddle_res = paddle.quantile(x, q=[0.2, 0.67], axis=[1, -1]) + np_res = np.quantile(self.input_data, q=[0.2, 0.67], axis=[1, -1]) + self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) + + def test_quantile_multiple_axis_keepdim(self): + x = paddle.to_tensor(self.input_data) + paddle_res = paddle.quantile( + x, q=[0.1, 0.2, 0.3], axis=[1, 2], keepdim=True) + np_res = np.quantile( + self.input_data, q=[0.1, 0.2, 0.3], axis=[1, 2], keepdims=True) + self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) + + +class TestError(unittest.TestCase): + """ + This class is used to test that exceptions are thrown correctly. + Validity of all parameter values and types should be considered. + """ + + def setUp(self): + self.x = paddle.randn((2, 3, 4)) + + def test_errors(self): + # Test error when q > 1 + def test_q_range_error_1(): + paddle_res = paddle.quantile(self.x, q=1.5) + + self.assertRaises(ValueError, test_q_range_error_1) + + # Test error when q < 0 + def test_q_range_error_2(): + paddle_res = paddle.quantile(self.x, q=[0.2, -0.3]) + + self.assertRaises(ValueError, test_q_range_error_2) + + # Test error with no valid q + def test_q_range_error_3(): + paddle_res = paddle.quantile(self.x, q=[]) + + self.assertRaises(ValueError, test_q_range_error_3) + + # Test error when x is not Tensor + def test_x_type_error(): + x = [1, 3, 4] + paddle_res = paddle.quantile(x, q=0.9) + + self.assertRaises(TypeError, test_x_type_error) + + # Test error when scalar axis is not int + def test_axis_type_error_1(): + paddle_res = paddle.quantile(self.x, q=0.4, axis=0.4) + + self.assertRaises(ValueError, test_axis_type_error_1) + + # Test error when axis in List is not int + def test_axis_type_error_2(): + paddle_res = paddle.quantile(self.x, q=0.4, axis=[1, 0.4]) + + self.assertRaises(ValueError, test_axis_type_error_2) + + # Test error when axis not in [-D, D) + def test_axis_value_error_1(): + paddle_res = paddle.quantile(self.x, q=0.4, axis=10) + + self.assertRaises(ValueError, test_axis_value_error_1) + + # Test error when axis not in [-D, D) + def test_axis_value_error_2(): + paddle_res = paddle.quantile(self.x, q=0.4, axis=[1, -10]) + + self.assertRaises(ValueError, test_axis_value_error_2) + + # Test error with no valid axis + def test_axis_value_error_3(): + paddle_res = paddle.quantile(self.x, q=0.4, axis=[]) + + self.assertRaises(ValueError, test_axis_value_error_3) + + +class TestQuantileRuntime(unittest.TestCase): + """ + This class is used to test the API could run correctly with + different devices, different data types, and dygraph/static mode. + """ + + def setUp(self): + self.input_data = np.random.rand(4, 7) + self.dtypes = ['float32', 'float64'] + self.devices = ['cpu'] + if paddle.device.is_compiled_with_cuda(): + self.devices.append('gpu') + + def test_dygraph(self): + paddle.disable_static() + for (func, res_func) in API_list: + for device in self.devices: + # Check different devices + paddle.set_device(device) + for dtype in self.dtypes: + # Check different dtypes + np_input_data = self.input_data.astype(dtype) + x = paddle.to_tensor(np_input_data, dtype=dtype) + paddle_res = func(x, q=0.5, axis=1) + np_res = res_func(np_input_data, q=0.5, axis=1) + self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) + + def test_static(self): + paddle.enable_static() + for (func, res_func) in API_list: + for device in self.devices: + x = paddle.static.data( + name="x", shape=self.input_data.shape, dtype=paddle.float32) + x_fp64 = paddle.static.data( + name="x_fp64", + shape=self.input_data.shape, + dtype=paddle.float64) + + results = func(x, q=0.5, axis=1) + np_input_data = self.input_data.astype('float32') + results_fp64 = func(x_fp64, q=0.5, axis=1) + np_input_data_fp64 = self.input_data.astype('float64') + + exe = paddle.static.Executor(device) + paddle_res, paddle_res_fp64 = exe.run( + paddle.static.default_main_program(), + feed={"x": np_input_data, + "x_fp64": np_input_data_fp64}, + fetch_list=[results, results_fp64]) + np_res = res_func(np_input_data, q=0.5, axis=1) + np_res_fp64 = res_func(np_input_data_fp64, q=0.5, axis=1) + self.assertTrue( + np.allclose(paddle_res, np_res) and + np.allclose(paddle_res_fp64, np_res_fp64)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 3c4647d4d6b68..5f0fb4336e014 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -262,6 +262,7 @@ from .stat import numel # noqa: F401 from .stat import median # noqa: F401 from .stat import quantile # noqa: F401 +from .stat import nanquantile # noqa: F401 from .to_string import set_printoptions # noqa: F401 @@ -445,6 +446,7 @@ 'numel', 'median', 'quantile', + 'nanquantile', 'is_complex', 'is_integer', 'rank', diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 9863abe1becbb..991b86fd47d16 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -342,13 +342,14 @@ def median(x, axis=None, keepdim=False, name=None): return out_tensor -def quantile(x, q, axis=None, keepdim=False): +def _compute_quantile(x, q, axis=None, keepdim=False, ignore_nan=False): """ Compute the quantile of the input along the specified axis. + Args: Args: x (Tensor): The input Tensor, it's data type can be float32, float64. - q (int|float|list): The q for calculate quantile, which should be in range [0, 1]. If q is a list, + q (int|float|list): The q for calculate quantile, which should be in range [0, 1]. If q is a list, each q will be calculated and the first dimension of output is same to the number of ``q`` . axis (int|list, optional): The axis along which to calculate quantile. ``axis`` should be int or list of int. ``axis`` should be in range [-D, D), where D is the dimensions of ``x`` . @@ -360,37 +361,28 @@ def quantile(x, q, axis=None, keepdim=False): the output Tensor is the same as ``x`` except in the reduced dimensions(it is of size 1 in this case). Otherwise, the shape of the output Tensor is squeezed in ``axis`` . Default is False. - name (str, optional): Name for the operation (optional, default is None). - For more information, please refer to :ref:`api_guide_Name`. + ignore_nan: (bool, optional): Whether to ignore NaN of input Tensor. + If ``ignore_nan`` is True, it will calculate nanquantile. + Otherwise it will calculate quantile. Default is False. Returns: - Tensor, results of quantile along ``axis`` of ``x``. If data type of ``x`` is float64, data type of results will be float64, otherwise data type will be float32. - - Examples: - .. code-block:: python - - import paddle - - x = paddle.randn((2,3)) - #[[-1.28740597, 0.49533170, -1.00698614], - # [-1.11656201, -1.01010525, -2.23457789]]) - - y1 = paddle.quantile(x, q=0.5, axis=[0, 1]) - # y1 = -1.06333363 - - y2 = paddle.quantile(x, q=0.5, axis=1) - # y2 = [-1.00698614, -1.11656201] - - y3 = paddle.quantile(x, q=[0.3, 0.5], axis=1) - # y3 =[[-1.11915410, -1.56376839], - # [-1.00698614, -1.11656201]] - - y4 = paddle.quantile(x, q=0.8, axis=1, keepdim=True) - # y4 = [[-0.10559537], - # [-1.05268800]]) + Tensor, results of quantile along ``axis`` of ``x``. + In order to obtain higher precision, data type of results will be float64. """ + # Validate x if not isinstance(x, Variable): raise TypeError("input x should be a Tensor.") + + # Validate q + if isinstance(q, (int, float)): + q = [q] + elif isinstance(q, (list, tuple)): + if len(q) <= 0: + raise ValueError("q should not be empty") + else: + raise TypeError("Type of q should be int, float, list or tuple.") + + # Validate axis dims = len(x.shape) out_shape = list(x.shape) if axis is None: @@ -399,7 +391,7 @@ def quantile(x, q, axis=None, keepdim=False): out_shape = [1] * dims else: if isinstance(axis, list): - if (len(axis) <= 0): + if len(axis) <= 0: raise ValueError("axis should not be empty") axis_src, axis_dst = [], [] for axis_single in axis: @@ -424,54 +416,177 @@ def quantile(x, q, axis=None, keepdim=False): if axis < 0: axis += dims out_shape[axis] = 1 + + mask = x.isnan() + valid_counts = mask.logical_not().sum(axis=axis, + keepdim=True, + dtype='float64') + indices = [] - if isinstance(q, (int, float)): - if q < 0 or q > 1: + + for q_num in q: + if q_num < 0 or q_num > 1: raise ValueError("q should be in range [0, 1]") - indices.append(q * (x.shape[axis] - 1)) - elif isinstance(q, (list, tuple)): - if len(q) <= 0: - raise ValueError("q should not be empty") - for q_num in q: - if q_num < 0 or q_num > 1: - raise ValueError("q should be in range [0, 1]") - indices.append(q_num * (x.shape[axis] - 1)) - else: - raise TypeError("Type of q should be int, float, list or tuple.") + if paddle.in_dynamic_mode(): + q_num = paddle.to_tensor(q_num, dtype='float64') + if ignore_nan: + indices.append(q_num * (valid_counts - 1)) + else: + # TODO(Asthestarsfalll): Use paddle.index_fill instead of where + index = q_num * (valid_counts - 1) + last_index = x.shape[axis] - 1 + nums = paddle.full_like(index, fill_value=last_index) + index = paddle.where(mask.any(axis=axis, keepdim=True), nums, index) + indices.append(index) + sorted_tensor = paddle.sort(x, axis) - indices_tensor = paddle.assign(indices).astype(paddle.float32) - indices_below = paddle.floor(indices_tensor).astype(paddle.int32) - indices_upper = paddle.ceil(indices_tensor).astype(paddle.int32) - outputs = [] - def expand_dim(indices, sorted_tensor_shape, axis): - assert axis < len(list(sorted_tensor_shape)) - expanded_shape = [1] * len(list(sorted_tensor_shape)) - expanded_shape = tuple(expanded_shape) - indices = indices.reshape(expanded_shape) - return indices + outputs = [] # TODO(chenjianye): replace the for-loop to directly take elements. - for i in range(len(indices)): - if (indices_upper[i] != indices_below[i]): - tensor_below = paddle.take_along_axis( - sorted_tensor, - expand_dim(indices_below[i], sorted_tensor.shape, axis), axis) - tensor_upper = paddle.take_along_axis( - sorted_tensor, - expand_dim(indices_upper[i], sorted_tensor.shape, axis), axis) - weights = (indices[i] - indices_below[i]).astype(x.dtype) - out = paddle.lerp(tensor_below, tensor_upper, weights) - else: - out = paddle.take_along_axis( - sorted_tensor, - expand_dim(indices_below[i], sorted_tensor.shape, axis), axis) + for index in indices: + indices_below = paddle.floor(index).astype(paddle.int32) + indices_upper = paddle.ceil(index).astype(paddle.int32) + tensor_upper = paddle.take_along_axis( + sorted_tensor, indices_upper, axis=axis) + tensor_below = paddle.take_along_axis( + sorted_tensor, indices_below, axis=axis) + weights = (index - indices_below.astype('float64')) + out = paddle.lerp( + tensor_below.astype('float64'), + tensor_upper.astype('float64'), weights) if not keepdim: out = paddle.squeeze(out, axis=axis) else: out = out.reshape(out_shape) outputs.append(out) - if isinstance(q, (list, tuple)): - return paddle.stack(outputs, 0) + + if len(q) > 1: + outputs = paddle.stack(outputs, 0) else: - return outputs[0] + outputs = outputs[0] + + return outputs + + +def quantile(x, q, axis=None, keepdim=False): + """ + Compute the quantile of the input along the specified axis. + If any values in a reduced row are NaN, then the quantiles for that reduction will be NaN. + + Args: + x (Tensor): The input Tensor, it's data type can be float32, float64. + q (int|float|list): The q for calculate quantile, which should be in range [0, 1]. If q is a list, + each q will be calculated and the first dimension of output is same to the number of ``q`` . + axis (int|list, optional): The axis along which to calculate quantile. ``axis`` should be int or list of int. + ``axis`` should be in range [-D, D), where D is the dimensions of ``x`` . + If ``axis`` is less than 0, it works the same way as :math:`axis + D`. + If ``axis`` is a list, quantile is calculated over all elements of given axises. + If ``axis`` is None, quantile is calculated over all elements of ``x``. Default is None. + keepdim (bool, optional): Whether to reserve the reduced dimension(s) + in the output Tensor. If ``keepdim`` is True, the dimensions of + the output Tensor is the same as ``x`` except in the reduced + dimensions(it is of size 1 in this case). Otherwise, the shape of + the output Tensor is squeezed in ``axis`` . Default is False. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, results of quantile along ``axis`` of ``x``. + In order to obtain higher precision, data type of results will be float64. + + Examples: + .. code-block:: python + + import numpy as np + import paddle + + x = np.arange(0, 8, dtype=np.float32).reshape(4, 2) + # [[0 1] + # [2 3] + # [4 5] + # [6 7]] + y = paddle.to_tensor(x) + y1 = paddle.quantile(y, q=0.5, axis=[0, 1]) + # 3.5 + + y2 = paddle.quantile(y, q=0.5, axis=1) + # [0.5 2.5 4.5 6.5] + + y3 = paddle.quantile(y, q=[0.3, 0.5], axis=0) + # [[1.8 2.8] + # [3. 4. ]] + + x[0][0] = np.nan + y = paddle.to_tensor(x) + y4 = paddle.quantile(y, q=0.8, axis=1, keepdim=True) + # [[nan] + # [2.8] + # [4.8] + # [6.8]] + + """ + return _compute_quantile(x, q, axis=axis, keepdim=keepdim, ignore_nan=False) + + +def nanquantile(x, q, axis=None, keepdim=False): + """ + Compute the quantile of the input as if NaN values in input did not exist. + If all values in a reduced row are NaN, then the quantiles for that reduction will be NaN. + + Args: + x (Tensor): The input Tensor, it's data type can be float32, float64. + q (int|float|list): The q for calculate quantile, which should be in range [0, 1]. If q is a list, + each q will be calculated and the first dimension of output is same to the number of ``q`` . + axis (int|list, optional): The axis along which to calculate quantile. ``axis`` should be int or list of int. + ``axis`` should be in range [-D, D), where D is the dimensions of ``x`` . + If ``axis`` is less than 0, it works the same way as :math:`axis + D`. + If ``axis`` is a list, quantile is calculated over all elements of given axises. + If ``axis`` is None, quantile is calculated over all elements of ``x``. Default is None. + keepdim (bool, optional): Whether to reserve the reduced dimension(s) + in the output Tensor. If ``keepdim`` is True, the dimensions of + the output Tensor is the same as ``x`` except in the reduced + dimensions(it is of size 1 in this case). Otherwise, the shape of + the output Tensor is squeezed in ``axis`` . Default is False. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, results of quantile along ``axis`` of ``x``. + In order to obtain higher precision, data type of results will be float64. + + Examples: + .. code-block:: python + + import numpy as np + import paddle + + x = np.array( + [[0, 1, 2, 3, 4], + [5, 6, 7, 8, 9]], + dtype=np.float32 + ) + x[0][0] = np.nan + + x = paddle.to_tensor(x) + y1 = paddle.nanquantile(x, q=0.5, axis=[0, 1]) + # 5.0 + + y2 = paddle.nanquantile(x, q=0.5, axis=1) + # [2.5 7. ] + + y3 = paddle.nanquantile(x, q=[0.3, 0.5], axis=0) + # [[5. 2.5 3.5 4.5 5.5] + # [5. 3.5 4.5 5.5 6.5] + + y4 = paddle.nanquantile(x, q=0.8, axis=1, keepdim=True) + # [[3.4] + # [8.2]] + + nan = paddle.full(shape=[2, 3], fill_value=np.nan) + y5 = paddle.nanquantile(nan, q=0.8, axis=1, keepdim=True) + # [[nan] + # [nan]] + + """ + return _compute_quantile(x, q, axis=axis, keepdim=keepdim, ignore_nan=True)