Skip to content

Commit

Permalink
Optimize of backward of log_softmax when axis is -1 and dim_size <= 1…
Browse files Browse the repository at this point in the history
…024 (#32180)
  • Loading branch information
AshburnLee committed Apr 14, 2021
1 parent 7da4455 commit 5dc0a6e
Showing 1 changed file with 126 additions and 6 deletions.
132 changes: 126 additions & 6 deletions paddle/fluid/operators/log_softmax_op.cu
Expand Up @@ -65,19 +65,14 @@ __global__ void ComputeLogSoftmaxForwardInWarp(T *dst, const T *src,
constexpr int warp_iter = near_greater_power_of_two / kernel_warp_size;
int batch_id = blockDim.y * blockIdx.x + threadIdx.y;

// set effective_warp_id as 1 when warps do effective work,
// when warps do ineffective work, effective_warp_id remains unchanged.
int effective_warp_id = batch_size - batch_id;
if (effective_warp_id > 1) effective_warp_id = 1;

int thread_in_warp_idx = threadIdx.x;

// 1.read data from global memory to registers
AccT elements[warp_iter];
// set effective_element_count as the num of elements when warps do effective
// work
// set effective_element_count as 0, when warps do ineffective work
int effective_element_count = (effective_warp_id <= 0) ? 0 : element_count;
int effective_element_count = (batch_id < batch_size) ? element_count : 0;
for (int it = 0; it < warp_iter; ++it) {
int element_index = thread_in_warp_idx + it * kernel_warp_size;
if (element_index < effective_element_count) {
Expand Down Expand Up @@ -181,6 +176,131 @@ class LogSoftmaxKernel<platform::CUDADeviceContext, T>
}
};

// Backward below
#define LAUNCH_WARP_BACKWARD_COMPUTE(near_greater_power_of_two) \
case near_greater_power_of_two: \
ComputeLogSoftmaxBackwardInWarp< \
T, AccT, near_greater_power_of_two><<<blocks, threads, 0, stream>>>( \
output, grad_output, grad_input, outer_size, dim_size); \
break;

template <typename T, typename AccT, int NearGreaterPowerOfTwo>
__global__ void ComputeLogSoftmaxBackwardInWarp(const T *output,
const T *grad_output,
T *grad_input, int batch_size,
int element_count) {
constexpr int near_greater_power_of_two = NearGreaterPowerOfTwo;
constexpr int kernel_warp_size =
(near_greater_power_of_two < 32) ? near_greater_power_of_two : 32;
constexpr int warp_iter = near_greater_power_of_two / kernel_warp_size;
int batch_id = blockDim.y * blockIdx.x + threadIdx.y;

int thread_in_warp_idx = threadIdx.x % kernel_warp_size;

// 1.read data from global memory to registers
AccT output_register[warp_iter];
AccT grad_output_register[warp_iter];
int effective_element_count = (batch_id < batch_size) ? element_count : 0;
for (int iter = 0; iter < warp_iter; ++iter) {
int element_index = thread_in_warp_idx + iter * kernel_warp_size;
if (element_index < effective_element_count) {
output_register[iter] =
static_cast<AccT>(output[batch_id * element_count + element_index]);
grad_output_register[iter] = static_cast<AccT>(
grad_output[batch_id * element_count + element_index]);
} else {
output_register[iter] = AccT(0);
grad_output_register[iter] = AccT(0);
}
}

// 2. For each warp, accumulate all thread registers
AccT sum = grad_output_register[0];
#pragma unroll
for (int iter = 1; iter < warp_iter; ++iter) {
sum += grad_output_register[iter];
}
sum = WarpReduceSum<AccT, kernel_warp_size>(sum);

// 3. write result in grad_input
#pragma unroll
for (int iter = 0; iter < warp_iter; ++iter) {
int element_index = thread_in_warp_idx + iter * kernel_warp_size;
if (element_index < element_count) {
grad_input[batch_id * element_count + element_index] = static_cast<T>(
(grad_output_register[iter] - std::exp(output_register[iter]) * sum));
}
}
}

template <typename T, typename AccT>
void LaunchSoftmaxBackwardForLastAxis(T *grad_input, const T *grad_output,
const T *output, int dim_size,
int outer_size, gpuStream_t stream) {
int threads_per_block = 128;
int near_greater_power_of_two = GetNearGreaterPowerOfTwo(dim_size);
int kernel_warp_size =
(near_greater_power_of_two < 32) ? near_greater_power_of_two : 32;
int warps_per_block = (threads_per_block / kernel_warp_size);
int blocks = (outer_size + warps_per_block - 1) / warps_per_block;
dim3 threads(kernel_warp_size, warps_per_block, 1);

switch (near_greater_power_of_two) {
LAUNCH_WARP_BACKWARD_COMPUTE(1); // dim_size: 1
LAUNCH_WARP_BACKWARD_COMPUTE(2); // dim_size: 2
LAUNCH_WARP_BACKWARD_COMPUTE(4); // dim_size: 3~4
LAUNCH_WARP_BACKWARD_COMPUTE(8); // dim_size: 5~8
LAUNCH_WARP_BACKWARD_COMPUTE(16); // dim_size: 9~16
LAUNCH_WARP_BACKWARD_COMPUTE(32); // dim_size: 17~32
LAUNCH_WARP_BACKWARD_COMPUTE(64); // dim_size: 33~64
LAUNCH_WARP_BACKWARD_COMPUTE(128); // dim_size: 65~128
LAUNCH_WARP_BACKWARD_COMPUTE(256); // dim_size: 129~256
LAUNCH_WARP_BACKWARD_COMPUTE(512); // dim_size: 257~512
LAUNCH_WARP_BACKWARD_COMPUTE(1024); // dim_size: 513~1024

default:
break;
}
}

template <typename T>
class LogSoftmaxGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
using MPDType = typename details::MPTypeTrait<T>::Type;

public:
void Compute(const framework::ExecutionContext &context) const override {
const auto *out = context.Input<framework::Tensor>("Out");
const auto *g_out =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *g_x = context.Output<framework::Tensor>(framework::GradVarName("X"));

const auto *out_data = out->data<T>();
const auto *g_out_data = g_out->data<T>();
auto *g_x_data = g_x->mutable_data<T>(context.GetPlace());

const int rank = out->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);

int dim_size = out->dims()[axis];
int inner_size = 1;
for (int i = axis + 1; i < out->dims().size(); ++i) {
inner_size *= out->dims()[i];
}
int outer_size = SizeToAxis(axis, out->dims());
gpuStream_t stream = context.cuda_device_context().stream();

if (inner_size == 1 && dim_size <= 1024 && dim_size * sizeof(T) <= 4096) {
LaunchSoftmaxBackwardForLastAxis<T, MPDType>(
g_x_data, g_out_data, out_data, dim_size, outer_size, stream);
} else {
LogSoftmaxGradFunctor<platform::CUDADeviceContext, T>()(
context.template device_context<platform::CUDADeviceContext>(), out,
g_out, g_x, axis);
}
}
};

} // operators
} // paddle

Expand Down

0 comments on commit 5dc0a6e

Please sign in to comment.