Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon 3 No.16】为 Paddle 新增 API paddle.take #44741

Merged
merged 29 commits into from Aug 30, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
982d01e
add paddle.take api
S-HuaBomb Jul 15, 2022
69b0a3e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Jul 15, 2022
b07c062
fix paddle.take
S-HuaBomb Jul 29, 2022
09d2836
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Jul 29, 2022
c8482f6
remove from pip import main
S-HuaBomb Jul 29, 2022
0665e50
test index out of range error
S-HuaBomb Aug 4, 2022
c5a9e16
test index out of range error and fix conflict
S-HuaBomb Aug 4, 2022
10b41c4
fix Examples
S-HuaBomb Aug 5, 2022
6852760
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Aug 5, 2022
9649b87
fix Examples
S-HuaBomb Aug 5, 2022
ec1cfd7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Aug 5, 2022
6806a8f
add param mode to take api
S-HuaBomb Aug 22, 2022
27b6943
fix conflict ad merge
S-HuaBomb Aug 22, 2022
5d32c52
add example code
S-HuaBomb Aug 22, 2022
b35d831
fix test using np.testing.assert_allclose
S-HuaBomb Aug 23, 2022
cc2f4f4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Aug 23, 2022
c4161f2
add annotation
S-HuaBomb Aug 23, 2022
aaee858
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Aug 23, 2022
ca2604f
fix typo
S-HuaBomb Aug 23, 2022
5979d5f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Aug 23, 2022
7b3fc1d
fix 嵌套列表
S-HuaBomb Aug 24, 2022
668964d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Aug 24, 2022
64b688a
fix Tensor,
S-HuaBomb Aug 24, 2022
cdd1080
fix docs warning
S-HuaBomb Aug 25, 2022
eca0483
fix conflict
S-HuaBomb Aug 25, 2022
4ca5c41
fix raise bug
S-HuaBomb Aug 27, 2022
9fb6896
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Aug 27, 2022
7fd6c85
add test case for negative index out of range error
S-HuaBomb Aug 29, 2022
046ff44
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Aug 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
95 changes: 73 additions & 22 deletions python/paddle/fluid/tests/unittests/test_take.py
Expand Up @@ -24,21 +24,28 @@

class TestTakeAPI(unittest.TestCase):

def set_mode(self):
self.mode = 'raise'

def set_dtype(self):
self.input_dtype = 'float64'
self.index_dtype = 'int64'

def setUp(self):
self.set_dtype()
self.place = fluid.CUDAPlace(
0) if core.is_compiled_with_cuda() else fluid.CPUPlace()
def set_input(self):
self.input_shape = [3, 4]
self.index_shape = [2, 3]
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
self.input_dtype)
self.index_np = np.arange(-4, 2).reshape(self.index_shape).astype(
self.index_dtype)

def setUp(self):
self.set_mode()
self.set_dtype()
self.set_input()
self.place = fluid.CUDAPlace(
0) if core.is_compiled_with_cuda() else fluid.CPUPlace()

def test_static_graph(self):
paddle.enable_static()
startup_program = Program()
Expand All @@ -50,7 +57,7 @@ def test_static_graph(self):
index = fluid.data(name='index',
dtype=self.index_dtype,
shape=self.index_shape)
out = paddle.take(x, index)
out = paddle.take(x, index, mode=self.mode)

exe = fluid.Executor(self.place)
st_result = exe.run(fluid.default_main_program(),
Expand All @@ -60,15 +67,17 @@ def test_static_graph(self):
},
fetch_list=[out])
self.assertTrue(
np.allclose(st_result, np.take(self.input_np, self.index_np)))
np.allclose(
st_result,
np.take(self.input_np, self.index_np, mode=self.mode)))

def test_dygraph(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.input_np)
index = paddle.to_tensor(self.index_np)
dy_result = paddle.take(x, index)
dy_result = paddle.take(x, index, mode=self.mode)
self.assertTrue(
np.allclose(np.take(self.input_np, self.index_np),
np.allclose(np.take(self.input_np, self.index_np, mode=self.mode),
dy_result.numpy()))


Expand Down Expand Up @@ -106,12 +115,13 @@ def test_static_type_error(self):
x = fluid.data(name='input',
dtype=self.input_dtype,
shape=self.input_shape)
self.assertRaises(TypeError, paddle.take, x, self.index_np)
self.assertRaises(TypeError, paddle.take, x, self.index_np,
self.mode)

def test_dygraph_type_error(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.input_np)
self.assertRaises(TypeError, paddle.take, x, self.index_np)
self.assertRaises(TypeError, paddle.take, x, self.index_np, self.mode)

def test_static_dtype_error(self):
"""Data type of argument 'index' must be in [paddle.int32, paddle.int64]"""
Expand All @@ -123,48 +133,89 @@ def test_static_dtype_error(self):
index = fluid.data(name='index',
dtype='float32',
shape=self.index_shape)
self.assertRaises(TypeError, paddle.take, x, index)
self.assertRaises(TypeError, paddle.take, x, index, self.mode)

def test_dygraph_dtype_error(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.input_np)
index = paddle.to_tensor(self.index_np, dtype='float32')
self.assertRaises(TypeError, paddle.take, x, index)
self.assertRaises(TypeError, paddle.take, x, index, self.mode)


class TestTakeIndexRangeError(TestTakeAPI):
class TestTakeModeRaise(unittest.TestCase):
"""Test take index out of range error"""

def set_mode(self):
self.mode = 'raise'

def set_dtype(self):
self.input_dtype = 'float64'
self.index_dtype = 'int64'

def set_input(self):
self.input_shape = [3, 4]
self.index_shape = [5, 8]
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
self.input_dtype)
self.index_np = np.arange(-20, 20).reshape(self.index_shape).astype(
self.index_dtype) # Both ends of the index are out of bounds

def setUp(self):
self.set_mode()
self.set_dtype()
self.set_input()
self.place = fluid.CUDAPlace(
0) if core.is_compiled_with_cuda() else fluid.CPUPlace()
self.input_shape = [3, 4]
self.index_shape = [2, 3]
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
self.input_dtype)
self.index_np = np.arange(6, 12).reshape(self.index_shape).astype(
self.index_dtype)

def test_static_index_error(self):
"""When the index is out of range,
an error is reported directly through `paddle.index_select`"""
paddle.enable_static()
with program_guard(Program()):
x = fluid.data(name='input',
dtype='float64',
dtype=self.input_dtype,
shape=self.input_shape)
index = fluid.data(name='index',
dtype='int64',
dtype=self.index_dtype,
shape=self.index_shape)
self.assertRaises(ValueError, paddle.index_select, x, index)

def test_dygraph_index_error(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.input_np)
index = paddle.to_tensor(self.index_np, dtype='int64')
index = paddle.to_tensor(self.index_np, dtype=self.index_dtype)
self.assertRaises(ValueError, paddle.index_select, x, index)


class TestTakeModeWrap(TestTakeAPI):
"""Test take index out of range mode"""

def set_mode(self):
self.mode = 'wrap'

def set_input(self):
self.input_shape = [3, 4]
self.index_shape = [5, 8]
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
self.input_dtype)
self.index_np = np.arange(-20, 20).reshape(self.index_shape).astype(
self.index_dtype) # Both ends of the index are out of bounds


class TestTakeModeClip(TestTakeAPI):
"""Test take index out of range mode"""

def set_mode(self):
self.mode = 'clip'

def set_input(self):
self.input_shape = [3, 4]
self.index_shape = [5, 8]
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
self.input_dtype)
self.index_np = np.arange(-20, 20).reshape(self.index_shape).astype(
self.index_dtype) # Both ends of the index are out of bounds


if __name__ == "__main__":
unittest.main()
40 changes: 32 additions & 8 deletions python/paddle/tensor/math.py
Expand Up @@ -4701,15 +4701,22 @@ def frac(x, name=None):
type="trunc", inputs=inputs, attrs=attrs, outputs={"Out": y})
return _elementwise_op(LayerHelper(op_type, **locals()))

def take(input, index, name=None):
def take(x, index, mode='raise', name=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the name of parameter needs to be consistent with rfc, input in rfc while x here, and mode is not in rfc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeff41404 根据之前的修改意见 PaddlePaddle/community#186 (review) 更新过RFC:PaddlePaddle/community#217
参数的名字按照新的RFC内容进行修改的。

@S-HuaBomb 请先修改完RFC的评审意见吧。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rfc is still old now, should update and merge rfc first

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the modified RFC PaddlePaddle/community#217 with instructions added

"""
Returns a new tensor with the elements of input at the given index.
The input tensor is treated as if it were viewed as a 1-D tensor.
Returns a new tensor with the elements of tnput tensor x at the given index.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tnput?是个 typo 嘛?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,fixed,done

The input tensor is treated as if it were viewed as a 1-D tensor.
The result takes the same shape as the index.

Args:
input (Tensor): An N-D Tensor, its data type should be int32, int64, float32, float64.
x (Tensor): An N-D Tensor, its data type should be int32, int64, float32, float64.
index (Tensor): An N-D Tensor, its data type should be int32, int64.
mode (str, optional): Specifies how out-of-bounds index will behave.
the candicates are ``'raise'`` | ``'wrap'`` | ``'clip'``.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里用 , 分隔即可,下面的内容按照中文文档那边的意见统一改成列表吧~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx, done.

If :attr:`mode` is ``'raise'``, raise an error (default);
If :attr:`mode` is ``'wrap'``, wrap around;
If :attr:`mode` is ``'clip'``, clip to the range.
``'clip'`` mode means that all indices that are too large are replaced by the index that
addresses the last element. Note that this disables indexing with negative numbers.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

Returns:
Expand Down Expand Up @@ -4746,6 +4753,9 @@ def take(input, index, name=None):
# [[4, 5, 6],
# [7, 8, 9]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

示例可增加一个negative index和float类型的input

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

"""
if mode not in ['raise', 'wrap', 'clip']:
raise ValueError(
"'mode' in 'take' should be 'raise', 'wrap', 'clip', but received {}.".format(mode))

if paddle.in_dynamic_mode():
if not isinstance(index, (paddle.Tensor, Variable)):
Expand All @@ -4755,14 +4765,28 @@ def take(input, index, name=None):
raise TypeError(
"The data type of 'index' must be one of ['int32', 'int64'], but got {}".format(
index.dtype))

else:
check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'take')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index索引越界时需要报错

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


input_1d = input.flatten()
input_1d = x.flatten()
index_1d = index.flatten()
max_index = input_1d.shape[-1]

if mode == 'raise':
# This processing enables 'take' to handle negative indexes within the correct range.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以补充下注释,negative indexes可以enable,但越界的索引会在下面的index_select报错

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THX,Done

index_1d = paddle.where(index_1d < 0, index_1d % max_index, index_1d)
pass
elif mode == 'wrap':
# The out of range indices are constrained by taking the remainder.
index_1d = paddle.where(index_1d < 0,
index_1d % max_index, index_1d)
index_1d = paddle.where(index_1d >= max_index,
index_1d % max_index, index_1d)
elif mode == 'clip':
# 'clip' mode disables indexing with negative numbers.
index_1d = clip(index_1d, 0, max_index - 1)

# This processing enables 'take' to handle negative indexes within the correct range
index_1d = paddle.where(index_1d < 0, index_1d + input_1d.shape[0], index_1d)
out = input_1d.index_select(index_1d).reshape(index.shape)

return out