Skip to content

Commit

Permalink
调整代码格式
Browse files Browse the repository at this point in the history
  • Loading branch information
BrilliantYuKaimin committed Apr 7, 2022
1 parent 7b9c8ed commit 3e2e62a
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 14 deletions.
11 changes: 6 additions & 5 deletions paddle/phi/kernels/cpu/logspace_kernel.cc
Expand Up @@ -13,10 +13,11 @@
// limitations under the License.

#include <cmath>
#include "paddle/phi/kernels/logspace_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/data_type_transform.h"
#include "paddle/phi/kernels/logspace_kernel.h"

namespace phi {

Expand Down Expand Up @@ -52,11 +53,11 @@ void LogspaceKernel(const Context& ctx,
int half_num = num / 2;
for (int i = 0; i < num; ++i) {
if (i < half_num) {
out_data[i] = static_cast<T>(std::pow(
base_data, start_data + step * i));
out_data[i] =
static_cast<T>(std::pow(base_data, start_data + step * i));
} else {
out_data[i] = static_cast<T>(std::pow(
base_data, stop_data - step * (num - i - 1)));
out_data[i] = static_cast<T>(
std::pow(base_data, stop_data - step * (num - i - 1)));
}
}
} else {
Expand Down
16 changes: 8 additions & 8 deletions paddle/phi/kernels/gpu/logspace_kernel.cu
Expand Up @@ -29,21 +29,21 @@ __global__ void LogspaceKernelInner(

for (; index < size; index += blockDim.x * gridDim.x) {
if (index < size / 2) {
out[index] = static_cast<T>(pow(
static_cast<double>(base),
static_cast<double>(start + step * index)));
out[index] =
static_cast<T>(pow(static_cast<double>(base),
static_cast<double>(start + step * index)));
} else {
out[index] = static_cast<T>(pow(
static_cast<double>(base),
static_cast<double>(stop - step * (size - index - 1))));
out[index] = static_cast<T>(
pow(static_cast<double>(base),
static_cast<double>(stop - step * (size - index - 1))));
}
}
}

template <typename T>
__global__ void LogspaceSpecialKernel(T start, T base, T* out) {
out[0] = static_cast<T>(pow(
static_cast<double>(base), static_cast<double>(start)));
out[0] = static_cast<T>(
pow(static_cast<double>(base), static_cast<double>(start)));
}

template <typename T, typename Context>
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/unittests/test_logspace.py
Expand Up @@ -222,4 +222,4 @@ def test_base_dtype():


if __name__ == "__main__":
unittest.main()
unittest.main()

0 comments on commit 3e2e62a

Please sign in to comment.