Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
【PaddlePaddle Hackathon 2】9、为 Paddle 新增 logspace API (#41261)
* 增加logspace的算子描述 * 增加logspace的形状推断 * 增加logspace核函数实现 * 在python中增加logspace接口 * 增加logspace单测 * 增加logspace * Update logspace_kernel.cu * Update logspace_op.cc * 调整代码格式 * Update doc of logspace * Update tensor.py * Update logspace_op.cc * Update logspace_kernel.cc * Update logspace_kernel.cu * Update test_logspace.py * 调整 logspace 的位置 * 调整代码格式
- Loading branch information
1 parent
885171e
commit a3c50c4
Showing
10 changed files
with
689 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <string> | ||
|
||
#include "paddle/fluid/framework/infershape_utils.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/framework/op_version_registry.h" | ||
#include "paddle/phi/core/infermeta_utils.h" | ||
#include "paddle/phi/infermeta/multiary.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class LogspaceOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext &ctx) const override { | ||
return framework::OpKernelType( | ||
framework::proto::VarType::Type(ctx.Attr<int>("dtype")), | ||
ctx.GetPlace()); | ||
} | ||
}; | ||
|
||
class LogspaceOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddInput("Start", | ||
"Exponent of first entry in the sequence. It is a tensor of " | ||
"shape [1], should be of type int32, int64, float32 or float64."); | ||
AddInput("Stop", | ||
"Exponent of last entry in the sequence. It is a tensor of " | ||
"shape [1], should be of type int32, int64, float32 or float64."); | ||
AddInput("Num", | ||
"Number of entry in the sequence. It is a tensor of shape [1], " | ||
"should be of type int32."); | ||
AddInput("Base", | ||
"Base of the logarithm function. It is a tensor of shape [1], " | ||
"should be of type int32, int64, float32 or float64."); | ||
AddAttr<int>("dtype", "The output data type."); | ||
AddOutput("Out", "A sequence of numbers."); | ||
AddComment(R"DOC( | ||
Return fixed number of logarithmical-evenly spaced values within a given | ||
interval. First entry is exponential of Start with base Base, and last | ||
entry is exponential of Stop with base Base. In the case when Num is 1, | ||
only exponential of Start with base Base is returned. If dtype is int32 | ||
or int64, the decimal part of values will be truncated. | ||
Like logspace function of numpy. | ||
)DOC"); | ||
} | ||
}; | ||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
DECLARE_INFER_SHAPE_FUNCTOR(logspace, LogspaceInferShapeFunctor, | ||
PD_INFER_META(phi::LogspaceInferMeta)); | ||
REGISTER_OPERATOR( | ||
logspace, ops::LogspaceOp, ops::LogspaceOpMaker, | ||
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, | ||
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, | ||
LogspaceInferShapeFunctor); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/logspace_kernel.h" | ||
|
||
#include <cmath> | ||
|
||
#include "paddle/phi/backends/cpu/cpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/funcs/data_type_transform.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T, typename Context> | ||
void LogspaceKernel(const Context& ctx, | ||
const DenseTensor& start, | ||
const DenseTensor& stop, | ||
const DenseTensor& number, | ||
const DenseTensor& base, | ||
DataType dtype, | ||
DenseTensor* out) { | ||
int32_t num = number.data<int32_t>()[0]; | ||
auto start_t = phi::funcs::TransDataType(ctx, start, dtype); | ||
auto stop_t = phi::funcs::TransDataType(ctx, stop, dtype); | ||
auto base_t = phi::funcs::TransDataType(ctx, base, dtype); | ||
|
||
T start_data = start_t.template data<T>()[0]; | ||
T stop_data = stop_t.template data<T>()[0]; | ||
T base_data = base_t.template data<T>()[0]; | ||
PADDLE_ENFORCE_GT( | ||
num, | ||
0, | ||
phi::errors::InvalidArgument("The num of logspace op should be larger " | ||
"than 0, but received num is %d", | ||
num)); | ||
|
||
out->Resize(phi::make_ddim({num})); | ||
T* out_data = ctx.template Alloc<T>(out); | ||
|
||
if (num > 1) { | ||
// step should be of double type for all types | ||
double step = (static_cast<double>(stop_data - start_data)) / (num - 1); | ||
int half_num = num / 2; | ||
for (int i = 0; i < num; ++i) { | ||
if (i < half_num) { | ||
out_data[i] = | ||
static_cast<T>(std::pow(base_data, start_data + step * i)); | ||
} else { | ||
out_data[i] = static_cast<T>( | ||
std::pow(base_data, stop_data - step * (num - i - 1))); | ||
} | ||
} | ||
} else { | ||
out_data[0] = static_cast<T>(std::pow(base_data, start_data)); | ||
} | ||
} | ||
|
||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL(logspace, | ||
CPU, | ||
ALL_LAYOUT, | ||
phi::LogspaceKernel, | ||
float, | ||
int32_t, | ||
int64_t, | ||
double) {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/logspace_kernel.h" | ||
|
||
#include "paddle/phi/backends/gpu/gpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/copy_kernel.h" | ||
#include "paddle/phi/kernels/funcs/data_type_transform.h" | ||
#include "paddle/phi/kernels/funcs/math_function.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T> | ||
__global__ void LogspaceKernelInner( | ||
T start, T stop, double step, T base, int64_t size, T* out) { | ||
int64_t index = blockIdx.x * blockDim.x + threadIdx.x; | ||
|
||
for (; index < size; index += blockDim.x * gridDim.x) { | ||
if (index < size / 2) { | ||
out[index] = | ||
static_cast<T>(pow(static_cast<double>(base), | ||
static_cast<double>(start + step * index))); | ||
} else { | ||
out[index] = static_cast<T>( | ||
pow(static_cast<double>(base), | ||
static_cast<double>(stop - step * (size - index - 1)))); | ||
} | ||
} | ||
} | ||
|
||
template <typename T> | ||
__global__ void LogspaceSpecialKernel(T start, T base, T* out) { | ||
out[0] = static_cast<T>( | ||
pow(static_cast<double>(base), static_cast<double>(start))); | ||
} | ||
|
||
template <typename T, typename Context> | ||
void LogspaceKernel(const Context& ctx, | ||
const DenseTensor& start, | ||
const DenseTensor& stop, | ||
const DenseTensor& number, | ||
const DenseTensor& base, | ||
DataType dtype, | ||
DenseTensor* out) { | ||
auto start_t = phi::funcs::TransDataType(ctx, start, dtype); | ||
auto stop_t = phi::funcs::TransDataType(ctx, stop, dtype); | ||
auto base_t = phi::funcs::TransDataType(ctx, base, dtype); | ||
|
||
DenseTensor n_start; | ||
DenseTensor n_stop; | ||
DenseTensor n_num; | ||
DenseTensor n_base; | ||
phi::Copy(ctx, start_t, phi::CPUPlace(), false, &n_start); | ||
T start_data = n_start.data<T>()[0]; | ||
phi::Copy(ctx, stop_t, phi::CPUPlace(), false, &n_stop); | ||
T stop_data = n_stop.data<T>()[0]; | ||
phi::Copy(ctx, number, phi::CPUPlace(), false, &n_num); | ||
int64_t num = static_cast<int64_t>(n_num.data<int32_t>()[0]); | ||
phi::Copy(ctx, base_t, phi::CPUPlace(), false, &n_base); | ||
T base_data = n_base.data<T>()[0]; | ||
|
||
PADDLE_ENFORCE_GT( | ||
num, | ||
0, | ||
phi::errors::InvalidArgument("The num of logspace op should be larger " | ||
"than 0, but received num is %d", | ||
num)); | ||
|
||
out->Resize(phi::make_ddim({num})); | ||
T* out_data = ctx.template Alloc<T>(out); | ||
|
||
double step = 0; | ||
auto stream = ctx.stream(); | ||
int block = 512; | ||
int grid = (num + block - 1) / block; | ||
if (num != 1) { | ||
step = (static_cast<double>(stop_data - start_data)) / (num - 1); | ||
LogspaceKernelInner<T><<<grid, block, 0, stream>>>( | ||
start_data, stop_data, step, base_data, num, out_data); | ||
} else { | ||
LogspaceSpecialKernel<T><<<grid, block, 0, stream>>>( | ||
start_data, base_data, out_data); | ||
} | ||
} | ||
|
||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL(logspace, | ||
GPU, | ||
ALL_LAYOUT, | ||
phi::LogspaceKernel, | ||
float, | ||
int32_t, | ||
int64_t, | ||
double) {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include "paddle/phi/core/dense_tensor.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T, typename Context> | ||
void LogspaceKernel(const Context& ctx, | ||
const DenseTensor& start, | ||
const DenseTensor& stop, | ||
const DenseTensor& number, | ||
const DenseTensor& base, | ||
DataType dtype, | ||
DenseTensor* out); | ||
|
||
} // namespace phi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.