Skip to content

Commit

Permalink
【Hackathon No.16】add RFC for take API (#217)
Browse files Browse the repository at this point in the history
* Create 2022-05-26-meeting-agenda.md

* add rfcs/APIs/20220714_api_design_for_take.md

* deleted rfcs/APIs/20220712_api_design_for_take.md

* fix rfcs/APIs/20220714_api_design_for_take.md

* Update 20220714_api_design_for_take.md

* Improve the survey

* Improve the survey

* update api_design_for_take.md

* Clarify the mode parameter
  • Loading branch information
S-HuaBomb committed Aug 29, 2022
1 parent e096869 commit 8e9a794
Showing 1 changed file with 124 additions and 8 deletions.
132 changes: 124 additions & 8 deletions rfcs/APIs/20220714_api_design_for_take.md
Expand Up @@ -217,6 +217,7 @@ def take(a, indices, axis=None, out=None, mode='raise'):
"""
return _wrapfunc(a, 'take', indices, axis=axis, out=out, mode=mode)
```

整体逻辑为:

- 当指定 `axis` 的时候,`numpy.take` 执行与 “fancy indexing” 相同的索引操作(使用数组索引数组);例如 `np.take(arr, indices, axis=3)` 等价于 `arr[:, :, :, indices, ...]`
Expand All @@ -225,6 +226,109 @@ def take(a, indices, axis=None, out=None, mode='raise'):

- 当提供参数 `out` 的时候,输出的数据将填充到 `out` 中。

numpy.take 的 mode 参数的越界处理的逻辑和核心代码[在这里](https://github.com/numpy/numpy/blob/54c52f13713f3d21795926ca4dbb27e16fada171/numpy/core/src/multiarray/item_selection.c#L34-L134)

```c++
// raise 模式,在索引超出最小(大)范围后直接报错
check_and_adjust_index(npy_intp *index, npy_intp max_item, int axis,
PyThreadState * _save)
{
/* Check that index is valid, taking into account negative indices */
if (NPY_UNLIKELY((*index < -max_item) || (*index >= max_item))) {
NPY_END_THREADS;
/* Try to be as clear as possible about what went wrong. */
if (axis >= 0) {
PyErr_Format(PyExc_IndexError,
"index %"NPY_INTP_FMT" is out of bounds "
"for axis %d with size %"NPY_INTP_FMT,
*index, axis, max_item);
} else {
PyErr_Format(PyExc_IndexError,
"index %"NPY_INTP_FMT" is out of bounds "
"for size %"NPY_INTP_FMT, *index, max_item);
}
return -1;
}
/* adjust negative indices */
if (*index < 0) {
*index += max_item;
}
return 0;
}
```
```c++
npy_fasttake_impl(
char *dest, char *src, const npy_intp *indices, ...)
{
NPY_BEGIN_THREADS_DEF;
NPY_BEGIN_THREADS_DESCR(dtype);
switch (clipmode) {
case NPY_RAISE:
/**
raise 模式,通过上方 check_and_adjust_index 函数判断越界索引
*/
for (npy_intp i = 0; i < n; i++) {
for (npy_intp j = 0; j < m; j++) {
npy_intp tmp = indices[j];
if (check_and_adjust_index(&tmp, max_item, axis,
_save) < 0) {
return -1;
}
...
}
break;
case NPY_WRAP:
/**
wrap 模式,将超出范围的索引取余。
由于 C 语言异号取余原则与 Python 不同,
因此使用循环加减来模拟 Python 取余结果。
*/
for (npy_intp i = 0; i < n; i++) {
for (npy_intp j = 0; j < m; j++) {
npy_intp tmp = indices[j];
if (tmp < 0) {
while (tmp < 0) {
tmp += max_item;
}
}
else if (tmp >= max_item) {
while (tmp >= max_item) {
tmp -= max_item;
}
}
...
}
}
break;
case NPY_CLIP:
/**
clip 模式,将两端超出范围的索引约束到 [0, max_item-1]
*/
for (npy_intp i = 0; i < n; i++) {
for (npy_intp j = 0; j < m; j++) {
npy_intp tmp = indices[j];
if (tmp < 0) {
tmp = 0;
}
else if (tmp >= max_item) {
tmp = max_item - 1;
}
...
}
}
break;
}
return 0;
}
```
综上,mode 索引越界处理方式总结为

- `mode='raise'`,判断索引范围并抛出错误 (默认),不需要调用 API;
- `mode='wrap'`,通过取余约束越界的 indices,不需要调用 API;
- `mode='clip'`,将两端超出范围的索引约束到 [0, max_item-1]

## TensorFlow

据我们的调研情况,TensorFlow 中没有自己实现 `take` API,而是直接调用 `numpy.take`[tf.experimental.numpy.take](https://tensorflow.google.cn/api_docs/python/tf/experimental/numpy/take)
Expand All @@ -233,11 +337,13 @@ def take(a, indices, axis=None, out=None, mode='raise'):

- `torch.take``index` 参数必须为 LongTensor 类型;`numpy.take` 直接对参数 `indices` 的元素取整再进行索引。

- 当不指定轴时,对于相同的索引矩阵,`Numpy.take` 的执行结果等于 `torch.take`
- 在维度支持上,`numpy.take` 支持指定轴,`torch.take` 不支持。

- 当不指定轴时,对于相同的索引矩阵,`numpy.take` 的执行结果等于 `torch.take`

- 在维度支持上,`Numpy.take` 支持指定轴,`torch.take` 不支持
- `numpy.take` 支持通过 `mode` 参数指定索引越界的 3 种处理方式,默认直接报错;`torch.take` 在索引越界时直接报错

- `Numpy.take` 支持通过 `mode` 参数指定索引越界的 3 种处理方式,默认直接报错;`torch.take` 在索引越界时直接报错
> 由于 numpy 在指定轴索引后得到的结果不能保证与 index 的 shape 一致,会破坏 take 方法的输出结果形状与 index 一致的特性。因此我们决定新增的 `paddle.take` 的功能与 `torch.take``numpy.take` 的默认形式保持一致,即,不增加 axis 参数指定索引轴;在 `torch.take` 的基础上增加 mode 参数提供三种 index 索引越界的处理方式。尽可能保持 take 索引方法简洁、易理解的特性
# 五、方案设计

Expand All @@ -247,8 +353,9 @@ def take(a, indices, axis=None, out=None, mode='raise'):

```python
paddle.take(
input: Tensor,
x: Tensor,
index: Tensor,
mode: str='raise',
name: str=None)
```

Expand All @@ -268,21 +375,30 @@ paddle.take(

1. 通过 `Tensor.flatten()` 将输入 x 和 index 展开成 1D Tensor。

2. 通过 `Tensor.index_select(index)` 按照 index 中的索引提取对应元素。
2. 根据 mode 参数对索引进行越界处理:
- `mode='raise'`,若索引越界,通过最后调用的 `paddle.index_select` 抛出错误 (默认);
- `mode='wrap'`,通过取余约束越界的 indices;
- `mode='clip'`,通过 `paddle.clip` 将两端超出范围的索引约束到 [0, max_index-1]

3. 通过 `Tensor.index_select(index)` 按照 index 中的索引提取对应元素。
- `numpy.take``torch.take` 支持负值索引;
- 然而 `index_select` 不支持,因此需要先将 index 的负值索引转为对应的正值索引。

3. 通过 `Tensor.reshape(index.shape)` 将输出的 Tensor 形状转成 index 的形状。
4. 通过 `Tensor.reshape(index.shape)` 将输出的 Tensor 形状转成 index 的形状。


# 六、测试和验收的考量

测试考虑的 case 如下:

- 参数 `index` 数据类型必须为 `paddle.int32``paddle.int64` 类型的 Tensor(与 `paddle.index_select` 一致)。

- `input` 的数据类型支持 `int32``int64``float32``float64`
- `x` 的数据类型支持 `int32``int64``float32``float64`

- `index` 索引越界时直接报错。
- `index` 索引越界的三种处理方式:
- `mode='raise'`,若索引越界,通过最后调用的 `paddle.index_select` 抛出错误 (默认);
- `mode='wrap'`,通过取余约束越界的 indices;
- `mode='clip'`,通过 `paddle.clip` 将两端超出范围的索引约束到 [0, max_index-1]

- 在动态图、静态图下,以及 CPU、GPU 下,都能得到正确的结果。

Expand Down

0 comments on commit 8e9a794

Please sign in to comment.