Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hackathon No.16】add RFC for take API #217

Merged
merged 17 commits into from Aug 29, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
132 changes: 124 additions & 8 deletions rfcs/APIs/20220714_api_design_for_take.md
Expand Up @@ -217,6 +217,7 @@ def take(a, indices, axis=None, out=None, mode='raise'):
"""
return _wrapfunc(a, 'take', indices, axis=axis, out=out, mode=mode)
```

整体逻辑为:

- 当指定 `axis` 的时候,`numpy.take` 执行与 “fancy indexing” 相同的索引操作(使用数组索引数组);例如 `np.take(arr, indices, axis=3)` 等价于 `arr[:, :, :, indices, ...]`。
Expand All @@ -225,6 +226,109 @@ def take(a, indices, axis=None, out=None, mode='raise'):

- 当提供参数 `out` 的时候,输出的数据将填充到 `out` 中。

numpy.take 的 mode 参数的越界处理的逻辑和核心代码[在这里](https://github.com/numpy/numpy/blob/54c52f13713f3d21795926ca4dbb27e16fada171/numpy/core/src/multiarray/item_selection.c#L34-L134):

```c++
// raise 模式,在索引超出最小(大)范围后直接报错
check_and_adjust_index(npy_intp *index, npy_intp max_item, int axis,
PyThreadState * _save)
{
/* Check that index is valid, taking into account negative indices */
if (NPY_UNLIKELY((*index < -max_item) || (*index >= max_item))) {
NPY_END_THREADS;
/* Try to be as clear as possible about what went wrong. */
if (axis >= 0) {
PyErr_Format(PyExc_IndexError,
"index %"NPY_INTP_FMT" is out of bounds "
"for axis %d with size %"NPY_INTP_FMT,
*index, axis, max_item);
} else {
PyErr_Format(PyExc_IndexError,
"index %"NPY_INTP_FMT" is out of bounds "
"for size %"NPY_INTP_FMT, *index, max_item);
}
return -1;
}
/* adjust negative indices */
if (*index < 0) {
*index += max_item;
}
return 0;
}
```

```c++
npy_fasttake_impl(
char *dest, char *src, const npy_intp *indices, ...)
{
NPY_BEGIN_THREADS_DEF;
NPY_BEGIN_THREADS_DESCR(dtype);
switch (clipmode) {
case NPY_RAISE:
/**
raise 模式,通过上方 check_and_adjust_index 函数判断越界索引
*/
for (npy_intp i = 0; i < n; i++) {
for (npy_intp j = 0; j < m; j++) {
npy_intp tmp = indices[j];
if (check_and_adjust_index(&tmp, max_item, axis,
_save) < 0) {
return -1;
}
...
}
break;
case NPY_WRAP:
/**
wrap 模式,将超出范围的索引取余。
由于 C 语言异号取余原则与 Python 不同,
因此使用循环加减来模拟 Python 取余结果。
*/
for (npy_intp i = 0; i < n; i++) {
for (npy_intp j = 0; j < m; j++) {
npy_intp tmp = indices[j];
if (tmp < 0) {
while (tmp < 0) {
tmp += max_item;
}
}
else if (tmp >= max_item) {
while (tmp >= max_item) {
tmp -= max_item;
}
}
...
}
}
break;
case NPY_CLIP:
/**
clip 模式,将两端超出范围的索引约束到 [0, max_item-1]
*/
for (npy_intp i = 0; i < n; i++) {
for (npy_intp j = 0; j < m; j++) {
npy_intp tmp = indices[j];
if (tmp < 0) {
tmp = 0;
}
else if (tmp >= max_item) {
tmp = max_item - 1;
}
...
}
}
break;
}
return 0;
}

```
综上,mode 索引越界处理方式总结为

- `mode='raise'`,判断索引范围并抛出错误 (默认),不需要调用 API;
- `mode='wrap'`,通过取余约束越界的 indices,不需要调用 API;
- `mode='clip'`,将两端超出范围的索引约束到 [0, max_item-1]。

## TensorFlow

据我们的调研情况,TensorFlow 中没有自己实现 `take` API,而是直接调用 `numpy.take`:[tf.experimental.numpy.take](https://tensorflow.google.cn/api_docs/python/tf/experimental/numpy/take)。
Expand All @@ -233,11 +337,13 @@ def take(a, indices, axis=None, out=None, mode='raise'):

- `torch.take` 的 `index` 参数必须为 LongTensor 类型;`numpy.take` 直接对参数 `indices` 的元素取整再进行索引。

- 当不指定轴时,对于相同的索引矩阵,`Numpy.take` 的执行结果等于 `torch.take`。
- 在维度支持上,`numpy.take` 支持指定轴,`torch.take` 不支持。

- 当不指定轴时,对于相同的索引矩阵,`numpy.take` 的执行结果等于 `torch.take`。

- 在维度支持上,`Numpy.take` 支持指定轴,`torch.take` 不支持
- `numpy.take` 支持通过 `mode` 参数指定索引越界的 3 种处理方式,默认直接报错;`torch.take` 在索引越界时直接报错

- `Numpy.take` 支持通过 `mode` 参数指定索引越界的 3 种处理方式,默认直接报错;`torch.take` 在索引越界时直接报错
> 由于 numpy 在指定轴索引后得到的结果不能保证与 index 的 shape 一致,会破坏 take 方法的输出结果形状与 index 一致的特性。因此我们决定新增的 `paddle.take` 的功能与 `torch.take` 和 `numpy.take` 的默认形式保持一致,即,不增加 axis 参数指定索引轴;在 `torch.take` 的基础上增加 mode 参数提供三种 index 索引越界的处理方式。尽可能保持 take 索引方法简洁、易理解的特性

# 五、方案设计

Expand All @@ -247,8 +353,9 @@ def take(a, indices, axis=None, out=None, mode='raise'):

```python
paddle.take(
input: Tensor,
x: Tensor,
index: Tensor,
mode: str='raise',
name: str=None)
```

Expand All @@ -268,21 +375,30 @@ paddle.take(

1. 通过 `Tensor.flatten()` 将输入 x 和 index 展开成 1D Tensor。

2. 通过 `Tensor.index_select(index)` 按照 index 中的索引提取对应元素。
2. 根据 mode 参数对索引进行越界处理:
- `mode='raise'`,若索引越界,通过最后调用的 `paddle.index_select` 抛出错误 (默认);
- `mode='wrap'`,通过取余约束越界的 indices;
- `mode='clip'`,通过 `paddle.clip` 将两端超出范围的索引约束到 [0, max_index-1]。

3. 通过 `Tensor.index_select(index)` 按照 index 中的索引提取对应元素。
- `numpy.take` 和 `torch.take` 支持负值索引;
- 然而 `index_select` 不支持,因此需要先将 index 的负值索引转为对应的正值索引。

3. 通过 `Tensor.reshape(index.shape)` 将输出的 Tensor 形状转成 index 的形状。
4. 通过 `Tensor.reshape(index.shape)` 将输出的 Tensor 形状转成 index 的形状。


# 六、测试和验收的考量

测试考虑的 case 如下:

- 参数 `index` 数据类型必须为 `paddle.int32` 和 `paddle.int64` 类型的 Tensor(与 `paddle.index_select` 一致)。

- `input` 的数据类型支持 `int32`,`int64`,`float32`,`float64`。
- `x` 的数据类型支持 `int32`,`int64`,`float32`,`float64`。

- `index` 索引越界时直接报错。
- `index` 索引越界的三种处理方式:
- `mode='raise'`,若索引越界,通过最后调用的 `paddle.index_select` 抛出错误 (默认);
- `mode='wrap'`,通过取余约束越界的 indices;
- `mode='clip'`,通过 `paddle.clip` 将两端超出范围的索引约束到 [0, max_index-1]。

- 在动态图、静态图下,以及 CPU、GPU 下,都能得到正确的结果。

Expand Down