Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.19】Implement ASGD optimizer #42431

72 changes: 72 additions & 0 deletions paddle/fluid/operators/optimizers/asgd_op.cc
@@ -0,0 +1,72 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

class AsgdOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace());
}
};

class AsgdOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
AddInput("LearningRate", "(Tensor) Learning rate of SGD");
AddInput("AvgParam",
"(Tensor) Average of parameter");
AddInput("CurrentStep",
"(Tensor) Current step");
AddOutput("ParamOut",
"(Tensor, same with Param) "
"Output parameter, should share the same memory with Param");
AddOutput("AvgParamOut",
"(Tensor, same with AvgParam) Average of parameter");
AddOutput("CurrentStepOut",
"(Tensor) Increased step");

AddAttr<float>("t0",
"(float, default 1e6) point at which to start averaging")
.SetDefault(0.95f);
AddComment(R"DOC(
)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(asgd, AsgdInferMetaFunctor,
PD_INFER_META(phi::AsgdInferMeta));
REGISTER_OPERATOR(
asgd, ops::AsgdOp, ops::AsgdOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
AsgdInferMetaFunctor);
46 changes: 46 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Expand Up @@ -1926,6 +1926,52 @@ void RnnInferMeta(const MetaTensor& x,
}
}

void AsgdInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& learning_rate,
const MetaTensor& avg_param,
const MetaTensor& current_step,
float t0,
MetaTensor* param_out,
MetaTensor* avg_param_out,
MetaTensor* current_step_out) {
PADDLE_ENFORCE_NOT_NULL(param_out,
phi::errors::InvalidArgument(
"Output(ParamOut) of SGDOp should not be null."));

auto lr_dims = learning_rate.dims();
PADDLE_ENFORCE_EQ(phi::product(lr_dims),
1,
phi::errors::InvalidArgument(
"Learning rate should have 1 element. But received "
"LearningRate dims [%s]",
phi::product(lr_dims)));

auto current_step_dims = current_step.dims();
PADDLE_ENFORCE_EQ(phi::product(current_step_dims),
1,
phi::errors::InvalidArgument(
"Current step should have 1 element. But received "
"dims [%s]",
phi::product(current_step_dims)));

auto param_dims = param.dims();
auto avg_param_dims = avg_param.dims();
PADDLE_ENFORCE_EQ(param_dims,
avg_param_dims,
phi::errors::InvalidArgument(
"Param and AvgParam should have the same dims. But received "
"[%s] and [%s]",
param_dims, avg_param_dims));

param_out->set_dims(param.dims());
param_out->set_dtype(param.dtype());
avg_param_out->set_dims(param.dims());
avg_param_out->set_dtype(param.dtype());
current_step_out->set_dims(current_step.dims());
current_step_out->set_dtype(current_step.dtype());
}

void SgdInferMeta(const MetaTensor& param,
const MetaTensor& learning_rate,
const MetaTensor& grad,
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/infermeta/multiary.h
Expand Up @@ -122,6 +122,16 @@ void AddNInferMeta(const std::vector<const MetaTensor*>& x,
MetaTensor* out,
MetaConfig config = MetaConfig());

void AsgdInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& learning_rate,
const MetaTensor& avg_param,
const MetaTensor& current_step,
float t0,
MetaTensor* param_out,
MetaTensor* avg_param_out,
MetaTensor* current_step_out);

void AucInferMeta(const MetaTensor& input,
const MetaTensor& label,
const MetaTensor& stat_pos,
Expand Down
33 changes: 33 additions & 0 deletions paddle/phi/kernels/asgd_kernel.h
@@ -0,0 +1,33 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void AsgdKernel(const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& learning_rate,
const DenseTensor& grad,
const DenseTensor& avg_param,
const DenseTensor& current_step,
float t0,
DenseTensor* param_out,
DenseTensor* avg_param_out,
DenseTensor* current_step_out);

} // namespace phi
65 changes: 65 additions & 0 deletions paddle/phi/kernels/cpu/asgd_kernel.cc
@@ -0,0 +1,65 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/asgd_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"

namespace phi {

template <typename T, typename Context>
void AsgdKernel(const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& grad,
const DenseTensor& learning_rate,
const DenseTensor& avg_param,
const DenseTensor& current_step,
float t0,
DenseTensor* param_out,
DenseTensor* avg_param_out,
DenseTensor* current_step_out) {
dev_ctx.template Alloc<T>(param_out);
dev_ctx.template Alloc<T>(avg_param_out);
dev_ctx.template Alloc<T>(current_step_out);

auto eigen_param = EigenVector<T>::Flatten(param);
auto eigen_grad = EigenVector<T>::Flatten(grad);
auto eigen_avg_param = EigenVector<T>::Flatten(avg_param);
auto eigen_param_out = EigenVector<T>::Flatten(*param_out);
auto eigen_avg_param_out = EigenVector<T>::Flatten(*avg_param_out);
auto& place = *dev_ctx.eigen_device();

auto lr = learning_rate.data<T>()[0];
eigen_param_out.device(place) = eigen_param - lr * eigen_grad;

T current_step_data = current_step.data<T>()[0];

if (current_step_data <= t0) {
eigen_avg_param_out.device(place) = eigen_param_out;
} else {
const auto mu1 = 1 / (current_step_data - t0);
const auto mu2 = 1 - mu1;
eigen_avg_param_out.device(place) =
mu2 * eigen_avg_param + mu1 * eigen_param_out;
}
*current_step_out->mutable_data<T>(dev_ctx.GetPlace()) =
current_step_data + 1;
}

} // namespace phi

PD_REGISTER_KERNEL(asgd, CPU, ALL_LAYOUT, phi::AsgdKernel, float, double) {}
83 changes: 83 additions & 0 deletions paddle/phi/kernels/gpu/asgd_kernel.cu
@@ -0,0 +1,83 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/asgd_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T>
__global__ void ASGDKernel(const T* param,
const T* grad,
const T* learning_rate,
const T* avg_param,
const T* current_step,
float t0,
size_t num,
T* param_out,
T* avg_param_out) {
T lr = learning_rate[0];
CUDA_KERNEL_LOOP(i, num) { param_out[i] = param[i] - lr * grad[i]; }
T current_step_data = current_step[0];
if (current_step_data <= t0) {
memcpy(avg_param_out, param, num * sizeof(T));
} else {
const auto mu1 = 1 / (current_step_data - t0);
const auto mu2 = 1 - mu1;
CUDA_KERNEL_LOOP(i, num) {
avg_param_out[i] = mu2 * avg_param[i] + mu1 * param_out[i];
}
}
}

template <typename T>
__global__ void IncreaseStep(const T* step, T* step_out) {
*step_out = *step + 1;
}

template <typename T, typename Context>
void AsgdKernel(const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& learning_rate,
const DenseTensor& grad,
const DenseTensor& avg_param,
const DenseTensor& current_step,
float t0,
DenseTensor* param_out,
DenseTensor* avg_param_out,
DenseTensor* current_step_out) {
int block = 512;
int grid = (param.numel() + block - 1) / block;

ASGDKernel<T><<<grid, block, 0, dev_ctx.stream()>>>(
param.data<T>(),
grad.data<T>(),
learning_rate.data<T>(),
avg_param.data<T>(),
current_step.data<T>(),
t0,
param.numel(),
param_out->mutable_data<T>(dev_ctx.GetPlace()),
avg_param_out->mutable_data<T>(dev_ctx.GetPlace()));

IncreaseStep<T><<<1, 1, 0, dev_ctx.stream()>>>(
current_step.data<T>(),
current_step_out->mutable_data<T>(dev_ctx.GetPlace()));
}

} // namespace phi

PD_REGISTER_KERNEL(asgd, GPU, ALL_LAYOUT, phi::AsgdKernel, float, double) {}
Binary file added python/paddle/fluid/tests/unittests/model.pdparams
Binary file not shown.
Binary file added python/paddle/fluid/tests/unittests/opt.pdopt
Binary file not shown.
55 changes: 55 additions & 0 deletions python/paddle/fluid/tests/unittests/test_optimizer.py
Expand Up @@ -69,6 +69,61 @@ def check_sgd_optimizer(optimizer_attr):
self.assertEqual(len(opts), 1)
self.assertEqual([op.type for op in opts], ["sgd"])

def test_asgd_optimizer(self):
w_shape = [3, 4]
class MyLayer(paddle.nn.Layer):
def __init__(self):
super(MyLayer, self).__init__()
self._w = self.create_parameter(w_shape, default_initializer=paddle.fluid.initializer.ConstantInitializer())

def forward(self, x):
return x * self._w

with paddle.fluid.dygraph.guard():
np_neg_ones = np.ones(w_shape) * -1

model = MyLayer()
x = paddle.ones([1, 3, 4])
asgd = paddle.optimizer.ASGD(learning_rate=1., parameters=model.parameters(), t0=1)

loss = model(x)
print(f'1: w grad before bw: {model._w.grad}')
loss.backward()
print(f'1: w grad: {model._w.grad}')
asgd.step()
assert np.allclose(model._w.numpy(), np_neg_ones)
assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones)
asgd.clear_grad()

loss = model(x)
print(f'2: w grad before bw: {model._w.grad}')
loss.backward()
print(f'2: w grad: {model._w.grad}')
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

【需要帮助】

这里有很奇怪的问题。这个测试在本地是可以通过的,但在 CI 环境里,第二次 loss.backward() 前后 model._w.grad 均为 None(通过这些 print 语句看出的)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI里提示有CONFLICT,先同步一下最新的代码再试试吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conflict 是来自于我自己的这个 PR 被合并到 develop 了:#42265

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已同步

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个问题很奇怪。本地是完全正常的,但 ci 环境里 loss.backward() 后 grad 竟然是 None。

或者有什么方法强行设置 model._w 的 grad 也可以,我尝试过 model._w.grad = ... 但会报错,不知道有没有其他方式

Copy link
Contributor

@chenwhql chenwhql May 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tiancaishaonvjituizi 您可太热心了,这个咱们之前讨论过了,感谢建议,我们后面会看下有没有更好的方式去替代这部分自动生成声明的实现,结合编译效率来看的话,每个文件中都展开一编注册的代码确实也不太高效,这确实不是一种好的方式。

不过其实这个设计在跟您在这儿讨论的这批人加入paddle之前就是这样的了,#14413
这个我理解在当时也起到了降低开发成本的作用,也没有您说的那么严重,但当然这并不是一种好的方式我也认可。我们后来人在此基础上的开发和重构都要在时效和目标之间做平衡和取舍,在目前的体系下,这样的方式既能够减少大家的开发成本,也能满足快速迭代上线的目标。如果我们先对整个paddle编译的体系进行重构优化的话,当然也能避免这种写法,但在算子库重构启动的时候这个不是最高优目标,这不只是一个技术的问题,涉及到团队规划和管理的多方面因素。

总之感谢您的建议,无论是编译方式和编译效率优化,还是Python API的分支写法优化,我们最近都在高优推进了,希望理解,也希望再给我们一些耐心,paddle会变得更好,也希望您能给我们更多的输入,大家一起将paddle建设得更好

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用C++的宏定义的时候,确实要很小心。

@jzhang533 这和“使用宏的时候要小心”没有关系,可能是你还没有了解为什么这里换一下位置就会报错

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不过其实这个设计在跟您在这儿讨论的这批人加入paddle之前就是这样的了,#14413
这个我理解在当时也起到了降低开发成本的作用,也没有您说的那么严重,但当然这并不是一种好的方式我也认可。我们后来人在此基础上的开发和重构都要在时效和目标之间做平衡和取舍,在目前的体系下,这样的方式既能够减少大家的开发成本,也能满足快速迭代上线的目标。如果我们先对整个paddle编译的体系进行重构优化的话,当然也能避免这种写法,但在算子库重构启动的时候这个不是最高优目标,这不只是一个技术的问题,涉及到团队规划和管理的多方面因素。

@chenwhql 通俗的讲就是只想着赶紧做业务,忽视了对底层代码的维护和迭代,所以屎山代码越堆越大,直到某一天爆雷了,才意识到要花十倍的时间去弥补 😂 我可以理解,同时 “屎山代码越堆越大不是好事”也是不需要再专门说明原因的老生常谈。还是希望 paddle 团队多多自我审视,软件工程就是在和人“再苟一苟”的惰性斗争

也没有您说的那么严重

这其实是一个常见的误区,原因是

  1. 屎山代码的危害本身就是温水煮青蛙
  2. 熟悉了内情的内部开发者是难以切身体会到其中的问题的,因为他在潜意识里已经熟知了那些坑,所以一举一动都会不自觉地落在“舒适区”里

其实这样做有两个方面的问题:

  1. 现实问题:新加入 paddle 开发的开发者在这里非常容易出错(我自己就是例子),特别是领域专家级别的开发者。小白开发者反而会少出问题,因为小白开发者本身就倾向于复制粘贴,不在乎代码冗余度和可维护性。也就带来这样一种反最佳实践的现象:好的写法无法通过编译,小白写法反而可以。这也是为什么我说这么做 “反映了 ‘鼓励流水线式批量产出代码,不在意可维护性’ 的潜意识”。
  2. 代码品味问题:C++ 语言的复杂和难以解析是一个众所周知的事情,在这种情况下 paddle “头铁” 地用字符串匹配来解析 C++ 代码,在熟悉 C++ 的人的眼里,是一件大跌眼镜的事情,挺损害 paddle 的形象的

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhiqiu

关于5,如上所述,目前的构建体系中有大量的静态库,新增c++代码需要手动修改很多依赖,导致cmake和make速度较慢,后面的一个优化方向是只编译少数几个动态库,这样可以把依赖层级扁平化。

请问这一部分有后续的更新/详细计划吗?

Copy link
Contributor

@zhiqiu zhiqiu Jul 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhiqiu

关于5,如上所述,目前的构建体系中有大量的静态库,新增c++代码需要手动修改很多依赖,导致cmake和make速度较慢,后面的一个优化方向是只编译少数几个动态库,这样可以把依赖层级扁平化。

请问这一部分有后续的更新/详细计划吗?

目前已经在开展部分工作,如#43247 ,将所有phi算子kernel编译为两个静态库。后续会逐步推广。

asgd.step()
assert np.allclose(model._w.numpy(), np_neg_ones * 2)
assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 2)
asgd.clear_grad()

loss = model(x)
loss.backward()
asgd.step()
assert np.allclose(model._w.numpy(), np_neg_ones * 3)
assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 3)
asgd.clear_grad()

loss = model(x)
loss.backward()
asgd.step()
assert np.allclose(model._w.numpy(), np_neg_ones * 4)
assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 3.5)
asgd.clear_grad()

loss = model(x)
loss.backward()
asgd.step()
assert np.allclose(model._w.numpy(), np_neg_ones * 5)
assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 4)


class TestOptimizerBackwardApplygrad(unittest.TestCase):
def test_sgd_optimizer(self):
Expand Down