Skip to content

Commit

Permalink
add copyright, fix comment
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Aug 10, 2022
1 parent bb8517a commit 666bd68
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
16 changes: 8 additions & 8 deletions paddle/fluid/platform/device/gpu/gpu_primitives.h
Expand Up @@ -422,15 +422,15 @@ CUDA_ATOMIC_WRAPPER(Max, double) {
#ifdef PADDLE_CUDA_FP16
inline static __device__ uint32_t max_to_low_half(uint32_t val, float x) {
float16 low_half;
// the float16 in lower 16bits
// The float16 in lower 16bits
low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
low_half = static_cast<float16>(max(static_cast<float>(low_half), x));
return (val & 0xFFFF0000u) | low_half.x;
}

inline static __device__ uint32_t max_to_high_half(uint32_t val, float x) {
float16 high_half;
// the float16 in higher 16bits
// The float16 in higher 16bits
high_half.x = static_cast<uint16_t>(val >> 16);
high_half = static_cast<float16>(max(static_cast<float>(high_half), x));
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
Expand All @@ -447,7 +447,7 @@ CUDA_ATOMIC_WRAPPER(Max, float16) {
uint32_t old = *address_as_ui;
uint32_t assumed;
if (((uintptr_t)address & 0x02) == 0) {
// the float16 value stay at lower 16 bits of the address.
// The float16 value stay at lower 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed, max_to_low_half(assumed, val_f));
Expand All @@ -456,7 +456,7 @@ CUDA_ATOMIC_WRAPPER(Max, float16) {
ret.x = old & 0xFFFFu;
return ret;
} else {
// the float16 value stay at higher 16 bits of the address.
// The float16 value stay at higher 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed, max_to_high_half(assumed, val_f));
Expand Down Expand Up @@ -555,15 +555,15 @@ CUDA_ATOMIC_WRAPPER(Min, double) {
#ifdef PADDLE_CUDA_FP16
inline static __device__ uint32_t min_to_low_half(uint32_t val, float x) {
float16 low_half;
// the float16 in lower 16bits
// The float16 in lower 16bits
low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
low_half = static_cast<float16>(min(static_cast<float>(low_half), x));
return (val & 0xFFFF0000u) | low_half.x;
}

inline static __device__ uint32_t min_to_high_half(uint32_t val, float x) {
float16 high_half;
// the float16 in higher 16bits
// The float16 in higher 16bits
high_half.x = static_cast<uint16_t>(val >> 16);
high_half = static_cast<float16>(min(static_cast<float>(high_half), x));
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
Expand All @@ -580,7 +580,7 @@ CUDA_ATOMIC_WRAPPER(Min, float16) {
uint32_t old = *address_as_ui;
uint32_t assumed;
if (((uintptr_t)address & 0x02) == 0) {
// the float16 value stay at lower 16 bits of the address.
// The float16 value stay at lower 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed, min_to_low_half(assumed, val_f));
Expand All @@ -589,7 +589,7 @@ CUDA_ATOMIC_WRAPPER(Min, float16) {
ret.x = old & 0xFFFFu;
return ret;
} else {
// the float16 value stay at higher 16 bits of the address.
// The float16 value stay at higher 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed, min_to_high_half(assumed, val_f));
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/impl/graph_messaage_passing_impl.h
@@ -1,4 +1,5 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright The DGL team.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down

0 comments on commit 666bd68

Please sign in to comment.