diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 060c62e9ec0c1..e419f09479a9c 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -280,6 +280,7 @@ from .tensor.math import heaviside # noqa: F401 from .tensor.math import frac # noqa: F401 from .tensor.math import sgn # noqa: F401 +from .tensor.math import take # noqa: F401 from .tensor.random import bernoulli # noqa: F401 from .tensor.random import poisson # noqa: F401 @@ -656,4 +657,5 @@ 'tril_indices', 'sgn', 'triu_indices', + 'take', ] diff --git a/python/paddle/fluid/tests/unittests/test_take.py b/python/paddle/fluid/tests/unittests/test_take.py new file mode 100644 index 0000000000000..6e58a3a43de47 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_take.py @@ -0,0 +1,246 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid import Program, program_guard + + +class TestTakeAPI(unittest.TestCase): + + def set_mode(self): + self.mode = 'raise' + + def set_dtype(self): + self.input_dtype = 'float64' + self.index_dtype = 'int64' + + def set_input(self): + self.input_shape = [3, 4] + self.index_shape = [2, 3] + self.input_np = np.arange(0, 12).reshape(self.input_shape).astype( + self.input_dtype) + self.index_np = np.arange(-4, 2).reshape(self.index_shape).astype( + self.index_dtype) + + def setUp(self): + self.set_mode() + self.set_dtype() + self.set_input() + self.place = fluid.CUDAPlace( + 0) if core.is_compiled_with_cuda() else fluid.CPUPlace() + + def test_static_graph(self): + paddle.enable_static() + startup_program = Program() + train_program = Program() + with program_guard(startup_program, train_program): + x = fluid.data(name='input', + dtype=self.input_dtype, + shape=self.input_shape) + index = fluid.data(name='index', + dtype=self.index_dtype, + shape=self.index_shape) + out = paddle.take(x, index, mode=self.mode) + + exe = fluid.Executor(self.place) + st_result = exe.run(fluid.default_main_program(), + feed={ + 'input': self.input_np, + 'index': self.index_np + }, + fetch_list=out) + np.testing.assert_allclose( + st_result[0], + np.take(self.input_np, self.index_np, mode=self.mode)) + + def test_dygraph(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.input_np) + index = paddle.to_tensor(self.index_np) + dy_result = paddle.take(x, index, mode=self.mode) + np.testing.assert_allclose( + np.take(self.input_np, self.index_np, mode=self.mode), + dy_result.numpy()) + + +class TestTakeInt32(TestTakeAPI): + """Test take API with data type int32""" + + def set_dtype(self): + self.input_dtype = 'int32' + self.index_dtype = 'int64' + + +class TestTakeInt64(TestTakeAPI): + """Test take API with data type int64""" + + def set_dtype(self): + self.input_dtype = 'int64' + self.index_dtype = 'int64' + + +class TestTakeFloat32(TestTakeAPI): + """Test take API with data type float32""" + + def set_dtype(self): + self.input_dtype = 'float32' + self.index_dtype = 'int64' + + +class TestTakeTypeError(TestTakeAPI): + """Test take Type Error""" + + def test_static_type_error(self): + """Argument 'index' must be Tensor""" + paddle.enable_static() + with program_guard(Program()): + x = fluid.data(name='input', + dtype=self.input_dtype, + shape=self.input_shape) + self.assertRaises(TypeError, paddle.take, x, self.index_np, + self.mode) + + def test_dygraph_type_error(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.input_np) + self.assertRaises(TypeError, paddle.take, x, self.index_np, self.mode) + + def test_static_dtype_error(self): + """Data type of argument 'index' must be in [paddle.int32, paddle.int64]""" + paddle.enable_static() + with program_guard(Program()): + x = fluid.data(name='input', + dtype='float64', + shape=self.input_shape) + index = fluid.data(name='index', + dtype='float32', + shape=self.index_shape) + self.assertRaises(TypeError, paddle.take, x, index, self.mode) + + def test_dygraph_dtype_error(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.input_np) + index = paddle.to_tensor(self.index_np, dtype='float32') + self.assertRaises(TypeError, paddle.take, x, index, self.mode) + + +class TestTakeModeRaisePos(unittest.TestCase): + """Test positive index out of range error""" + + def set_mode(self): + self.mode = 'raise' + + def set_dtype(self): + self.input_dtype = 'float64' + self.index_dtype = 'int64' + + def set_input(self): + self.input_shape = [3, 4] + self.index_shape = [5, 6] + self.input_np = np.arange(0, 12).reshape(self.input_shape).astype( + self.input_dtype) + self.index_np = np.arange(-10, 20).reshape(self.index_shape).astype( + self.index_dtype) # positive indices are out of range + + def setUp(self): + self.set_mode() + self.set_dtype() + self.set_input() + self.place = fluid.CUDAPlace( + 0) if core.is_compiled_with_cuda() else fluid.CPUPlace() + + def test_static_index_error(self): + """When the index is out of range, + an error is reported directly through `paddle.index_select`""" + paddle.enable_static() + with program_guard(Program()): + x = fluid.data(name='input', + dtype=self.input_dtype, + shape=self.input_shape) + index = fluid.data(name='index', + dtype=self.index_dtype, + shape=self.index_shape) + self.assertRaises(ValueError, paddle.index_select, x, index) + + def test_dygraph_index_error(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.input_np) + index = paddle.to_tensor(self.index_np, dtype=self.index_dtype) + self.assertRaises(ValueError, paddle.index_select, x, index) + + +class TestTakeModeRaiseNeg(TestTakeModeRaisePos): + """Test negative index out of range error""" + + def set_mode(self): + self.mode = 'raise' + + def set_dtype(self): + self.input_dtype = 'float64' + self.index_dtype = 'int64' + + def set_input(self): + self.input_shape = [3, 4] + self.index_shape = [5, 6] + self.input_np = np.arange(0, 12).reshape(self.input_shape).astype( + self.input_dtype) + self.index_np = np.arange(-20, 10).reshape(self.index_shape).astype( + self.index_dtype) # negative indices are out of range + + def setUp(self): + self.set_mode() + self.set_dtype() + self.set_input() + self.place = fluid.CUDAPlace( + 0) if core.is_compiled_with_cuda() else fluid.CPUPlace() + + +class TestTakeModeWrap(TestTakeAPI): + """Test take index out of range mode""" + + def set_mode(self): + self.mode = 'wrap' + + def set_input(self): + self.input_shape = [3, 4] + self.index_shape = [5, 8] + self.input_np = np.arange(0, 12).reshape(self.input_shape).astype( + self.input_dtype) + self.index_np = np.arange(-20, 20).reshape(self.index_shape).astype( + self.index_dtype) # Both ends of the index are out of bounds + + +class TestTakeModeClip(TestTakeAPI): + """Test take index out of range mode""" + + def set_mode(self): + self.mode = 'clip' + + def set_input(self): + self.input_shape = [3, 4] + self.index_shape = [5, 8] + self.input_np = np.arange(0, 12).reshape(self.input_shape).astype( + self.input_dtype) + self.index_np = np.arange(-20, 20).reshape(self.index_shape).astype( + self.index_dtype) # Both ends of the index are out of bounds + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 42da4030dec4c..a5c06cee85093 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -234,6 +234,7 @@ from .math import heaviside # noqa: F401 from .math import frac # noqa: F401 from .math import sgn # noqa: F401 +from .math import take # noqa: F401 from .random import multinomial # noqa: F401 from .random import standard_normal # noqa: F401 @@ -280,8 +281,8 @@ from .einsum import einsum # noqa: F401 -#this list used in math_op_patch.py for _binary_creator_ -tensor_method_func = [ #noqa +# this list used in math_op_patch.py for _binary_creator_ +tensor_method_func = [ # noqa 'matmul', 'dot', 'cov', @@ -505,11 +506,12 @@ 'put_along_axis_', 'exponential_', 'heaviside', + 'take', 'bucketize', 'sgn', ] -#this list used in math_op_patch.py for magic_method bind +# this list used in math_op_patch.py for magic_method bind magic_method_func = [ ('__and__', 'bitwise_and'), ('__or__', 'bitwise_or'), diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 3acd9d5897aa1..1d8b6a126152e 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -4748,7 +4748,6 @@ def frac(x, name=None): type="trunc", inputs=inputs, attrs=attrs, outputs={"Out": y}) return _elementwise_op(LayerHelper(op_type, **locals())) - def sgn(x, name=None): """ For complex tensor, this API returns a new tensor whose elements have the same angles as the corresponding @@ -4789,3 +4788,105 @@ def sgn(x, name=None): return paddle.as_complex(output) else: return paddle.sign(x) + +def take(x, index, mode='raise', name=None): + """ + Returns a new tensor with the elements of input tensor x at the given index. + The input tensor is treated as if it were viewed as a 1-D tensor. + The result takes the same shape as the index. + + Args: + x (Tensor): An N-D Tensor, its data type should be int32, int64, float32, float64. + index (Tensor): An N-D Tensor, its data type should be int32, int64. + mode (str, optional): Specifies how out-of-bounds index will behave. the candicates are ``'raise'``, ``'wrap'`` and ``'clip'``. + + - ``'raise'``: raise an error (default); + - ``'wrap'``: wrap around; + - ``'clip'``: clip to the range. ``'clip'`` mode means that all indices that are too large are replaced by the index that addresses the last element. Note that this disables indexing with negative numbers. + + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, Tensor with the same shape as index, the data type is the same with input. + + Examples: + .. code-block:: python + + import paddle + + x_int = paddle.arange(0, 12).reshape([3, 4]) + x_float = x_int.astype(paddle.float64) + + idx_pos = paddle.arange(4, 10).reshape([2, 3]) # positive index + idx_neg = paddle.arange(-2, 4).reshape([2, 3]) # negative index + idx_err = paddle.arange(-2, 13).reshape([3, 5]) # index out of range + + paddle.take(x_int, idx_pos) + # Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, + # [[4, 5, 6], + # [7, 8, 9]]) + + paddle.take(x_int, idx_neg) + # Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, + # [[10, 11, 0 ], + # [1 , 2 , 3 ]]) + + paddle.take(x_float, idx_pos) + # Tensor(shape=[2, 3], dtype=float64, place=Place(cpu), stop_gradient=True, + # [[4., 5., 6.], + # [7., 8., 9.]]) + + x_int.take(idx_pos) + # Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, + # [[4, 5, 6], + # [7, 8, 9]]) + + paddle.take(x_int, idx_err, mode='wrap') + # Tensor(shape=[3, 5], dtype=int32, place=Place(cpu), stop_gradient=True, + # [[10, 11, 0 , 1 , 2 ], + # [3 , 4 , 5 , 6 , 7 ], + # [8 , 9 , 10, 11, 0 ]]) + + paddle.take(x_int, idx_err, mode='clip') + # Tensor(shape=[3, 5], dtype=int32, place=Place(cpu), stop_gradient=True, + # [[0 , 0 , 0 , 1 , 2 ], + # [3 , 4 , 5 , 6 , 7 ], + # [8 , 9 , 10, 11, 11]]) + + """ + if mode not in ['raise', 'wrap', 'clip']: + raise ValueError( + "'mode' in 'take' should be 'raise', 'wrap', 'clip', but received {}.".format(mode)) + + if paddle.in_dynamic_mode(): + if not isinstance(index, (paddle.Tensor, Variable)): + raise TypeError( + "The type of 'index' must be Tensor, but got {}".format(type(index))) + if index.dtype not in [paddle.int32, paddle.int64]: + raise TypeError( + "The data type of 'index' must be one of ['int32', 'int64'], but got {}".format( + index.dtype)) + + else: + check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'take') + + input_1d = x.flatten() + index_1d = index.flatten() + max_index = input_1d.shape[-1] + + if mode == 'raise': + # This processing enables 'take' to handle negative indexes within the correct range. + index_1d = paddle.where(index_1d < 0, index_1d + max_index, index_1d) + elif mode == 'wrap': + # The out of range indices are constrained by taking the remainder. + index_1d = paddle.where(index_1d < 0, + index_1d % max_index, index_1d) + index_1d = paddle.where(index_1d >= max_index, + index_1d % max_index, index_1d) + elif mode == 'clip': + # 'clip' mode disables indexing with negative numbers. + index_1d = clip(index_1d, 0, max_index - 1) + + out = input_1d.index_select(index_1d).reshape(index.shape) + + return out