Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.19】Implement ASGD optimizer #42431

72 changes: 72 additions & 0 deletions paddle/fluid/operators/optimizers/asgd_op.cc
@@ -0,0 +1,72 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

class AsgdOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace());
}
};

class AsgdOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
AddInput("LearningRate", "(Tensor) Learning rate of SGD");
AddInput("AvgParam",
"(Tensor) Average of parameter");
AddInput("CurrentStep",
"(Tensor) Current step");
AddOutput("ParamOut",
"(Tensor, same with Param) "
"Output parameter, should share the same memory with Param");
AddOutput("AvgParamOut",
"(Tensor, same with AvgParam) Average of parameter");
AddOutput("CurrentStepOut",
"(Tensor) Increased step");

AddAttr<float>("t0",
"(float, default 1e6) point at which to start averaging")
.SetDefault(0.95f);
AddComment(R"DOC(
)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(asgd, AsgdInferMetaFunctor,
PD_INFER_META(phi::AsgdInferMeta));
REGISTER_OPERATOR(
asgd, ops::AsgdOp, ops::AsgdOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
AsgdInferMetaFunctor);
46 changes: 46 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Expand Up @@ -1926,6 +1926,52 @@ void RnnInferMeta(const MetaTensor& x,
}
}

void AsgdInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& learning_rate,
const MetaTensor& avg_param,
const MetaTensor& current_step,
float t0,
MetaTensor* param_out,
MetaTensor* avg_param_out,
MetaTensor* current_step_out) {
PADDLE_ENFORCE_NOT_NULL(param_out,
phi::errors::InvalidArgument(
"Output(ParamOut) of SGDOp should not be null."));

auto lr_dims = learning_rate.dims();
PADDLE_ENFORCE_EQ(phi::product(lr_dims),
1,
phi::errors::InvalidArgument(
"Learning rate should have 1 element. But received "
"LearningRate dims [%s]",
phi::product(lr_dims)));

auto current_step_dims = current_step.dims();
PADDLE_ENFORCE_EQ(phi::product(current_step_dims),
1,
phi::errors::InvalidArgument(
"Current step should have 1 element. But received "
"dims [%s]",
phi::product(current_step_dims)));

auto param_dims = param.dims();
auto avg_param_dims = avg_param.dims();
PADDLE_ENFORCE_EQ(param_dims,
avg_param_dims,
phi::errors::InvalidArgument(
"Param and AvgParam should have the same dims. But received "
"[%s] and [%s]",
param_dims, avg_param_dims));

param_out->set_dims(param.dims());
param_out->set_dtype(param.dtype());
avg_param_out->set_dims(param.dims());
avg_param_out->set_dtype(param.dtype());
current_step_out->set_dims(current_step.dims());
current_step_out->set_dtype(current_step.dtype());
}

void SgdInferMeta(const MetaTensor& param,
const MetaTensor& learning_rate,
const MetaTensor& grad,
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/infermeta/multiary.h
Expand Up @@ -122,6 +122,16 @@ void AddNInferMeta(const std::vector<const MetaTensor*>& x,
MetaTensor* out,
MetaConfig config = MetaConfig());

void AsgdInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& learning_rate,
const MetaTensor& avg_param,
const MetaTensor& current_step,
float t0,
MetaTensor* param_out,
MetaTensor* avg_param_out,
MetaTensor* current_step_out);

void AucInferMeta(const MetaTensor& input,
const MetaTensor& label,
const MetaTensor& stat_pos,
Expand Down
33 changes: 33 additions & 0 deletions paddle/phi/kernels/asgd_kernel.h
@@ -0,0 +1,33 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void AsgdKernel(const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& learning_rate,
const DenseTensor& grad,
const DenseTensor& avg_param,
const DenseTensor& current_step,
float t0,
DenseTensor* param_out,
DenseTensor* avg_param_out,
DenseTensor* current_step_out);

} // namespace phi
65 changes: 65 additions & 0 deletions paddle/phi/kernels/cpu/asgd_kernel.cc
@@ -0,0 +1,65 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/asgd_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"

namespace phi {

template <typename T, typename Context>
void AsgdKernel(const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& grad,
const DenseTensor& learning_rate,
const DenseTensor& avg_param,
const DenseTensor& current_step,
float t0,
DenseTensor* param_out,
DenseTensor* avg_param_out,
DenseTensor* current_step_out) {
dev_ctx.template Alloc<T>(param_out);
dev_ctx.template Alloc<T>(avg_param_out);
dev_ctx.template Alloc<T>(current_step_out);

auto eigen_param = EigenVector<T>::Flatten(param);
auto eigen_grad = EigenVector<T>::Flatten(grad);
auto eigen_avg_param = EigenVector<T>::Flatten(avg_param);
auto eigen_param_out = EigenVector<T>::Flatten(*param_out);
auto eigen_avg_param_out = EigenVector<T>::Flatten(*avg_param_out);
auto& place = *dev_ctx.eigen_device();

auto lr = learning_rate.data<T>()[0];
eigen_param_out.device(place) = eigen_param - lr * eigen_grad;

T current_step_data = current_step.data<T>()[0];

if (current_step_data <= t0) {
eigen_avg_param_out.device(place) = eigen_param_out;
} else {
const auto mu1 = 1 / (current_step_data - t0);
const auto mu2 = 1 - mu1;
eigen_avg_param_out.device(place) =
mu2 * eigen_avg_param + mu1 * eigen_param_out;
}
*current_step_out->mutable_data<T>(dev_ctx.GetPlace()) =
current_step_data + 1;
}

} // namespace phi

PD_REGISTER_KERNEL(asgd, CPU, ALL_LAYOUT, phi::AsgdKernel, float, double) {}
83 changes: 83 additions & 0 deletions paddle/phi/kernels/gpu/asgd_kernel.cu
@@ -0,0 +1,83 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/asgd_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T>
__global__ void ASGDKernel(const T* param,
const T* grad,
const T* learning_rate,
const T* avg_param,
const T* current_step,
float t0,
size_t num,
T* param_out,
T* avg_param_out) {
T lr = learning_rate[0];
CUDA_KERNEL_LOOP(i, num) { param_out[i] = param[i] - lr * grad[i]; }
T current_step_data = current_step[0];
if (current_step_data <= t0) {
memcpy(avg_param_out, param, num * sizeof(T));
} else {
const auto mu1 = 1 / (current_step_data - t0);
const auto mu2 = 1 - mu1;
CUDA_KERNEL_LOOP(i, num) {
avg_param_out[i] = mu2 * avg_param[i] + mu1 * param_out[i];
}
}
}

template <typename T>
__global__ void IncreaseStep(const T* step, T* step_out) {
*step_out = *step + 1;
}

template <typename T, typename Context>
void AsgdKernel(const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& learning_rate,
const DenseTensor& grad,
const DenseTensor& avg_param,
const DenseTensor& current_step,
float t0,
DenseTensor* param_out,
DenseTensor* avg_param_out,
DenseTensor* current_step_out) {
int block = 512;
int grid = (param.numel() + block - 1) / block;

ASGDKernel<T><<<grid, block, 0, dev_ctx.stream()>>>(
param.data<T>(),
grad.data<T>(),
learning_rate.data<T>(),
avg_param.data<T>(),
current_step.data<T>(),
t0,
param.numel(),
param_out->mutable_data<T>(dev_ctx.GetPlace()),
avg_param_out->mutable_data<T>(dev_ctx.GetPlace()));

IncreaseStep<T><<<1, 1, 0, dev_ctx.stream()>>>(
current_step.data<T>(),
current_step_out->mutable_data<T>(dev_ctx.GetPlace()));
}

} // namespace phi

PD_REGISTER_KERNEL(asgd, GPU, ALL_LAYOUT, phi::AsgdKernel, float, double) {}
9 changes: 9 additions & 0 deletions paddle/utils/variant.h
Expand Up @@ -13,6 +13,11 @@

#pragma once

#if defined(__GNUC__) && !defined(__clang__) && __GNUC__ >= 9
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-copy"
#endif
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

本文件的改动已经单独提交了 PR:#42265 ,在本 PR 的 review 中可以忽略这些改动


/*
variant synopsis

Expand Down Expand Up @@ -2828,3 +2833,7 @@ struct hash<paddle::monostate> {
};

} // namespace std

#if defined(__GNUC__) && !defined(__clang__) && __GNUC__ >= 9
#pragma GCC diagnostic pop
#endif
Binary file added python/paddle/fluid/tests/unittests/model.pdparams
Binary file not shown.
Binary file added python/paddle/fluid/tests/unittests/opt.pdopt
Binary file not shown.