Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[triu_indices] add triu_indices_op #45168

Merged
merged 29 commits into from Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
86 changes: 86 additions & 0 deletions paddle/fluid/operators/triu_indices_op.cc
@@ -0,0 +1,86 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <memory>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/nullary.h"

namespace paddle {
namespace operators {

class TriuIndicesOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};

class TriuIndicesOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddOutput("out",
"Tensor, the output tensor, with the shape (2,x), x bounded by "
"[0,rows*cols])");
AddAttr<int>("rows",
"int number, the input of triu_indices op"
"which describes the number of row of the matrix")
.SetDefault(0);
AddAttr<int>("cols",
"int number, the input of triu_indices op"
"which describes the number of col of the matrix")
.SetDefault(0);
AddAttr<int>(
"offset",
"int number, the input of triu_indices op bounded by [1-rows,cols-1"
"which describes the dignalline index of the upper triangular part of "
"the matrix")
.SetDefault(0);
AddAttr<int>("dtype", "data type ,the input of triu_indices op")
.SetDefault(framework::proto::VarType::INT64);

AddComment(R"DOC(
TriuIndices Operator.
The triu_indices operator returns the indices of the upper triangular part of the matrix
whose rows and cols is known. It is a 2-by-x tensor, where the first row contains row coordinates
of all indices and the second row contains column coordinates. Indices are ordered based on
rows and then columns. The upper triangular part of the matrix is defined as the elements on
and below the diagonal.
The argument offset controls which diagonal to consider, default value is 0.
A positive value includes just as fewer diagonals above the main diagonal,
and similarly a negative value excludes just as fewer diagonals below the main diagonal
)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(triu_indices,
TriuIndicesInferShapeFunctor,
PD_INFER_META(phi::TriuIndicesInferMeta));

REGISTER_OPERATOR(
triu_indices,
ops::TriuIndicesOp,
ops::TriuIndicesOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
TriuIndicesInferShapeFunctor);
12 changes: 12 additions & 0 deletions paddle/phi/api/yaml/legacy_api.yaml
Expand Up @@ -2747,6 +2747,18 @@
data_type : x
backward : trilinear_interp_grad

- api : triu_indices
args : (int rows, int cols, int offset, DataType dtype, Place place={})
output : Tensor(out)
infer_meta :
func : TriuIndicesInferMeta
param : [rows, cols, offset, dtype]
kernel :
func : triu_indices
param : [rows, cols, offset, dtype]
data_type : dtype
backend : place

# python API: paddle.nn.initializer.TruncatedNormal
- api : truncated_gaussian_random
args : (int[] shape, float mean, float std, int seed, DataType dtype=DataType::FLOAT32, Place place={})
Expand Down
30 changes: 30 additions & 0 deletions paddle/phi/infermeta/nullary.cc
Expand Up @@ -152,4 +152,34 @@ void TrilIndicesInferMeta(
out->set_dims(out_dims);
out->set_dtype(dtype);
}

void TriuIndicesInferMeta(
int rows, int cols, int offset, DataType dtype, MetaTensor* out) {
// number of elements in the first row of the tril,bounded by [0, cols]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

开头说明下计算逻辑:是通过总元素数量-下三角元素数量求得上三角元素数量的,所以offset要-1。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯好的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

开头说明下计算逻辑:是通过总元素数量-下三角元素数量求得上三角元素数量的,所以offset要-1。

已添加

// use total item number minus bottom rectangle item number to get
// the above rectangle item number
// triu_size = rows * cols - tril_size
// so the `offset` need to be set as `offset-1` in order to include
// the item on the diagonal line
offset = offset - 1;
auto n_first_row =
offset > 0 ? std::min<int64_t>(cols, 1 + offset) : rows + offset > 0;
// number of elements in the last row of the tril, bounded by [0, cols]
auto n_last_row =
std::max<int64_t>(0, std::min<int64_t>(cols, rows + offset));
// number of rows, bounded by [0, rows]
auto n_row_all = std::max<int64_t>(0, std::min<int64_t>(rows, rows + offset));
auto n_row_trapezoid = (n_last_row - n_first_row + 1);
// calculate # of elements in the top trapezoid
auto tril_size = (n_first_row + n_last_row) * n_row_trapezoid >> 1;
// calculate # of elements in the bottom rectangle if there is any
auto diff_row = n_row_all - n_row_trapezoid;
if (diff_row > 0) {
tril_size += diff_row * cols;
}
std::vector<int64_t> tmp = {2, rows * cols - tril_size};
auto out_dims = phi::make_ddim(tmp);
out->set_dims(out_dims);
out->set_dtype(dtype);
}
} // namespace phi
3 changes: 3 additions & 0 deletions paddle/phi/infermeta/nullary.h
Expand Up @@ -74,4 +74,7 @@ void UniformRandomInferMeta(const IntArray& shape,

void TrilIndicesInferMeta(
int rows, int cols, int offset, DataType dtype, MetaTensor* out);

void TriuIndicesInferMeta(
int rows, int cols, int offset, DataType dtype, MetaTensor* out);
} // namespace phi
51 changes: 51 additions & 0 deletions paddle/phi/kernels/cpu/triu_indices_kernel.cc
@@ -0,0 +1,51 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/triu_indices_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {
template <typename T, typename Context>
void TriuIndicesKernel(const Context& dev_ctx,
int rows,
int cols,
int offset,
DataType dtype,
DenseTensor* out) {
T* out_data = dev_ctx.template Alloc<T>(out);
const auto& out_dims = out->dims();
int64_t triu_size = out_dims[1];
int64_t i = 0;
T c = std::max<int64_t>(0, offset), r = 0;
while (i < triu_size) {
out_data[i] = r;
out_data[triu_size + i++] = c;

// move to the next column and check if (r, c) is still in bound
c += 1;
if (c >= cols) {
r += 1;
// not typing std::max with scalar_t as it could be an unsigned type
// NOTE: not necessary to check if c is less than col or overflows here,
// because i and triu_size act as a guard.
c = std::max<int64_t>(0, r + offset);
}
}
}
} // namespace phi

PD_REGISTER_KERNEL(
triu_indices, CPU, ALL_LAYOUT, phi::TriuIndicesKernel, int, int64_t) {}
133 changes: 133 additions & 0 deletions paddle/phi/kernels/gpu/triu_indices_kernel.cu
@@ -0,0 +1,133 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/triu_indices_kernel.h"

#include <algorithm>
#include <tuple>

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T>
__device__ inline int resolve_root_int(int b, int cX4, int x, int32_t sign) {
int bXb_cX4 = b * b - cX4;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to multiply two int32 numbers, it's better to use int64_t or long long int to avoid overflow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

double sr = ::sqrt(static_cast<double>(bXb_cX4));
T res = ::__double2ll_rd((-b + sign * sr) / 2);
if (bXb_cX4 != static_cast<int>(sr * sr)) {
int llsr = ::__double2ll_rd(sr);
int diff = ::__double2ll_ru(
::sqrt(::fabs(static_cast<double>(bXb_cX4 - llsr * llsr))));
auto l = res > diff ? res - diff : 0;
auto r = res + diff + 1;
x <<= 1;
while (l + 1 < r) {
auto m = (l + r) >> 1;
if (sign * (b + m) * m > x) {
r = m;
} else {
l = m;
}
}
res = l;
}
return res;
}

template <typename T>
__device__ inline void get_coordinate_in_triu_trapezoid(int f,
int x,
T* row,
T* col) {
f <<= 1; // all statements use 2f, so only calculate it once here.
auto b = -1 - f;
auto cX4 = x << 3; // 4 * c = 4 * (2x) = 8x;
*row = resolve_root_int<T>(b, cX4, x, -1);
*col = (x - (((f - *row + 1) * *row) >> 1)) + *row;
}

template <typename T>
__global__ void triu_indices_kernel(T* out_data,
int col_offset,
int m_first_row,
int col,
int rectangle_size,
int triu_size) {
int linear_index = blockIdx.x * blockDim.x + threadIdx.x;

if (linear_index < triu_size) {
T r, c;
if (linear_index < rectangle_size) {
// the coordinate is within the top rectangle
r = linear_index / col;
c = linear_index % col;
} else {
// the coordinate falls in the bottom trapezoid
get_coordinate_in_triu_trapezoid<T>(
m_first_row, linear_index - rectangle_size, &r, &c);
r += rectangle_size / col;
}

c += col_offset;
out_data[linear_index] = r;
out_data[linear_index + triu_size] = c;
}
}

template <typename T, typename Context>
void TriuIndicesKernel(const Context& dev_ctx,
int rows,
int cols,
int offset,
DataType dtype,
DenseTensor* out) {
T* out_data = dev_ctx.template Alloc<T>(out);
auto out_dims = out->dims();
int triu_size = out_dims[1];
// auto tensor = empty_cuda({2, triu_size}, dtype_opt, layout_opt,
// device_opt, pin_memory_opt);

if (triu_size > 0) {
// # of triu elements in the first row
auto m_first_row = offset > 0 ? std::max<int>(cols - offset, 0)
: // upper bounded by col
cols;

// size of the top rectangle
int rectangle_size = 0;
if (offset < 0) {
rectangle_size = std::min<int>(rows, -offset) * cols;
}

// using gpu_launch_config to get grid_size and block_size
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, triu_size);

triu_indices_kernel<T><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(out_data,
std::max<int>(0, offset),
m_first_row,
cols,
rectangle_size,
triu_size);
}
}
} // namespace phi

PD_REGISTER_KERNEL(
triu_indices, GPU, ALL_LAYOUT, phi::TriuIndicesKernel, int, int64_t) {}
29 changes: 29 additions & 0 deletions paddle/phi/kernels/triu_indices_kernel.h
@@ -0,0 +1,29 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void TriuIndicesKernel(const Context& dev_ctx,
int rows,
int cols,
int offset,
DataType dtype,
DenseTensor* out);

} // namespace phi
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
Expand Up @@ -110,6 +110,7 @@
from .tensor.creation import complex # noqa: F401
from .tensor.creation import clone # noqa: F401
from .tensor.creation import tril_indices #noqa: F401
from .tensor.creation import triu_indices #noqa: F401
from .tensor.linalg import matmul # noqa: F401
from .tensor.linalg import dot # noqa: F401
from .tensor.linalg import norm # noqa: F401
Expand Down Expand Up @@ -654,4 +655,5 @@
'heaviside',
'tril_indices',
'sgn',
'triu_indices',
]