Skip to content

Commit

Permalink
【PaddlePaddle Hackathon 3 No.16】为 Paddle 新增 API paddle.take (#44741)
Browse files Browse the repository at this point in the history
  • Loading branch information
S-HuaBomb committed Aug 30, 2022
1 parent 871e332 commit 5f1a8e4
Show file tree
Hide file tree
Showing 4 changed files with 355 additions and 4 deletions.
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
Expand Up @@ -280,6 +280,7 @@
from .tensor.math import heaviside # noqa: F401
from .tensor.math import frac # noqa: F401
from .tensor.math import sgn # noqa: F401
from .tensor.math import take # noqa: F401

from .tensor.random import bernoulli # noqa: F401
from .tensor.random import poisson # noqa: F401
Expand Down Expand Up @@ -656,4 +657,5 @@
'tril_indices',
'sgn',
'triu_indices',
'take',
]
246 changes: 246 additions & 0 deletions python/paddle/fluid/tests/unittests/test_take.py
@@ -0,0 +1,246 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard


class TestTakeAPI(unittest.TestCase):

def set_mode(self):
self.mode = 'raise'

def set_dtype(self):
self.input_dtype = 'float64'
self.index_dtype = 'int64'

def set_input(self):
self.input_shape = [3, 4]
self.index_shape = [2, 3]
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
self.input_dtype)
self.index_np = np.arange(-4, 2).reshape(self.index_shape).astype(
self.index_dtype)

def setUp(self):
self.set_mode()
self.set_dtype()
self.set_input()
self.place = fluid.CUDAPlace(
0) if core.is_compiled_with_cuda() else fluid.CPUPlace()

def test_static_graph(self):
paddle.enable_static()
startup_program = Program()
train_program = Program()
with program_guard(startup_program, train_program):
x = fluid.data(name='input',
dtype=self.input_dtype,
shape=self.input_shape)
index = fluid.data(name='index',
dtype=self.index_dtype,
shape=self.index_shape)
out = paddle.take(x, index, mode=self.mode)

exe = fluid.Executor(self.place)
st_result = exe.run(fluid.default_main_program(),
feed={
'input': self.input_np,
'index': self.index_np
},
fetch_list=out)
np.testing.assert_allclose(
st_result[0],
np.take(self.input_np, self.index_np, mode=self.mode))

def test_dygraph(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.input_np)
index = paddle.to_tensor(self.index_np)
dy_result = paddle.take(x, index, mode=self.mode)
np.testing.assert_allclose(
np.take(self.input_np, self.index_np, mode=self.mode),
dy_result.numpy())


class TestTakeInt32(TestTakeAPI):
"""Test take API with data type int32"""

def set_dtype(self):
self.input_dtype = 'int32'
self.index_dtype = 'int64'


class TestTakeInt64(TestTakeAPI):
"""Test take API with data type int64"""

def set_dtype(self):
self.input_dtype = 'int64'
self.index_dtype = 'int64'


class TestTakeFloat32(TestTakeAPI):
"""Test take API with data type float32"""

def set_dtype(self):
self.input_dtype = 'float32'
self.index_dtype = 'int64'


class TestTakeTypeError(TestTakeAPI):
"""Test take Type Error"""

def test_static_type_error(self):
"""Argument 'index' must be Tensor"""
paddle.enable_static()
with program_guard(Program()):
x = fluid.data(name='input',
dtype=self.input_dtype,
shape=self.input_shape)
self.assertRaises(TypeError, paddle.take, x, self.index_np,
self.mode)

def test_dygraph_type_error(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.input_np)
self.assertRaises(TypeError, paddle.take, x, self.index_np, self.mode)

def test_static_dtype_error(self):
"""Data type of argument 'index' must be in [paddle.int32, paddle.int64]"""
paddle.enable_static()
with program_guard(Program()):
x = fluid.data(name='input',
dtype='float64',
shape=self.input_shape)
index = fluid.data(name='index',
dtype='float32',
shape=self.index_shape)
self.assertRaises(TypeError, paddle.take, x, index, self.mode)

def test_dygraph_dtype_error(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.input_np)
index = paddle.to_tensor(self.index_np, dtype='float32')
self.assertRaises(TypeError, paddle.take, x, index, self.mode)


class TestTakeModeRaisePos(unittest.TestCase):
"""Test positive index out of range error"""

def set_mode(self):
self.mode = 'raise'

def set_dtype(self):
self.input_dtype = 'float64'
self.index_dtype = 'int64'

def set_input(self):
self.input_shape = [3, 4]
self.index_shape = [5, 6]
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
self.input_dtype)
self.index_np = np.arange(-10, 20).reshape(self.index_shape).astype(
self.index_dtype) # positive indices are out of range

def setUp(self):
self.set_mode()
self.set_dtype()
self.set_input()
self.place = fluid.CUDAPlace(
0) if core.is_compiled_with_cuda() else fluid.CPUPlace()

def test_static_index_error(self):
"""When the index is out of range,
an error is reported directly through `paddle.index_select`"""
paddle.enable_static()
with program_guard(Program()):
x = fluid.data(name='input',
dtype=self.input_dtype,
shape=self.input_shape)
index = fluid.data(name='index',
dtype=self.index_dtype,
shape=self.index_shape)
self.assertRaises(ValueError, paddle.index_select, x, index)

def test_dygraph_index_error(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.input_np)
index = paddle.to_tensor(self.index_np, dtype=self.index_dtype)
self.assertRaises(ValueError, paddle.index_select, x, index)


class TestTakeModeRaiseNeg(TestTakeModeRaisePos):
"""Test negative index out of range error"""

def set_mode(self):
self.mode = 'raise'

def set_dtype(self):
self.input_dtype = 'float64'
self.index_dtype = 'int64'

def set_input(self):
self.input_shape = [3, 4]
self.index_shape = [5, 6]
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
self.input_dtype)
self.index_np = np.arange(-20, 10).reshape(self.index_shape).astype(
self.index_dtype) # negative indices are out of range

def setUp(self):
self.set_mode()
self.set_dtype()
self.set_input()
self.place = fluid.CUDAPlace(
0) if core.is_compiled_with_cuda() else fluid.CPUPlace()


class TestTakeModeWrap(TestTakeAPI):
"""Test take index out of range mode"""

def set_mode(self):
self.mode = 'wrap'

def set_input(self):
self.input_shape = [3, 4]
self.index_shape = [5, 8]
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
self.input_dtype)
self.index_np = np.arange(-20, 20).reshape(self.index_shape).astype(
self.index_dtype) # Both ends of the index are out of bounds


class TestTakeModeClip(TestTakeAPI):
"""Test take index out of range mode"""

def set_mode(self):
self.mode = 'clip'

def set_input(self):
self.input_shape = [3, 4]
self.index_shape = [5, 8]
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
self.input_dtype)
self.index_np = np.arange(-20, 20).reshape(self.index_shape).astype(
self.index_dtype) # Both ends of the index are out of bounds


if __name__ == "__main__":
unittest.main()
8 changes: 5 additions & 3 deletions python/paddle/tensor/__init__.py
Expand Up @@ -234,6 +234,7 @@
from .math import heaviside # noqa: F401
from .math import frac # noqa: F401
from .math import sgn # noqa: F401
from .math import take # noqa: F401

from .random import multinomial # noqa: F401
from .random import standard_normal # noqa: F401
Expand Down Expand Up @@ -280,8 +281,8 @@

from .einsum import einsum # noqa: F401

#this list used in math_op_patch.py for _binary_creator_
tensor_method_func = [ #noqa
# this list used in math_op_patch.py for _binary_creator_
tensor_method_func = [ # noqa
'matmul',
'dot',
'cov',
Expand Down Expand Up @@ -505,11 +506,12 @@
'put_along_axis_',
'exponential_',
'heaviside',
'take',
'bucketize',
'sgn',
]

#this list used in math_op_patch.py for magic_method bind
# this list used in math_op_patch.py for magic_method bind
magic_method_func = [
('__and__', 'bitwise_and'),
('__or__', 'bitwise_or'),
Expand Down

0 comments on commit 5f1a8e4

Please sign in to comment.