Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 第三期 No.4】 为 Paddle 新增 cummax API #45073

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
65 changes: 65 additions & 0 deletions paddle/fluid/operators/cum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,62 @@ class CumsumGradMaker : public framework::SingleGradOpMaker<T> {
}
};

class CummaxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};

class CummaxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of cummax operator");
// AddOutput("Out_values", "Output values of cummax operator");
// AddOutput("Out_indices", "Output indices of cummax operator");
AddOutput("Out", "Output values of cummax operator");
AddAttr<int>("axis",
"The dimension to operate along. -1 means the last "
"dimension [default: -1].")
.SetDefault(-1);
AddAttr<bool>("flatten",
"Whether to compute the cummax over the flattened array. "
"[default: false].")
.SetDefault(false);
AddAttr<bool>("exclusive",
"Whether to perform exclusive cummax. [default: false].")
.SetDefault(false);
AddAttr<bool>("reverse",
"If true, the cummax is performed in the reversed direction. "
"[default: false].")
.SetDefault(false);
AddComment(R"DOC(
The cumulative maximum and corresponding index of the elements along a given axis.
By default, the first element of the out_values is the same as the first element of
the input. If exclusive is true, the first element of the result is 0.
)DOC");
}
};

template <typename T>
class CummaxGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("cummax");
// grad_op->SetInput("X", this->OutputGrad("Out_values"));
grad_op->SetInput("X", this->OutputGrad("Out"));
grad_op->SetOutput("Out", this->InputGrad("X"));
grad_op->SetAttr("axis", PADDLE_GET_CONST(int, this->GetAttr("axis")));
grad_op->SetAttr("flatten",
PADDLE_GET_CONST(bool, this->GetAttr("flatten")));
grad_op->SetAttr("reverse",
!PADDLE_GET_CONST(bool, this->GetAttr("reverse")));
grad_op->SetAttr("exclusive",
PADDLE_GET_CONST(bool, this->GetAttr("exclusive")));
}
};

class LogcumsumexpOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
Expand Down Expand Up @@ -149,6 +205,9 @@ using CPU = phi::CPUContext;
DECLARE_INFER_SHAPE_FUNCTOR(cumsum,
CumsumInferShapeFunctor,
PD_INFER_META(phi::CumInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(cummax,
CummaxInferShapeFunctor,
PD_INFER_META(phi::CummaxInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(logcumsumexp,
LogcumsumexpInferShapeFunctor,
PD_INFER_META(phi::CumInferMeta));
Expand All @@ -158,6 +217,12 @@ REGISTER_OPERATOR(cumsum,
ops::CumsumGradMaker<paddle::framework::OpDesc>,
ops::CumsumGradMaker<paddle::imperative::OpBase>,
CumsumInferShapeFunctor);
REGISTER_OPERATOR(cummax,
ops::CummaxOp,
ops::CummaxOpMaker,
ops::CummaxGradOpMaker<paddle::framework::OpDesc>,
ops::CummaxGradOpMaker<paddle::imperative::OpBase>,
CummaxInferShapeFunctor);
REGISTER_OPERATOR(logcumsumexp,
ops::CumOp,
ops::LogcumsumexpOpMaker,
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@
data_type : x
backward : cross_grad

- api : cummax
args : (Tensor x, int axis, bool flatten, bool exclusive, bool reverse)
output : Tensor(out)
infer_meta :
func : CummaxInferMeta
kernel :
func : cummax
backward : cummax_grad

- api : diag
args : (Tensor x, int offset = 0, float padding_value = 0.0)
output : Tensor
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@
func : cross_grad
data_type : out_grad

- backward_api : cummax_grad
forward : cummax(Tensor x, int axis, bool flatten, bool exclusive, bool reverse) -> Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
args : (Tensor out_grad, int axis, bool flatten, bool exclusive, bool reverse)
output : Tensor(x_grad)
invoke : cummax(out_grad, axis, flatten, exclusive, !reverse)

- backward_api : diag_grad
forward : diag (Tensor x, int offset, float padding_value) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int offset)
Expand Down
27 changes: 27 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,33 @@ void CumInferMeta(const MetaTensor& x,
out->share_lod(x);
}

void CummaxInferMeta(const MetaTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
// MetaTensor* out_values,
// MetaTensor* out_indices
MetaTensor* out) {
auto x_dims = x.dims();
if (flatten) {
// out_values->set_dims(phi::make_ddim({phi::product(x_dims)}));
// out_indices->set_dims(phi::make_ddim({phi::product(x_dims)}));
out->set_dims(phi::make_ddim({phi::product(x_dims)}));
} else {
// out_values->set_dims(x_dims);
// out_indices->set_dims(x_dims);
out->set_dims(x_dims);
}
// out_values->set_dtype(x.dtype());
// out_indices->set_dtype(DataType::INT64);
out->set_dtype(x.dtype());

// out_values->share_lod(x);
// out_indices->share_lod(x);
out->share_lod(x);
}

void CropTensorInferMeta(const MetaTensor& x,
const IntArray& shape,
const IntArray& offsets,
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@ void CumInferMeta(const MetaTensor& x,
bool reverse,
MetaTensor* out);

void CummaxInferMeta(const MetaTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
// MetaTensor* out_values,
// MetaTensor* out_indices);
MetaTensor* out);

void DecodeJpegInferMeta(const MetaTensor& x,
const std::string& mode,
MetaTensor* out);
Expand Down
32 changes: 30 additions & 2 deletions paddle/phi/kernels/cpu/cum_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,20 @@ void CumsumKernel(const Context& dev_ctx,
dev_ctx, x, axis, flatten, exclusive, reverse, reducer, out);
}

template <typename T, typename Context>
void CummaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
DenseTensor* out) {
using Reducer = Eigen::internal::MaxReducer<T>;
auto reducer = Reducer();
ScanKernel<T, Context, Reducer>(
dev_ctx, x, axis, flatten, exclusive, reverse, reducer, out);
}

template <typename T>
struct LogSumExp {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a,
Expand Down Expand Up @@ -267,5 +281,19 @@ PD_REGISTER_KERNEL(cumsum,
int,
int64_t) {}

PD_REGISTER_KERNEL(
logcumsumexp, CPU, ALL_LAYOUT, phi::LogcumsumexpKernel, float, double) {}
PD_REGISTER_KERNEL(cummax,
CPU,
ALL_LAYOUT,
phi::CummaxKernel,
float,
double,
int16_t,
int,
int64_t) {}

PD_REGISTER_KERNEL(logcumsumexp,
CPU,
ALL_LAYOUT,
phi::LogcumsumexpKernel,
float,
double) {}
9 changes: 9 additions & 0 deletions paddle/phi/kernels/cum_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ void CumsumKernel(const Context& dev_ctx,
bool reverse,
DenseTensor* out);

template <typename T, typename Context>
void CummaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
DenseTensor* out);

template <typename T, typename Context>
void LogcumsumexpKernel(const Context& dev_ctx,
const DenseTensor& x,
Expand Down
40 changes: 38 additions & 2 deletions paddle/phi/kernels/gpu/cum_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,14 @@ struct LogAddExp {
}
};

struct Max {
template <typename T>
__host__ __device__ __forceinline__ T operator()(const T& a,
const T& b) const {
return std::max(a, b);
}
};

template <typename T, typename op>
struct Identity;

Expand Down Expand Up @@ -364,6 +372,20 @@ void CumsumKernel(const Context& dev_ctx,
dev_ctx, x, axis, flatten, exclusive, reverse, op, out);
}

template <typename T, typename Context>
void CummaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
DenseTensor* out) {
using Op = Max;
auto op = Op();
ScanKernel<T, Context, Op>(
dev_ctx, x, axis, flatten, exclusive, reverse, op, out);
}

template <typename T, typename Context>
void LogcumsumexpKernel(const Context& dev_ctx,
const DenseTensor& x,
Expand All @@ -390,5 +412,19 @@ PD_REGISTER_KERNEL(cumsum,
int,
int64_t) {}

PD_REGISTER_KERNEL(
logcumsumexp, GPU, ALL_LAYOUT, phi::LogcumsumexpKernel, float, double) {}
PD_REGISTER_KERNEL(cummax,
GPU,
ALL_LAYOUT,
phi::CummaxKernel,
float,
double,
int16_t,
int,
int64_t) {}

PD_REGISTER_KERNEL(logcumsumexp,
GPU,
ALL_LAYOUT,
phi::LogcumsumexpKernel,
float,
double) {}
68 changes: 68 additions & 0 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -3079,6 +3079,7 @@ def cumsum(x, axis=None, dtype=None, name=None):
else:
return _C_ops.cumsum(x, 'axis', axis, 'flatten', flatten)

# renziji: the static mode part is different, does it need to be modified ???
check_type(x, 'x', (Variable), 'cumsum')
locals_var = locals().copy()
kwargs = dict()
Expand All @@ -3089,6 +3090,73 @@ def cumsum(x, axis=None, dtype=None, name=None):
return _cum_sum_(**kwargs)


def cummax(x, axis=None, dtype=None, name=None):
"""
The cumulative sum of the elements along a given axis.

**Note**:
The first element of the result is the same as the first element of the input.

Args:
x (Tensor): The input tensor needed to be cumsumed.
axis (int, optional): The dimension to accumulate along. -1 means the last dimension. The default (None) is to compute the cumsum over the flattened array.
dtype (str, optional): The data type of the output tensor, can be float32, float64, int32, int64. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. The default value is None.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

Returns:
Tensor, the result of cumsum operator.

Examples:
.. code-block:: python

import paddle

data = paddle.arange(12)
data = paddle.reshape(data, (3, 4))

y = paddle.cumsum(data)
# [ 0 1 3 6 10 15 21 28 36 45 55 66]

y = paddle.cumsum(data, axis=0)
# [[ 0 1 2 3]
# [ 4 6 8 10]
# [12 15 18 21]]

y = paddle.cumsum(data, axis=-1)
# [[ 0 1 3 6]
# [ 4 9 15 22]
# [ 8 17 27 38]]

y = paddle.cumsum(data, dtype='float64')
print(y.dtype)
# paddle.float64
"""
if axis is None:
flatten = True
else:
flatten = False
if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype):
x = cast(x, dtype)

if in_dygraph_mode():
if axis is None: axis = -1
return _C_ops.final_state_cummax(x, axis, flatten, False, False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你好,Paddle底层过了改造,同意去掉了final_state_前缀。麻烦把这里的final_state_删掉吧~

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@OccupyMars2025
改造的PR见 :#45306
开发文档的改造见:#45306 (comment)

if _in_legacy_dygraph():
if axis is None:
return _C_ops.cummax(x, 'flatten', flatten)
else:
return _C_ops.cummax(x, 'axis', axis, 'flatten', flatten)
Comment on lines +3146 to +3148
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你好,Paddle底层过了改造,以后老动态图API不再放在_C_ops,而是放在_legacy_C_ops下。辛苦这里也修改一下吧。


check_type(x, 'x', (Variable), 'cummax')
locals_var = locals().copy()
kwargs = dict()
for name, val in locals_var.items():
if val is not None:
kwargs[name] = val
_cum_sum_ = generate_layer_fn('cummax')
return _cum_sum_(**kwargs)


def logcumsumexp(x, axis=None, dtype=None, name=None):
r"""
The logarithm of the cumulative summation of the exponentiation of the elements along a given axis.
Expand Down