Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon 2】9、为 Paddle 新增 logspace API #41261

Merged
merged 21 commits into from Apr 20, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
80 changes: 80 additions & 0 deletions paddle/fluid/operators/logspace_op.cc
@@ -0,0 +1,80 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

@chenwhql chenwhql Apr 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

license的格式有点问题,缺失空行和缩进,可以参考其他文件修改一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完成

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <string>

#include "paddle/fluid/framework/infershape_utils.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C++基础头文件和项目自身的头文件之间空一行,方便区分。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完成

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"

namespace paddle {
namespace operators {

class LogspaceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}

framework::OpKernelType GetKernelTypeForVar(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数移除就好,默认就是这样,不用写多余的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完成

const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};

class LogspaceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Start",
"Exponent of first entry in the sequence. It is a tensor of "
"shape [1], should be of type int32, int64, float32 or float64.");
AddInput("Stop",
"Exponent of last entry in the sequence. It is a tensor of "
"shape [1], should be of type int32, int64, float32 or float64.");
AddInput("Num",
"Number of entry in the sequence. It is a tensor of shape [1], "
"should be of type int32.");
AddInput("Base",
"Base of the logarithm function. It is a tensor of shape [1], "
"should be of type int32, int64, float32 or float64.");
AddAttr<int>("dtype", "The output data type.");
AddOutput("Out", "A sequence of numbers.");
AddComment(R"DOC(
Return fixed number of logarithmical-evenly spaced values within a given
interval. First entry is exponential of Start with base Base, and last
entry is exponential of Stop with base Base. In the case when Num is 1,
only exponential of Start with base Base is returned. If dtype is int32
or int64, the decimal part of values will be truncated.
Like logspace function of numpy.
)DOC");
}
};
} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(logspace, LogspaceInferShapeFunctor,
PD_INFER_META(phi::LogspaceInferMeta));
REGISTER_OPERATOR(
logspace, ops::LogspaceOp, ops::LogspaceOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
LogspaceInferShapeFunctor);
37 changes: 37 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Expand Up @@ -1488,6 +1488,43 @@ void InterpolateInferMeta(
}
}

void LogspaceInferMeta(const MetaTensor& start,
const MetaTensor& stop,
const MetaTensor& number,
const MetaTensor& base,
MetaTensor* out) {
auto s_dims = start.dims();
PADDLE_ENFORCE_EQ(
(s_dims.size() == 1) && (s_dims[0] == 1),
true,
phi::errors::InvalidArgument("The shape of Input(Start) must be [1],"
"but received input shape is [%s].",
s_dims));
auto e_dims = stop.dims();
PADDLE_ENFORCE_EQ(
(e_dims.size() == 1) && (e_dims[0] == 1),
true,
phi::errors::InvalidArgument("The shape of Input(Stop) must be [1],"
"but received input shape is [%s].",
e_dims));
auto num_dims = number.dims();
PADDLE_ENFORCE_EQ(
(num_dims.size() == 1) && (num_dims[0] == 1),
true,
phi::errors::InvalidArgument("The shape of Input(Num) must be [1],"
"but received input shape is [%s].",
num_dims));
auto b_dims = base.dims();
PADDLE_ENFORCE_EQ(
(b_dims.size() == 1) && (b_dims[0] == 1),
true,
phi::errors::InvalidArgument("The shape of Input(Base) must be [1],"
"but received input shape is [%s].",
b_dims));
out->set_dims(phi::make_ddim({-1}));
out->set_dtype(start.dtype());
}

void MeshgridInferMeta(const std::vector<MetaTensor*>& inputs,
std::vector<MetaTensor*> outputs) {
const size_t inputs_num = inputs.size();
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/multiary.h
Expand Up @@ -227,6 +227,12 @@ void InterpolateInferMeta(
MetaTensor* output,
MetaConfig config = MetaConfig());

void LogspaceInferMeta(const MetaTensor& start,
const MetaTensor& stop,
const MetaTensor& number,
const MetaTensor& base,
MetaTensor* out);

void MeshgridInferMeta(const std::vector<MetaTensor*>& inputs,
std::vector<MetaTensor*> outputs);

Expand Down
77 changes: 77 additions & 0 deletions paddle/phi/kernels/cpu/logspace_kernel.cc
@@ -0,0 +1,77 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cmath>

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/data_type_transform.h"
#include "paddle/phi/kernels/logspace_kernel.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

头文件的include顺序需要调整下,可以参考说明:https://zh-google-styleguide.readthedocs.io/en/latest/google-cpp-styleguide/headers/#include

image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完成


namespace phi {

template <typename T, typename Context>
void LogspaceKernel(const Context& ctx,
const DenseTensor& start,
const DenseTensor& stop,
const DenseTensor& number,
const DenseTensor& base,
DataType dtype,
DenseTensor* out) {
int32_t num = number.data<int32_t>()[0];
auto start_t = phi::funcs::TransDataType(ctx, start, dtype);
auto stop_t = phi::funcs::TransDataType(ctx, stop, dtype);
auto base_t = phi::funcs::TransDataType(ctx, base, dtype);

T start_data = start_t.template data<T>()[0];
T stop_data = stop_t.template data<T>()[0];
T base_data = base_t.template data<T>()[0];
PADDLE_ENFORCE_GT(
num,
0,
phi::errors::InvalidArgument("The num of logspace op should be larger "
"than 0, but received num is %d",
num));

out->Resize(phi::make_ddim({num}));
T* out_data = ctx.template Alloc<T>(out);

if (num > 1) {
// step should be of double type for all types
double step = (static_cast<double>(stop_data - start_data)) / (num - 1);
int half_num = num / 2;
for (int i = 0; i < num; ++i) {
if (i < half_num) {
out_data[i] =
static_cast<T>(std::pow(base_data, start_data + step * i));
} else {
out_data[i] = static_cast<T>(
std::pow(base_data, stop_data - step * (num - i - 1)));
}
}
} else {
out_data[0] = static_cast<T>(std::pow(base_data, start_data));
}
}

} // namespace phi

PD_REGISTER_KERNEL(logspace,
CPU,
ALL_LAYOUT,
phi::LogspaceKernel,
float,
int32_t,
int64_t,
double) {}
107 changes: 107 additions & 0 deletions paddle/phi/kernels/gpu/logspace_kernel.cu
@@ -0,0 +1,107 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上哈,麻烦也调整下,#include "paddle/phi/kernels/logspace_kernel.h"在最前面,用空行隔开

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完成

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里gpu_primitives.h看起来好像没有使用?是否可以移除?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已移除

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/data_type_transform.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/logspace_kernel.h"

namespace phi {

template <typename T>
__global__ void LogspaceKernelInner(
T start, T stop, double step, T base, int64_t size, T* out) {
int64_t index = blockIdx.x * blockDim.x + threadIdx.x;

for (; index < size; index += blockDim.x * gridDim.x) {
if (index < size / 2) {
out[index] =
static_cast<T>(pow(static_cast<double>(base),
static_cast<double>(start + step * index)));
} else {
out[index] = static_cast<T>(
pow(static_cast<double>(base),
static_cast<double>(stop - step * (size - index - 1))));
}
}
}

template <typename T>
__global__ void LogspaceSpecialKernel(T start, T base, T* out) {
out[0] = static_cast<T>(
pow(static_cast<double>(base), static_cast<double>(start)));
}

template <typename T, typename Context>
void LogspaceKernel(const Context& ctx,
const DenseTensor& start,
const DenseTensor& stop,
const DenseTensor& number,
const DenseTensor& base,
DataType dtype,
DenseTensor* out) {
auto start_t = phi::funcs::TransDataType(ctx, start, dtype);
auto stop_t = phi::funcs::TransDataType(ctx, stop, dtype);
auto base_t = phi::funcs::TransDataType(ctx, base, dtype);

DenseTensor n_start;
DenseTensor n_stop;
DenseTensor n_num;
DenseTensor n_base;
phi::Copy(ctx, start_t, phi::CPUPlace(), false, &n_start);
T start_data = n_start.data<T>()[0];
phi::Copy(ctx, stop_t, phi::CPUPlace(), false, &n_stop);
T stop_data = n_stop.data<T>()[0];
phi::Copy(ctx, number, phi::CPUPlace(), false, &n_num);
int64_t num = static_cast<int64_t>(n_num.data<int32_t>()[0]);
phi::Copy(ctx, base_t, phi::CPUPlace(), false, &n_base);
T base_data = n_base.data<T>()[0];

PADDLE_ENFORCE_GT(
num,
0,
phi::errors::InvalidArgument("The num of logspace op should be larger "
"than 0, but received num is %d",
num));

out->Resize(phi::make_ddim({num}));
T* out_data = ctx.template Alloc<T>(out);

double step = 0;
auto stream = ctx.stream();
int block = 512;
int grid = (num + block - 1) / block;
if (num != 1) {
step = (static_cast<double>(stop_data - start_data)) / (num - 1);
LogspaceKernelInner<T><<<grid, block, 0, stream>>>(
start_data, stop_data, step, base_data, num, out_data);
} else {
LogspaceSpecialKernel<T><<<grid, block, 0, stream>>>(
start_data, base_data, out_data);
}
}

} // namespace phi

PD_REGISTER_KERNEL(logspace,
GPU,
ALL_LAYOUT,
phi::LogspaceKernel,
float,
int32_t,
int64_t,
double) {}
27 changes: 27 additions & 0 deletions paddle/phi/kernels/logspace_kernel.h
@@ -0,0 +1,27 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的license格式好像还是有点问题,可以后续再完善下

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void LogspaceKernel(const Context& ctx,
const DenseTensor& start,
const DenseTensor& stop,
const DenseTensor& number,
const DenseTensor& base,
DataType dtype,
DenseTensor* out);

} // namespace phi
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
Expand Up @@ -89,6 +89,7 @@
from .tensor.creation import diagflat # noqa: F401
from .tensor.creation import eye # noqa: F401
from .tensor.creation import linspace # noqa: F401
from .tensor.creation import logspace # noqa: F401
from .tensor.creation import ones # noqa: F401
from .tensor.creation import ones_like # noqa: F401
from .tensor.creation import zeros # noqa: F401
Expand Down Expand Up @@ -588,6 +589,7 @@
'sqrt',
'randperm',
'linspace',
'logspace',
'reshape',
'reshape_',
'reverse',
Expand Down