Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hackathon No.16】add RFC for take API #217

Merged
merged 17 commits into from Aug 29, 2022
Merged
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 20 additions & 8 deletions rfcs/APIs/20220714_api_design_for_take.md
Expand Up @@ -233,11 +233,13 @@ def take(a, indices, axis=None, out=None, mode='raise'):

- `torch.take` 的 `index` 参数必须为 LongTensor 类型;`numpy.take` 直接对参数 `indices` 的元素取整再进行索引。

- 当不指定轴时,对于相同的索引矩阵,`Numpy.take` 的执行结果等于 `torch.take`。
- 在维度支持上,`numpy.take` 支持指定轴,`torch.take` 不支持

- 在维度支持上,`Numpy.take` 支持指定轴,`torch.take` 不支持
- 当不指定轴时,对于相同的索引矩阵,`numpy.take` 的执行结果等于 `torch.take`。

- `Numpy.take` 支持通过 `mode` 参数指定索引越界的 3 种处理方式,默认直接报错;`torch.take` 在索引越界时直接报错。
- `numpy.take` 支持通过 `mode` 参数指定索引越界的 3 种处理方式,默认直接报错;`torch.take` 在索引越界时直接报错。

> 由于 numpy 在指定轴索引后得到的结果不能保证与 index 的 shape 一致,会破坏 take 方法的输出结果形状与 index 一致的特性。因此我们决定新增的 `paddle.take` 的功能与 `torch.take` 和 `numpy.take` 的默认形式保持一致,即,不增加 axis 参数指定索引轴;在 `torch.take` 的基础上增加 mode 参数提供三种 index 索引越界的处理方式。尽可能保持 take 索引方法简洁、易理解的特性。

# 五、方案设计

Expand All @@ -247,8 +249,9 @@ def take(a, indices, axis=None, out=None, mode='raise'):

```python
paddle.take(
input: Tensor,
x: Tensor,
index: Tensor,
mode: str='raise',
name: str=None)
```

Expand All @@ -268,21 +271,30 @@ paddle.take(

1. 通过 `Tensor.flatten()` 将输入 x 和 index 展开成 1D Tensor。

2. 通过 `Tensor.index_select(index)` 按照 index 中的索引提取对应元素。
2. 根据 mode 参数对索引进行越界处理:
- `mode='raise'`, 直接抛出错误 (默认);
- `mode='wrap'`, 通过取余约束越界的 indices;
- `mode='clip'`, 通过 `paddle.clip` 约束两端的索引。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 需要在业内方案调研的Numpy实现方法里,补充下mode参数的越界处理的逻辑和核心代码
  2. raise抛出错误,不是在mode参数阶段抛出的,是在后面的index_select地方抛出,这里的描述不够准确
  3. 需要说明下wrap是用哪个API来完成处理的。
  4. 因为Clip: Note that this disables indexing with negative numbers.有这个限制,需要说明下直接使用paddle.clip是否符合要求


3. 通过 `Tensor.index_select(index)` 按照 index 中的索引提取对应元素。
- `numpy.take` 和 `torch.take` 支持负值索引;
- 然而 `index_select` 不支持,因此需要先将 index 的负值索引转为对应的正值索引。

3. 通过 `Tensor.reshape(index.shape)` 将输出的 Tensor 形状转成 index 的形状。
4. 通过 `Tensor.reshape(index.shape)` 将输出的 Tensor 形状转成 index 的形状。


# 六、测试和验收的考量

测试考虑的 case 如下:

- 参数 `index` 数据类型必须为 `paddle.int32` 和 `paddle.int64` 类型的 Tensor(与 `paddle.index_select` 一致)。

- `input` 的数据类型支持 `int32`,`int64`,`float32`,`float64`。
- `x` 的数据类型支持 `int32`,`int64`,`float32`,`float64`。

- `index` 索引越界时直接报错。
- `index` 索引越界的三种处理方式:
- `mode='raise'`, 直接抛出错误 (默认);
- `mode='wrap'`, 通过取余约束越界的 indices;
- `mode='clip'`, 通过 `paddle.clip` 约束两端的索引。

- 在动态图、静态图下,以及 CPU、GPU 下,都能得到正确的结果。

Expand Down