New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PaddlePaddle Hackathon 3 No.16】为 Paddle 新增 API paddle.take #44741
Changes from 5 commits
982d01e
69b0a3e
b07c062
09d2836
c8482f6
0665e50
c5a9e16
10b41c4
6852760
9649b87
ec1cfd7
6806a8f
27b6943
5d32c52
b35d831
cc2f4f4
c4161f2
aaee858
ca2604f
5979d5f
7b3fc1d
668964d
64b688a
cdd1080
eca0483
4ca5c41
9fb6896
7fd6c85
046ff44
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import print_function | ||
|
||
import unittest | ||
import numpy as np | ||
import paddle | ||
import paddle.fluid as fluid | ||
import paddle.fluid.core as core | ||
from paddle.fluid import Program, program_guard | ||
|
||
|
||
class TestTakeAPI(unittest.TestCase): | ||
|
||
def set_dtype(self): | ||
self.input_dtype = 'float64' | ||
self.index_dtype = 'int64' | ||
|
||
def setUp(self): | ||
self.set_dtype() | ||
self.place = fluid.CUDAPlace( | ||
0) if core.is_compiled_with_cuda() else fluid.CPUPlace() | ||
self.input_shape = [3, 4] | ||
self.index_shape = [2, 3] | ||
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype( | ||
self.input_dtype) | ||
self.index_np = np.arange(-4, 2).reshape(self.index_shape).astype( | ||
self.index_dtype) | ||
|
||
def test_static_graph(self): | ||
paddle.enable_static() | ||
startup_program = Program() | ||
train_program = Program() | ||
with program_guard(startup_program, train_program): | ||
x = fluid.data(name='input', | ||
dtype=self.input_dtype, | ||
shape=self.input_shape) | ||
index = fluid.data(name='index', | ||
dtype=self.index_dtype, | ||
shape=self.index_shape) | ||
out = paddle.take(x, index) | ||
|
||
exe = fluid.Executor(self.place) | ||
st_result = exe.run(fluid.default_main_program(), | ||
feed={ | ||
'input': self.input_np, | ||
'index': self.index_np | ||
}, | ||
fetch_list=[out]) | ||
self.assertTrue( | ||
np.allclose(st_result, np.take(self.input_np, self.index_np))) | ||
|
||
def test_dygraph(self): | ||
paddle.disable_static(self.place) | ||
x = paddle.to_tensor(self.input_np) | ||
index = paddle.to_tensor(self.index_np) | ||
dy_result = paddle.take(x, index) | ||
self.assertTrue( | ||
np.allclose(np.take(self.input_np, self.index_np), | ||
dy_result.numpy())) | ||
|
||
|
||
class TestTakeInt32(TestTakeAPI): | ||
"""Test take API with data type int32""" | ||
|
||
def set_dtype(self): | ||
self.input_dtype = 'int32' | ||
self.index_dtype = 'int64' | ||
|
||
|
||
class TestTakeInt64(TestTakeAPI): | ||
"""Test take API with data type int64""" | ||
|
||
def set_dtype(self): | ||
self.input_dtype = 'int64' | ||
self.index_dtype = 'int64' | ||
|
||
|
||
class TestTakeFloat32(TestTakeAPI): | ||
"""Test take API with data type float32""" | ||
|
||
def set_dtype(self): | ||
self.input_dtype = 'float32' | ||
self.index_dtype = 'int64' | ||
|
||
|
||
class TestTakeType(TestTakeAPI): | ||
"""Test take Error""" | ||
|
||
def test_static_type_error(self): | ||
"""Argument 'index' must be Tensor""" | ||
paddle.enable_static() | ||
with program_guard(Program()): | ||
x = fluid.data(name='input', | ||
dtype=self.input_dtype, | ||
shape=self.input_shape) | ||
self.assertRaises(TypeError, paddle.take, x, self.index_np) | ||
|
||
def test_dygraph_type_error(self): | ||
paddle.disable_static(self.place) | ||
x = paddle.to_tensor(self.input_np) | ||
self.assertRaises(TypeError, paddle.take, x, self.index_np) | ||
|
||
def test_static_dtype_error(self): | ||
"""Data type of argument 'index' must be in [paddle.int32, paddle.int64]""" | ||
paddle.enable_static() | ||
with program_guard(Program()): | ||
x = fluid.data(name='input', | ||
dtype='float64', | ||
shape=self.input_shape) | ||
index = fluid.data(name='index', | ||
dtype='float32', | ||
shape=self.index_shape) | ||
self.assertRaises(TypeError, paddle.take, x, index) | ||
|
||
def test_dygraph_dtype_error(self): | ||
paddle.disable_static(self.place) | ||
x = paddle.to_tensor(self.input_np) | ||
index = paddle.to_tensor(self.index_np, dtype='float32') | ||
self.assertRaises(TypeError, paddle.take, x, index) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4655,3 +4655,56 @@ def frac(x, name=None): | |
helper.append_op( | ||
type="trunc", inputs=inputs, attrs=attrs, outputs={"Out": y}) | ||
return _elementwise_op(LayerHelper(op_type, **locals())) | ||
|
||
def take(input, index, name=None): | ||
""" | ||
Returns a new tensor with the elements of input at the given indices. | ||
The input tensor is treated as if it were viewed as a 1-D tensor. | ||
The result takes the same shape as the indices. | ||
|
||
Args: | ||
input (Tensor): An N-D Tensor, which data type should be int32, int64, float32, float64. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. which data type-》its data type There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
index (Tensor): An N-D Tensor, which data type should be int32, int64. | ||
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. | ||
|
||
Returns: | ||
Tensor: Tensor with the same shape as index, the data type is the same with input. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tensor 后使用 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thx, done. |
||
|
||
Examples: | ||
.. code-block:: python | ||
|
||
import paddle | ||
|
||
x = paddle.arange(0, 12).reshape([3, 4]) | ||
idx = paddle.arange(4, 10).reshape([2, 3]) | ||
|
||
paddle.take(x, idx) | ||
# Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, | ||
# [[4, 5, 6], | ||
# [7, 8, 9]]) | ||
|
||
x.take(idx) | ||
# Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, | ||
# [[4, 5, 6], | ||
# [7, 8, 9]]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 示例可增加一个negative index和float类型的input There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
""" | ||
|
||
if paddle.in_dynamic_mode(): | ||
if not isinstance(index, (paddle.Tensor, Variable)): | ||
raise TypeError( | ||
"The type of 'index' must be Tensor, but got {}".format(type(index))) | ||
if index.dtype not in [paddle.int32, paddle.int64]: | ||
raise TypeError( | ||
"The data type of 'index' must be one of ['int32', 'int64'], but got {}".format( | ||
index.dtype)) | ||
else: | ||
check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'take') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. index索引越界时需要报错 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
input_1d = input.flatten() | ||
index_1d = index.flatten() | ||
|
||
# This processing enables 'take' to handle negative indexes within the correct range | ||
index_1d = paddle.where(index_1d < 0, index_1d + input_1d.shape[0], index_1d) | ||
out = input_1d.index_select(index_1d).reshape(index.shape) | ||
|
||
return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
缺少index索引越界的报错
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我直接通过
paddle.index_select
来报错。Done