New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PaddlePaddle Hackathon 3 No.16】为 Paddle 新增 API paddle.take #44741
Changes from 9 commits
982d01e
69b0a3e
b07c062
09d2836
c8482f6
0665e50
c5a9e16
10b41c4
6852760
9649b87
ec1cfd7
6806a8f
27b6943
5d32c52
b35d831
cc2f4f4
c4161f2
aaee858
ca2604f
5979d5f
7b3fc1d
668964d
64b688a
cdd1080
eca0483
4ca5c41
9fb6896
7fd6c85
046ff44
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import print_function | ||
|
||
import unittest | ||
import numpy as np | ||
import paddle | ||
import paddle.fluid as fluid | ||
import paddle.fluid.core as core | ||
from paddle.fluid import Program, program_guard | ||
|
||
|
||
class TestTakeAPI(unittest.TestCase): | ||
|
||
def set_dtype(self): | ||
self.input_dtype = 'float64' | ||
self.index_dtype = 'int64' | ||
|
||
def setUp(self): | ||
self.set_dtype() | ||
self.place = fluid.CUDAPlace( | ||
0) if core.is_compiled_with_cuda() else fluid.CPUPlace() | ||
self.input_shape = [3, 4] | ||
self.index_shape = [2, 3] | ||
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype( | ||
self.input_dtype) | ||
self.index_np = np.arange(-4, 2).reshape(self.index_shape).astype( | ||
self.index_dtype) | ||
|
||
def test_static_graph(self): | ||
paddle.enable_static() | ||
startup_program = Program() | ||
train_program = Program() | ||
with program_guard(startup_program, train_program): | ||
x = fluid.data(name='input', | ||
dtype=self.input_dtype, | ||
shape=self.input_shape) | ||
index = fluid.data(name='index', | ||
dtype=self.index_dtype, | ||
shape=self.index_shape) | ||
out = paddle.take(x, index) | ||
|
||
exe = fluid.Executor(self.place) | ||
st_result = exe.run(fluid.default_main_program(), | ||
feed={ | ||
'input': self.input_np, | ||
'index': self.index_np | ||
}, | ||
fetch_list=[out]) | ||
self.assertTrue( | ||
np.allclose(st_result, np.take(self.input_np, self.index_np))) | ||
|
||
def test_dygraph(self): | ||
paddle.disable_static(self.place) | ||
x = paddle.to_tensor(self.input_np) | ||
index = paddle.to_tensor(self.index_np) | ||
dy_result = paddle.take(x, index) | ||
self.assertTrue( | ||
np.allclose(np.take(self.input_np, self.index_np), | ||
dy_result.numpy())) | ||
|
||
|
||
class TestTakeInt32(TestTakeAPI): | ||
"""Test take API with data type int32""" | ||
|
||
def set_dtype(self): | ||
self.input_dtype = 'int32' | ||
self.index_dtype = 'int64' | ||
|
||
|
||
class TestTakeInt64(TestTakeAPI): | ||
"""Test take API with data type int64""" | ||
|
||
def set_dtype(self): | ||
self.input_dtype = 'int64' | ||
self.index_dtype = 'int64' | ||
|
||
|
||
class TestTakeFloat32(TestTakeAPI): | ||
"""Test take API with data type float32""" | ||
|
||
def set_dtype(self): | ||
self.input_dtype = 'float32' | ||
self.index_dtype = 'int64' | ||
|
||
|
||
class TestTakeTypeError(TestTakeAPI): | ||
"""Test take Type Error""" | ||
|
||
def test_static_type_error(self): | ||
"""Argument 'index' must be Tensor""" | ||
paddle.enable_static() | ||
with program_guard(Program()): | ||
x = fluid.data(name='input', | ||
dtype=self.input_dtype, | ||
shape=self.input_shape) | ||
self.assertRaises(TypeError, paddle.take, x, self.index_np) | ||
|
||
def test_dygraph_type_error(self): | ||
paddle.disable_static(self.place) | ||
x = paddle.to_tensor(self.input_np) | ||
self.assertRaises(TypeError, paddle.take, x, self.index_np) | ||
|
||
def test_static_dtype_error(self): | ||
"""Data type of argument 'index' must be in [paddle.int32, paddle.int64]""" | ||
paddle.enable_static() | ||
with program_guard(Program()): | ||
x = fluid.data(name='input', | ||
dtype='float64', | ||
shape=self.input_shape) | ||
index = fluid.data(name='index', | ||
dtype='float32', | ||
shape=self.index_shape) | ||
self.assertRaises(TypeError, paddle.take, x, index) | ||
|
||
def test_dygraph_dtype_error(self): | ||
paddle.disable_static(self.place) | ||
x = paddle.to_tensor(self.input_np) | ||
index = paddle.to_tensor(self.index_np, dtype='float32') | ||
self.assertRaises(TypeError, paddle.take, x, index) | ||
|
||
|
||
class TestTakeIndexRangeError(TestTakeAPI): | ||
"""Test take index out of range error""" | ||
|
||
def setUp(self): | ||
self.set_dtype() | ||
self.place = fluid.CUDAPlace( | ||
0) if core.is_compiled_with_cuda() else fluid.CPUPlace() | ||
self.input_shape = [3, 4] | ||
self.index_shape = [2, 3] | ||
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype( | ||
self.input_dtype) | ||
self.index_np = np.arange(6, 12).reshape(self.index_shape).astype( | ||
self.index_dtype) | ||
|
||
def test_static_index_error(self): | ||
"""When the index is out of range, | ||
an error is reported directly through `paddle.index_select`""" | ||
paddle.enable_static() | ||
with program_guard(Program()): | ||
x = fluid.data(name='input', | ||
dtype='float64', | ||
shape=self.input_shape) | ||
index = fluid.data(name='index', | ||
dtype='int64', | ||
shape=self.index_shape) | ||
self.assertRaises(ValueError, paddle.index_select, x, index) | ||
|
||
def test_dygraph_index_error(self): | ||
paddle.disable_static(self.place) | ||
x = paddle.to_tensor(self.input_np) | ||
index = paddle.to_tensor(self.index_np, dtype='int64') | ||
self.assertRaises(ValueError, paddle.index_select, x, index) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4700,3 +4700,71 @@ def frac(x, name=None): | |
helper.append_op( | ||
type="trunc", inputs=inputs, attrs=attrs, outputs={"Out": y}) | ||
return _elementwise_op(LayerHelper(op_type, **locals())) | ||
|
||
def take(input, index, name=None): | ||
""" | ||
Returns a new tensor with the elements of input at the given index. | ||
The input tensor is treated as if it were viewed as a 1-D tensor. | ||
The result takes the same shape as the index. | ||
|
||
Args: | ||
input (Tensor): An N-D Tensor, its data type should be int32, int64, float32, float64. | ||
index (Tensor): An N-D Tensor, its data type should be int32, int64. | ||
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. | ||
|
||
Returns: | ||
Tensor: Tensor with the same shape as index, the data type is the same with input. | ||
|
||
Examples: | ||
.. code-block:: python | ||
|
||
import numpy as np | ||
import paddle | ||
|
||
n = np.arange(0, 12).reshape([3, 4]) | ||
x_int = paddle.to_tensor(n, dtype='int64') | ||
x_float = paddle.to_tensor(n, dtype='float64') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以使用paddle API直接生成输入的情况下,尽量避免引入第三方库哈~ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. docs 的 pr 也需要改一下 |
||
|
||
idx_pos = paddle.arange(4, 10).reshape([2, 3]) # positive index | ||
idx_neg = paddle.arange(-2, 4).reshape([2, 3]) # negative index | ||
|
||
paddle.take(x_int, idx_pos) | ||
# Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, | ||
# [[4, 5, 6], | ||
# [7, 8, 9]]) | ||
|
||
paddle.take(x_int, idx_neg) | ||
# Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, | ||
# [[10, 11, 0 ], | ||
# [1 , 2 , 3 ]]) | ||
|
||
paddle.take(x_float, idx_pos) | ||
# Tensor(shape=[2, 3], dtype=float64, place=Place(cpu), stop_gradient=True, | ||
# [[4., 5., 6.], | ||
# [7., 8., 9.]]) | ||
|
||
x_int.take(idx_pos) | ||
# Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, | ||
# [[4, 5, 6], | ||
# [7, 8, 9]]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 示例可增加一个negative index和float类型的input There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
""" | ||
|
||
if paddle.in_dynamic_mode(): | ||
if not isinstance(index, (paddle.Tensor, Variable)): | ||
raise TypeError( | ||
"The type of 'index' must be Tensor, but got {}".format(type(index))) | ||
if index.dtype not in [paddle.int32, paddle.int64]: | ||
raise TypeError( | ||
"The data type of 'index' must be one of ['int32', 'int64'], but got {}".format( | ||
index.dtype)) | ||
else: | ||
check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'take') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. index索引越界时需要报错 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
input_1d = input.flatten() | ||
index_1d = index.flatten() | ||
|
||
# This processing enables 'take' to handle negative indexes within the correct range | ||
index_1d = paddle.where(index_1d < 0, index_1d + input_1d.shape[0], index_1d) | ||
out = input_1d.index_select(index_1d).reshape(index.shape) | ||
|
||
return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tensor 后使用
,
,以避免解析出Return Type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx, done.