From 441da36cb4080172c37acb547a73cc580344cbd3 Mon Sep 17 00:00:00 2001 From: Zihang Yao <1162526220@qq.com> Date: Sun, 29 May 2022 18:00:48 +0800 Subject: [PATCH 01/13] refactor code structure --- paddle/phi/kernels/gpu/batch_norm_kernel.cu | 182 ++++++++++---------- 1 file changed, 89 insertions(+), 93 deletions(-) diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index 361e62e566035..e572ec70dbebe 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -451,10 +451,80 @@ void BatchNormKernel(const Context &ctx, paddle::framework::TensorCopy(x, ctx.GetPlace(), y); } else { double this_factor = 1. - momentum; - - bool called = false; +#ifdef PADDLE_WITH_HIP + const int num = transformed_x.numel(); + const int block = 256; + const int max_threads = ctx.GetMaxPhysicalThreadCount(); + const int max_blocks = std::max(max_threads / block, 1); + const int grid = std::min(C, max_blocks); + if (compute_format == DataLayout::kNCHW) { + BNForwardTraining< + T, + block, + DataLayout::kNCHW><<>>( + transformed_x.template data(), + scale.template data>(), + bias.template data>(), + C, + N, + H * W * D, + epsilon, + this_factor, + transformed_y.template data(), + mean_out->template data>(), + variance_out->template data>(), + saved_mean->template data>(), + saved_variance->template data>()); + } else { + BNForwardTraining< + T, + block, + DataLayout::kNHWC><<>>( + transformed_x.template data(), + scale.template data>(), + bias.template data>(), + C, + N, + H * W * D, + epsilon, + this_factor, + transformed_y.template data(), + mean_out->template data>(), + variance_out->template data>(), + saved_mean->template data>(), + saved_variance->template data>()); + } +// TODO(wangran16): wait for MIOpen to improve the performance of BN +// PADDLE_ENFORCE_GPU_SUCCESS( +// platform::dynload::miopenBatchNormalizationForwardTraining( +// handle, mode_, const_cast(static_cast( +// CudnnDataType::kOne())), +// const_cast( +// static_cast(CudnnDataType::kZero())), +// data_desc_, +// static_cast(transformed_x.template data()), +// data_desc_, +// static_cast( +// transformed_y.template mutable_data(ctx.GetPlace())), +// bn_param_desc_, +// const_cast(static_cast( +// scale->template data>())), +// const_cast(static_cast( +// bias->template data>())), +// this_factor, +// static_cast( +// mean_out->template mutable_data>( +// ctx.GetPlace())), +// static_cast(variance_out->template mutable_data< +// BatchNormParamType>(ctx.GetPlace())), +// epsilon, +// static_cast( +// saved_mean->template mutable_data>( +// ctx.GetPlace())), +// static_cast(saved_variance->template mutable_data< +// BatchNormParamType>(ctx.GetPlace())))); +#else #if CUDNN_VERSION_MIN(7, 4, 1) - called = true; size_t workspace_size = 0; size_t reserve_space_size = 0; void *reserve_space_ptr = nullptr; @@ -530,102 +600,28 @@ void BatchNormKernel(const Context &ctx, workspace_size, reserve_space_ptr, reserve_space_size)); -#endif // CUDNN_VERSION_MIN(7, 4, 1) - if (!called) { -#ifdef PADDLE_WITH_HIP - const int num = transformed_x.numel(); - const int block = 256; - const int max_threads = ctx.GetMaxPhysicalThreadCount(); - const int max_blocks = std::max(max_threads / block, 1); - const int grid = std::min(C, max_blocks); - if (compute_format == DataLayout::kNCHW) { - BNForwardTraining< - T, - block, - DataLayout::kNCHW><<>>( +#else + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cudnnBatchNormalizationForwardTraining( + handle, + mode_, + CudnnDataType::kOne(), + CudnnDataType::kZero(), + data_desc_, transformed_x.template data(), + data_desc_, + ctx.template Alloc(&transformed_y), + bn_param_desc_, scale.template data>(), bias.template data>(), - C, - N, - H * W * D, - epsilon, this_factor, - transformed_y.template data(), - mean_out->template data>(), - variance_out->template data>(), - saved_mean->template data>(), - saved_variance->template data>()); - } else { - BNForwardTraining< - T, - block, - DataLayout::kNHWC><<>>( - transformed_x.template data(), - scale.template data>(), - bias.template data>(), - C, - N, - H * W * D, + ctx.template Alloc>(mean_out), + ctx.template Alloc>(variance_out), epsilon, - this_factor, - transformed_y.template data(), - mean_out->template data>(), - variance_out->template data>(), - saved_mean->template data>(), - saved_variance->template data>()); - } -// TODO(wangran16): wait for MIOpen to improve the performance of BN -// PADDLE_ENFORCE_GPU_SUCCESS( -// platform::dynload::miopenBatchNormalizationForwardTraining( -// handle, mode_, const_cast(static_cast( -// CudnnDataType::kOne())), -// const_cast( -// static_cast(CudnnDataType::kZero())), -// data_desc_, -// static_cast(transformed_x.template data()), -// data_desc_, -// static_cast( -// transformed_y.template mutable_data(ctx.GetPlace())), -// bn_param_desc_, -// const_cast(static_cast( -// scale->template data>())), -// const_cast(static_cast( -// bias->template data>())), -// this_factor, -// static_cast( -// mean_out->template mutable_data>( -// ctx.GetPlace())), -// static_cast(variance_out->template mutable_data< -// BatchNormParamType>(ctx.GetPlace())), -// epsilon, -// static_cast( -// saved_mean->template mutable_data>( -// ctx.GetPlace())), -// static_cast(saved_variance->template mutable_data< -// BatchNormParamType>(ctx.GetPlace())))); -#else - PADDLE_ENFORCE_GPU_SUCCESS( - paddle::platform::dynload::cudnnBatchNormalizationForwardTraining( - handle, - mode_, - CudnnDataType::kOne(), - CudnnDataType::kZero(), - data_desc_, - transformed_x.template data(), - data_desc_, - ctx.template Alloc(&transformed_y), - bn_param_desc_, - scale.template data>(), - bias.template data>(), - this_factor, - ctx.template Alloc>(mean_out), - ctx.template Alloc>(variance_out), - epsilon, - ctx.template Alloc>(saved_mean), - ctx.template Alloc>(saved_variance))); + ctx.template Alloc>(saved_mean), + ctx.template Alloc>(saved_variance))); +#endif // CUDNN_VERSION_MIN(7, 4, 1) #endif - } } } From c48e076349ba257bcd87f874f792aaf449852fd8 Mon Sep 17 00:00:00 2001 From: Zihang Yao <1162526220@qq.com> Date: Sun, 29 May 2022 22:55:48 +0800 Subject: [PATCH 02/13] add native kernel usage --- paddle/phi/kernels/gpu/batch_norm_kernel.cu | 222 ++++++++++++-------- 1 file changed, 134 insertions(+), 88 deletions(-) diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index e572ec70dbebe..08eea1f8717cd 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -524,103 +524,149 @@ void BatchNormKernel(const Context &ctx, // static_cast(saved_variance->template mutable_data< // BatchNormParamType>(ctx.GetPlace())))); #else -#if CUDNN_VERSION_MIN(7, 4, 1) - size_t workspace_size = 0; - size_t reserve_space_size = 0; - void *reserve_space_ptr = nullptr; - void *workspace_ptr = nullptr; - DenseTensor workspace_tensor; - DenseTensor reserve_space_tensor; - // Create reserve space and workspace for batch norm. - // Create tensor for each batchnorm op, it will be used in the - // backward. Thus this tensor shouldn't be temp. - // auto *reserve_space = ctx.Output("ReserveSpace"); - if (reserve_space == nullptr) { - reserve_space = &reserve_space_tensor; - } - PADDLE_ENFORCE_NOT_NULL( - reserve_space, - phi::errors::NotFound( - "The argument ReserveSpace of batch_norm op is not found.")); - // --------------- cudnn batchnorm workspace --------------- - PADDLE_ENFORCE_GPU_SUCCESS( - paddle::platform::dynload:: - cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( - /*handle=*/handle, - /*mode=*/mode_, - /*bnIps=*/CUDNN_BATCHNORM_OPS_BN, - /*xDesc=*/data_desc_, - /*zDesc=*/nullptr, - /*yDesc=*/data_desc_, - /*bnScaleBiasMeanVarDesc=*/bn_param_desc_, - /*activationDesc=*/nullptr, - /*sizeInBytes=*/&workspace_size)); - - // -------------- cudnn batchnorm reserve space -------------- - PADDLE_ENFORCE_GPU_SUCCESS( - paddle::platform::dynload:: - cudnnGetBatchNormalizationTrainingExReserveSpaceSize( - /*handle=*/handle, - /*mode=*/mode_, - /*bnOps=*/CUDNN_BATCHNORM_OPS_BN, - /*activationDesc=*/nullptr, - /*xDesc=*/data_desc_, - /*sizeInBytes=*/&reserve_space_size)); - - reserve_space->Resize({static_cast(reserve_space_size)}); - reserve_space_ptr = - static_cast(ctx.template Alloc(reserve_space)); - workspace_tensor.Resize({static_cast(workspace_size)}); - workspace_ptr = - static_cast(ctx.template Alloc(&workspace_tensor)); - PADDLE_ENFORCE_GPU_SUCCESS( - paddle::platform::dynload::cudnnBatchNormalizationForwardTrainingEx( - handle, - mode_, - CUDNN_BATCHNORM_OPS_BN, - CudnnDataType::kOne(), - CudnnDataType::kZero(), - data_desc_, + const bool use_native_kernel = (x_dims.size() == 2 && N >= 131070); + if(use_native_kernel) { + const int num = transformed_x.numel(); + const int block = 256; + const int max_threads = ctx.GetMaxPhysicalThreadCount(); + const int max_blocks = std::max(max_threads / block, 1); + const int grid = std::min(C, max_blocks); + if (compute_format == DataLayout::kNCHW) { + BNForwardTraining< + T, + block, + DataLayout::kNCHW><<>>( transformed_x.template data(), - nullptr, - nullptr, - data_desc_, - transformed_y.template data(), - bn_param_desc_, scale.template data>(), bias.template data>(), - this_factor, - ctx.template Alloc>(mean_out), - ctx.template Alloc>(variance_out), + C, + N, + H * W * D, epsilon, - ctx.template Alloc>(saved_mean), - ctx.template Alloc>(saved_variance), - nullptr, - workspace_ptr, - workspace_size, - reserve_space_ptr, - reserve_space_size)); -#else - PADDLE_ENFORCE_GPU_SUCCESS( - paddle::platform::dynload::cudnnBatchNormalizationForwardTraining( - handle, - mode_, - CudnnDataType::kOne(), - CudnnDataType::kZero(), - data_desc_, + this_factor, + transformed_y.template data(), + mean_out->template data>(), + variance_out->template data>(), + saved_mean->template data>(), + saved_variance->template data>()); + } else { + BNForwardTraining< + T, + block, + DataLayout::kNHWC><<>>( transformed_x.template data(), - data_desc_, - ctx.template Alloc(&transformed_y), - bn_param_desc_, scale.template data>(), bias.template data>(), - this_factor, - ctx.template Alloc>(mean_out), - ctx.template Alloc>(variance_out), + C, + N, + H * W * D, epsilon, - ctx.template Alloc>(saved_mean), - ctx.template Alloc>(saved_variance))); + this_factor, + transformed_y.template data(), + mean_out->template data>(), + variance_out->template data>(), + saved_mean->template data>(), + saved_variance->template data>()); + } + } else { +#if CUDNN_VERSION_MIN(7, 4, 1) + size_t workspace_size = 0; + size_t reserve_space_size = 0; + void *reserve_space_ptr = nullptr; + void *workspace_ptr = nullptr; + DenseTensor workspace_tensor; + DenseTensor reserve_space_tensor; + // Create reserve space and workspace for batch norm. + // Create tensor for each batchnorm op, it will be used in the + // backward. Thus this tensor shouldn't be temp. + // auto *reserve_space = ctx.Output("ReserveSpace"); + if (reserve_space == nullptr) { + reserve_space = &reserve_space_tensor; + } + PADDLE_ENFORCE_NOT_NULL( + reserve_space, + phi::errors::NotFound( + "The argument ReserveSpace of batch_norm op is not found.")); + // --------------- cudnn batchnorm workspace --------------- + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload:: + cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( + /*handle=*/handle, + /*mode=*/mode_, + /*bnIps=*/CUDNN_BATCHNORM_OPS_BN, + /*xDesc=*/data_desc_, + /*zDesc=*/nullptr, + /*yDesc=*/data_desc_, + /*bnScaleBiasMeanVarDesc=*/bn_param_desc_, + /*activationDesc=*/nullptr, + /*sizeInBytes=*/&workspace_size)); + + // -------------- cudnn batchnorm reserve space -------------- + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload:: + cudnnGetBatchNormalizationTrainingExReserveSpaceSize( + /*handle=*/handle, + /*mode=*/mode_, + /*bnOps=*/CUDNN_BATCHNORM_OPS_BN, + /*activationDesc=*/nullptr, + /*xDesc=*/data_desc_, + /*sizeInBytes=*/&reserve_space_size)); + + reserve_space->Resize({static_cast(reserve_space_size)}); + reserve_space_ptr = + static_cast(ctx.template Alloc(reserve_space)); + workspace_tensor.Resize({static_cast(workspace_size)}); + workspace_ptr = + static_cast(ctx.template Alloc(&workspace_tensor)); + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cudnnBatchNormalizationForwardTrainingEx( + handle, + mode_, + CUDNN_BATCHNORM_OPS_BN, + CudnnDataType::kOne(), + CudnnDataType::kZero(), + data_desc_, + transformed_x.template data(), + nullptr, + nullptr, + data_desc_, + transformed_y.template data(), + bn_param_desc_, + scale.template data>(), + bias.template data>(), + this_factor, + ctx.template Alloc>(mean_out), + ctx.template Alloc>(variance_out), + epsilon, + ctx.template Alloc>(saved_mean), + ctx.template Alloc>(saved_variance), + nullptr, + workspace_ptr, + workspace_size, + reserve_space_ptr, + reserve_space_size)); +#else + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cudnnBatchNormalizationForwardTraining( + handle, + mode_, + CudnnDataType::kOne(), + CudnnDataType::kZero(), + data_desc_, + transformed_x.template data(), + data_desc_, + ctx.template Alloc(&transformed_y), + bn_param_desc_, + scale.template data>(), + bias.template data>(), + this_factor, + ctx.template Alloc>(mean_out), + ctx.template Alloc>(variance_out), + epsilon, + ctx.template Alloc>(saved_mean), + ctx.template Alloc>(saved_variance))); #endif // CUDNN_VERSION_MIN(7, 4, 1) + } #endif } } From 0a68ba3641939219cc22c45a5ca772e419107ecf Mon Sep 17 00:00:00 2001 From: Zihang Yao <1162526220@qq.com> Date: Sat, 4 Jun 2022 00:21:54 +0800 Subject: [PATCH 03/13] add wellford impl --- paddle/phi/kernels/gpu/batch_norm_kernel.cu | 268 +++++++++++++++++++- 1 file changed, 264 insertions(+), 4 deletions(-) diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index 08eea1f8717cd..15a84c4ae918e 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -140,6 +140,265 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTraining( } } + +template +static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTrainingWellford( + const T *x, + const BatchNormParamType *scale, + const BatchNormParamType *bias, + const int C, + const int N, + const int HxW, + const double epsilon, + double exponentialAverageFactor, + T *y, + BatchNormParamType *mean, + BatchNormParamType *variance, + BatchNormParamType *save_mean, + BatchNormParamType *save_inv_variance) { + int outer_size = C; + int inner_size = N * HxW; + __shared__ BatchNormParamType mean_val; + __shared__ BatchNormParamType variance_val; + __shared__ BatchNormParamType inv_var_val; + + constexpr int THREADS_PER_WARP = 32; + constexpr int THREADS_BITS_PER_WARP = 5; + + constexpr int WARP_PER_BLOCK = BlockDim / THREADS_PER_WARP; + const int WARP_BITS_PER_BLOCK = (31 - __clz(WARP_PER_BLOCK)); + + __shared__ int warp_shared_count[WARP_PER_BLOCK]; + __shared__ BatchNormParamType warp_shared_mean[WARP_PER_BLOCK]; + __shared__ BatchNormParamType warp_shared_var_n[WARP_PER_BLOCK]; + + for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { + BatchNormParamType local_mean = static_cast>(0); + BatchNormParamType local_var_n = static_cast>(0); + int local_count = 0; + + // thread-local iterative computation + for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { + const int index = layout == phi::DataLayout::kNCHW + ? (j / HxW * C + i) * HxW + j % HxW + : j * outer_size + i; + BatchNormParamType x_i = static_cast>(x[index]); + BatchNormParamType delta = (x_i - local_mean); + local_count++; + local_mean += delta / local_count; + local_var_n += delta * (x_i - local_mean); + } + + // warp sum + for(int b_i = 0; b_i < THREADS_BITS_PER_WARP; b_i++) { + BatchNormParamType o_mean = __shfl_xor_sync(0xffffffff, local_mean, 1 << b_i, THREADS_PER_WARP); + int o_count = __shfl_xor_sync(0xffffffff, local_count, 1 << b_i, THREADS_PER_WARP); + BatchNormParamType factor = 1.0 / static_cast(max(1, local_count+o_count)); + local_var_n += (__shfl_xor_sync(0xffffffff, local_var_n, 1 << b_i, THREADS_PER_WARP) + (local_mean - o_mean) * (local_mean - o_mean) * local_count * o_count * factor); + local_mean = (local_count * local_mean + o_count * o_mean) * factor; + local_count += o_count; + } + + if (threadIdx.x % THREADS_PER_WARP == 0) { + warp_shared_count[threadIdx.x / THREADS_PER_WARP] = local_count; + warp_shared_mean[threadIdx.x / THREADS_PER_WARP] = local_mean; + warp_shared_var_n[threadIdx.x / THREADS_PER_WARP] = local_var_n; + } + __syncthreads(); + + // block sum + if (threadIdx.x < WARP_PER_BLOCK) { + local_count = warp_shared_count[threadIdx.x]; + local_mean = warp_shared_count[threadIdx.x]; + local_var_n = warp_shared_count[threadIdx.x]; + } + + for(int b_i = 0; b_i < WARP_BITS_PER_BLOCK; b_i++) { + BatchNormParamType o_mean = __shfl_xor_sync(0xffffffff, local_mean, 1 << b_i, THREADS_PER_WARP); + int o_count = __shfl_xor_sync(0xffffffff, local_count, 1 << b_i, THREADS_PER_WARP); + BatchNormParamType factor = 1.0 / static_cast(max(1, local_count+o_count)); + local_var_n += (__shfl_xor_sync(0xffffffff, local_var_n, 1 << b_i, THREADS_PER_WARP) + (local_mean - o_mean) * (local_mean - o_mean) * local_count * o_count * factor); + local_mean = (local_count * local_mean + o_count * o_mean) * factor; + local_count += o_count; + } + + if (threadIdx.x == 0) { + mean_val = local_mean; + variance_val = local_var_n / local_count; + inv_var_val = 1 / sqrt(variance_val + epsilon); + + if (save_mean && save_inv_variance) { + save_mean[i] = mean_val; + save_inv_variance[i] = inv_var_val; + } + mean[i] = (1 - exponentialAverageFactor) * mean_val + + exponentialAverageFactor * mean[i]; + variance[i] = (1 - exponentialAverageFactor) * variance_val + + exponentialAverageFactor * variance[i]; + } + __syncthreads(); + + for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { + const int index = layout == phi::DataLayout::kNCHW + ? (j / HxW * C + i) * HxW + j % HxW + : j * outer_size + i; + BatchNormParamType x_sub_mean = + static_cast>(x[index]) - mean_val; + y[index] = scale[i] * x_sub_mean * inv_var_val + bias[i]; + } + } +} + +template +static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTrainingWellfordParallel( + const T *x, + const BatchNormParamType *scale, + const BatchNormParamType *bias, + const int C, + const int N, + const int HxW, + const double epsilon, + double exponentialAverageFactor, + T *y, + BatchNormParamType *mean, + BatchNormParamType *variance, + BatchNormParamType *save_mean, + BatchNormParamType *save_inv_variance) { + int outer_size = C; + int inner_size = N * HxW; + __shared__ BatchNormParamType mean_val; + __shared__ BatchNormParamType variance_val; + __shared__ BatchNormParamType inv_var_val; + + constexpr int PARALLEL_LOADS = 4; + + constexpr int THREADS_PER_WARP = 32; + constexpr int THREADS_BITS_PER_WARP = 5; + + constexpr int WARP_PER_BLOCK = BlockDim / THREADS_PER_WARP; + const int WARP_BITS_PER_BLOCK = (31 - __clz(WARP_PER_BLOCK)); + + __shared__ int warp_shared_count[WARP_PER_BLOCK]; + __shared__ BatchNormParamType warp_shared_mean[WARP_PER_BLOCK]; + __shared__ BatchNormParamType warp_shared_var_n[WARP_PER_BLOCK]; + + for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { + BatchNormParamType tmp_local_mean[PARALLEL_LOADS]; + BatchNormParamType tmp_local_var_n[PARALLEL_LOADS]; + int tmp_local_count[PARALLEL_LOADS]; + + #pragma unroll + for(int k = 0; k < PARALLEL_LOADS; k++) { + tmp_local_mean[k] = static_cast>(0); + tmp_local_var_n[k] = static_cast>(0); + tmp_local_count[k] = 0; + } + + // thread-local iterative computation + for (int j = threadIdx.x; j < inner_size; j += PARALLEL_LOADS * blockDim.x) { + BatchNormParamType tmp_local_x[PARALLEL_LOADS]; + BatchNormParamType tmp_local_count_inv[PARALLEL_LOADS]; + BatchNormParamType valid[PARALLEL_LOADS]; + auto offset = j; + #pragma unroll + for(int k = 0; k < PARALLEL_LOADS; k++) { + if(offset < inner_size) { + const int index = layout == phi::DataLayout::kNCHW + ? (offset / HxW * C + i) * HxW + offset % HxW + : offset * outer_size + i; + tmp_local_x[k] = static_cast>(x[index]); + tmp_local_count[k]++; + tmp_local_count_inv[k] = static_cast>(1) / tmp_local_count[k]; + valid[k] = static_cast>(1); + } else { + tmp_local_x[k] = static_cast>(0); + tmp_local_count_inv[k] = static_cast>(0); + valid[k] = static_cast>(0); + } + offset += blockDim.x; + } + + #pragma unroll + for(int k = 0; k < PARALLEL_LOADS; k++) { + BatchNormParamType delta = (tmp_local_x[k] - tmp_local_mean[k]); + tmp_local_mean[k] += delta * tmp_local_count_inv[k]; + tmp_local_var_n[k] += delta * (tmp_local_x[k] - tmp_local_mean[k]) * valid[k]; + } + } + + #pragma unroll + for(int k = 1; k < PARALLEL_LOADS; k++) { + BatchNormParamType factor = 1.0 / static_cast(max(1, tmp_local_count[0]+tmp_local_count[k])); + BatchNormParamType delta = (tmp_local_mean[0] - tmp_local_mean[k]); + tmp_local_mean[0] = (tmp_local_count[0] * tmp_local_mean[0] + tmp_local_count[k] * tmp_local_mean[k]) * factor; + tmp_local_var_n[0] += (tmp_local_var_n[k] + delta * delta * tmp_local_count[0] * tmp_local_count[k] * factor); + tmp_local_count[0] += tmp_local_count[k]; + } + + BatchNormParamType local_mean = tmp_local_mean[0]; + BatchNormParamType local_var_n = tmp_local_var_n[0]; + int local_count = tmp_local_count[0]; + + // warp sum + for(int b_i = 0; b_i < THREADS_BITS_PER_WARP; b_i++) { + BatchNormParamType o_mean = __shfl_xor_sync(0xffffffff, local_mean, 1 << b_i, THREADS_PER_WARP); + int o_count = __shfl_xor_sync(0xffffffff, local_count, 1 << b_i, THREADS_PER_WARP); + BatchNormParamType factor = 1.0 / static_cast(max(1, local_count+o_count)); + local_var_n += (__shfl_xor_sync(0xffffffff, local_var_n, 1 << b_i, THREADS_PER_WARP) + (local_mean - o_mean) * (local_mean - o_mean) * local_count * o_count * factor); + local_mean = (local_count * local_mean + o_count * o_mean) * factor; + local_count += o_count; + } + + if (threadIdx.x % THREADS_PER_WARP == 0) { + warp_shared_count[threadIdx.x / THREADS_PER_WARP] = local_count; + warp_shared_mean[threadIdx.x / THREADS_PER_WARP] = local_mean; + warp_shared_var_n[threadIdx.x / THREADS_PER_WARP] = local_var_n; + } + __syncthreads(); + + // block sum + if (threadIdx.x < WARP_PER_BLOCK) { + local_count = warp_shared_count[threadIdx.x]; + local_mean = warp_shared_count[threadIdx.x]; + local_var_n = warp_shared_count[threadIdx.x]; + } + + for(int b_i = 0; b_i < WARP_BITS_PER_BLOCK; b_i++) { + BatchNormParamType o_mean = __shfl_xor_sync(0xffffffff, local_mean, 1 << b_i, THREADS_PER_WARP); + int o_count = __shfl_xor_sync(0xffffffff, local_count, 1 << b_i, THREADS_PER_WARP); + BatchNormParamType factor = 1.0 / static_cast(max(1, local_count+o_count)); + local_var_n += (__shfl_xor_sync(0xffffffff, local_var_n, 1 << b_i, THREADS_PER_WARP) + (local_mean - o_mean) * (local_mean - o_mean) * local_count * o_count * factor); + local_mean = (local_count * local_mean + o_count * o_mean) * factor; + local_count += o_count; + } + + if (threadIdx.x == 0) { + mean_val = local_mean; + variance_val = local_var_n / local_count; + inv_var_val = 1 / sqrt(variance_val + epsilon); + + if (save_mean && save_inv_variance) { + save_mean[i] = mean_val; + save_inv_variance[i] = inv_var_val; + } + mean[i] = (1 - exponentialAverageFactor) * mean_val + + exponentialAverageFactor * mean[i]; + variance[i] = (1 - exponentialAverageFactor) * variance_val + + exponentialAverageFactor * variance[i]; + } + __syncthreads(); + + for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { + const int index = layout == phi::DataLayout::kNCHW + ? (j / HxW * C + i) * HxW + j % HxW + : j * outer_size + i; + BatchNormParamType x_sub_mean = + static_cast>(x[index]) - mean_val; + y[index] = scale[i] * x_sub_mean * inv_var_val + bias[i]; + } + } +} + template void BatchNormKernel(const Context &ctx, const DenseTensor &x, @@ -524,15 +783,16 @@ void BatchNormKernel(const Context &ctx, // static_cast(saved_variance->template mutable_data< // BatchNormParamType>(ctx.GetPlace())))); #else - const bool use_native_kernel = (x_dims.size() == 2 && N >= 131070); + //const bool use_native_kernel = (x_dims.size() == 2 && N >= 131070); + const bool use_native_kernel = true; if(use_native_kernel) { const int num = transformed_x.numel(); - const int block = 256; + const int block = 1024; const int max_threads = ctx.GetMaxPhysicalThreadCount(); const int max_blocks = std::max(max_threads / block, 1); const int grid = std::min(C, max_blocks); if (compute_format == DataLayout::kNCHW) { - BNForwardTraining< + BNForwardTrainingWellfordParallel< T, block, DataLayout::kNCHW><<>>( @@ -550,7 +810,7 @@ void BatchNormKernel(const Context &ctx, saved_mean->template data>(), saved_variance->template data>()); } else { - BNForwardTraining< + BNForwardTrainingWellfordParallel< T, block, DataLayout::kNHWC><<>>( From b3248c9ef649ac2073853ae0b39f8b4fa3175d1b Mon Sep 17 00:00:00 2001 From: Zihang Yao <1162526220@qq.com> Date: Sat, 4 Jun 2022 15:32:25 +0800 Subject: [PATCH 04/13] add shmem impl --- paddle/phi/kernels/gpu/batch_norm_kernel.cu | 83 +++++++++++++++++++-- 1 file changed, 77 insertions(+), 6 deletions(-) diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index 15a84c4ae918e..394b6399977dd 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -399,6 +399,77 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTrainingWellfordParallel } } + +template +static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTrainingSMem( + const T *x, + const BatchNormParamType *scale, + const BatchNormParamType *bias, + const int C, + const int N, + const int HxW, + const double epsilon, + double exponentialAverageFactor, + T *y, + BatchNormParamType *mean, + BatchNormParamType *variance, + BatchNormParamType *save_mean, + BatchNormParamType *save_inv_variance) { + extern __shared__ __align__(sizeof(double)) char smem_buf[]; + BatchNormParamType* x_buf = reinterpret_cast*>(smem_buf); + + int outer_size = C; + int inner_size = N * HxW; + typedef cub::BlockReduce, BlockDim> BlockReduce; + __shared__ typename BlockReduce::TempStorage mean_storage; + __shared__ typename BlockReduce::TempStorage variance_storeage; + __shared__ BatchNormParamType mean_val; + __shared__ BatchNormParamType variance_val; + __shared__ BatchNormParamType inv_var_val; + + for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { + BatchNormParamType x_sum = static_cast>(0); + BatchNormParamType x_square_sum = static_cast>(0); + + for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { + const int index = layout == phi::DataLayout::kNCHW + ? (j / HxW * C + i) * HxW + j % HxW + : j * outer_size + i; + BatchNormParamType x_i = static_cast>(x[index]); + x_buf[j] = x_i; + x_sum += x_i; + x_square_sum += x_i * x_i; + } + x_sum = BlockReduce(mean_storage).Reduce(x_sum, cub::Sum()); + x_square_sum = + BlockReduce(variance_storeage).Reduce(x_square_sum, cub::Sum()); + if (threadIdx.x == 0) { + mean_val = x_sum / inner_size; + variance_val = x_square_sum / inner_size - mean_val * mean_val; + inv_var_val = 1 / sqrt(variance_val + epsilon); + + if (save_mean && save_inv_variance) { + save_mean[i] = mean_val; + save_inv_variance[i] = inv_var_val; + } + mean[i] = (1 - exponentialAverageFactor) * mean_val + + exponentialAverageFactor * mean[i]; + variance[i] = (1 - exponentialAverageFactor) * variance_val + + exponentialAverageFactor * variance[i]; + } + __syncthreads(); + + for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { + const int index = layout == phi::DataLayout::kNCHW + ? (j / HxW * C + i) * HxW + j % HxW + : j * outer_size + i; + BatchNormParamType x_sub_mean = + static_cast>(x_buf[j]) - mean_val; + y[index] = scale[i] * x_sub_mean * inv_var_val + bias[i]; + } + } +} + template void BatchNormKernel(const Context &ctx, const DenseTensor &x, @@ -786,16 +857,16 @@ void BatchNormKernel(const Context &ctx, //const bool use_native_kernel = (x_dims.size() == 2 && N >= 131070); const bool use_native_kernel = true; if(use_native_kernel) { - const int num = transformed_x.numel(); - const int block = 1024; + const int block = 512; const int max_threads = ctx.GetMaxPhysicalThreadCount(); const int max_blocks = std::max(max_threads / block, 1); const int grid = std::min(C, max_blocks); + const size_t smem_size = N * H * W * D * sizeof(BatchNormParamType); if (compute_format == DataLayout::kNCHW) { - BNForwardTrainingWellfordParallel< + BNForwardTrainingSMem< T, block, - DataLayout::kNCHW><<>>( + DataLayout::kNCHW><<>>( transformed_x.template data(), scale.template data>(), bias.template data>(), @@ -810,10 +881,10 @@ void BatchNormKernel(const Context &ctx, saved_mean->template data>(), saved_variance->template data>()); } else { - BNForwardTrainingWellfordParallel< + BNForwardTrainingSMem< T, block, - DataLayout::kNHWC><<>>( + DataLayout::kNHWC><<>>( transformed_x.template data(), scale.template data>(), bias.template data>(), From 78349a2d57bc34fc0897bacd9f42cc2fa9a918ee Mon Sep 17 00:00:00 2001 From: Zihang Yao <1162526220@qq.com> Date: Sat, 4 Jun 2022 17:07:29 +0800 Subject: [PATCH 05/13] add dispatch logic --- paddle/phi/kernels/gpu/batch_norm_kernel.cu | 92 ++++++++++++++++++--- 1 file changed, 79 insertions(+), 13 deletions(-) diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index 394b6399977dd..ab797aa186f65 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -470,6 +470,81 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTrainingSMem( } } +template +inline bool TryDispatchBNForwardTrainingSMem( + const Context &ctx, + const T *x, + const BatchNormParamType *scale, + const BatchNormParamType *bias, + const int C, + const int N, + const int HxW, + const double epsilon, + double exponentialAverageFactor, + T *y, + BatchNormParamType *mean, + BatchNormParamType *variance, + BatchNormParamType *save_mean, + BatchNormParamType *save_inv_variance) { + constexpr int block_size = 512; + const size_t smem = N * HxW * sizeof(BatchNormParamType); + int max_active_blocks_conf; + { + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks_conf, + BNForwardTrainingSMem, + block_size, smem); + } + if (max_active_blocks_conf <= 0) { + return false; + } + const int max_threads = ctx.GetMaxPhysicalThreadCount(); + const int max_blocks = std::max(max_threads / block_size, 1); + const int grid = std::min(C, max_blocks); + BNForwardTrainingSMem<<>>( + x, scale, bias, C, N, HxW, epsilon, exponentialAverageFactor, + y, mean, variance, save_mean, save_inv_variance); + return true; +} + +template +inline void DispatchBNForwardTraining( + const Context &ctx, + const T *x, + const BatchNormParamType *scale, + const BatchNormParamType *bias, + const int C, + const int N, + const int HxW, + const double epsilon, + double exponentialAverageFactor, + T *y, + BatchNormParamType *mean, + BatchNormParamType *variance, + BatchNormParamType *save_mean, + BatchNormParamType *save_inv_variance) { + if ((N * HxW) <= 1024) { + // TODO: impl register-cache version + return; + } else { + bool dispatch_smem_impl_success = false; + { + dispatch_smem_impl_success = TryDispatchBNForwardTrainingSMem( + ctx, x, scale, bias, C, N, HxW, epsilon, exponentialAverageFactor, + y, mean, variance, save_mean, save_inv_variance); + } + if (!dispatch_smem_impl_success) { + const int block = 512; + const int max_threads = ctx.GetMaxPhysicalThreadCount(); + const int max_blocks = std::max(max_threads / block, 1); + const int grid = std::min(C, max_blocks); + return BNForwardTraining<<>>( + x, scale, bias, C, N, HxW, epsilon, exponentialAverageFactor, + y, mean, variance, save_mean, save_inv_variance); + } + } +} + template void BatchNormKernel(const Context &ctx, const DenseTensor &x, @@ -857,16 +932,9 @@ void BatchNormKernel(const Context &ctx, //const bool use_native_kernel = (x_dims.size() == 2 && N >= 131070); const bool use_native_kernel = true; if(use_native_kernel) { - const int block = 512; - const int max_threads = ctx.GetMaxPhysicalThreadCount(); - const int max_blocks = std::max(max_threads / block, 1); - const int grid = std::min(C, max_blocks); - const size_t smem_size = N * H * W * D * sizeof(BatchNormParamType); if (compute_format == DataLayout::kNCHW) { - BNForwardTrainingSMem< - T, - block, - DataLayout::kNCHW><<>>( + DispatchBNForwardTraining( + ctx, transformed_x.template data(), scale.template data>(), bias.template data>(), @@ -881,10 +949,8 @@ void BatchNormKernel(const Context &ctx, saved_mean->template data>(), saved_variance->template data>()); } else { - BNForwardTrainingSMem< - T, - block, - DataLayout::kNHWC><<>>( + DispatchBNForwardTraining( + ctx, transformed_x.template data(), scale.template data>(), bias.template data>(), From 98c66f0df7b2ab0c33c89a0bd2a18fce2b13440a Mon Sep 17 00:00:00 2001 From: Zihang Yao <1162526220@qq.com> Date: Sun, 5 Jun 2022 00:21:12 +0800 Subject: [PATCH 06/13] add channel_last impl --- paddle/phi/kernels/gpu/batch_norm_kernel.cu | 280 +++++++++++++++++--- 1 file changed, 250 insertions(+), 30 deletions(-) diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index ab797aa186f65..14778a89a4657 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -25,6 +25,7 @@ namespace cub = hipcub; #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/batch_norm_kernel.h" #include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" #include "paddle/fluid/operators/norm_utils.cu.h" #include "paddle/fluid/operators/norm_utils.h" @@ -399,6 +400,158 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTrainingWellfordParallel } } +template +static __global__ void BNForwardTraining2D( + const T *x, + const BatchNormParamType *scale, + const BatchNormParamType *bias, + const int C, + const int N, + const int HxW, + const double epsilon, + double exponentialAverageFactor, + T *y, + BatchNormParamType *mean, + BatchNormParamType *variance, + BatchNormParamType *save_mean, + BatchNormParamType *save_inv_variance, + BatchNormParamType *block_data_ptr, + int *flag_ptr) { + int outer_size = C; + int inner_size = N * HxW; + + extern __shared__ __align__(sizeof(double)) char smem_buf[]; + + BatchNormParamType* mean_val = reinterpret_cast*>(smem_buf); + BatchNormParamType* variance_val = reinterpret_cast*>(&smem_buf[blockDim.x]); + BatchNormParamType* inv_var_val = reinterpret_cast*>(&smem_buf[2*blockDim.x]); + + __shared__ BatchNormParamType smem_sum[BlockDim]; + __shared__ BatchNormParamType smem_square_sum[BlockDim]; + + int outer_loop_stride = gridDim.x * blockDim.x; + int inner_loop_stride = gridDim.y * blockDim.y; + + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < outer_size; i += outer_loop_stride) { + BatchNormParamType x_sum = static_cast>(0); + BatchNormParamType x_square_sum = static_cast>(0); + + for (int j = blockIdx.y * blockDim.y + threadIdx.y; j < inner_size; j += inner_loop_stride) { + const int index = layout == phi::DataLayout::kNCHW + ? (j / HxW * C + i) * HxW + j % HxW + : j * outer_size + i; + BatchNormParamType x_i = static_cast>(x[index]); + x_sum += x_i; + x_square_sum += x_i * x_i; + } + + // vertical block sum + int tid = threadIdx.x + threadIdx.y * blockDim.x; + #pragma unroll + for (int offset = blockDim.y/2; offset > 0; offset >>= 1) { + if (threadIdx.y < offset*2) { + smem_sum[tid] = x_sum; + smem_square_sum[tid] = x_square_sum; + } + __syncthreads(); + if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { + int pair_tid = tid + offset * blockDim.x; + x_sum += smem_sum[pair_tid]; + x_square_sum += smem_square_sum[pair_tid]; + } + } + + if (gridDim.y > 1) { + volatile BatchNormParamType* staging_sum = block_data_ptr; + volatile BatchNormParamType* staging_square_sum = &block_data_ptr[C*gridDim.y]; + // write block data to global memory + if (threadIdx.y == 0) { + staging_sum[i + blockIdx.y * C] = x_sum; + staging_square_sum[i + blockIdx.y * C] = x_square_sum; + } + + // make sure write is visible to all blocks + __threadfence(); + __syncthreads(); + + __shared__ bool is_last_block_done; + // mark block done + if (threadIdx.x == 0 && threadIdx.y == 0) { + int old = atomicAdd(&flag_ptr[blockIdx.x], 1); + is_last_block_done = (old == (gridDim.y-1)); + } + + __syncthreads(); + + if (is_last_block_done) { + x_sum = static_cast>(0); + x_square_sum = static_cast>(0); + // thread sum + for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) { + x_sum += staging_sum[i+y*C]; + x_square_sum += staging_square_sum[i+y*C]; + } + + // vertical block sum + int tid = threadIdx.x + threadIdx.y * blockDim.x; + #pragma unroll + for (int offset = blockDim.y/2; offset > 0; offset >>= 1) { + if (threadIdx.y < offset*2) { + smem_sum[tid] = x_sum; + smem_square_sum[tid] = x_square_sum; + } + __syncthreads(); + if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { + int pair_tid = tid + offset * blockDim.x; + x_sum += smem_sum[pair_tid]; + x_square_sum += smem_square_sum[pair_tid]; + } + } + + // final compute + if(threadIdx.y == 0) { + mean_val[threadIdx.x] = x_sum / inner_size; + variance_val[threadIdx.x] = x_square_sum / inner_size - mean_val[threadIdx.x] * mean_val[threadIdx.x]; + inv_var_val[threadIdx.x] = 1 / sqrt(variance_val[threadIdx.x] + epsilon); + + if (save_mean && save_inv_variance) { + save_mean[i] = mean_val[threadIdx.x]; + save_inv_variance[i] = inv_var_val[threadIdx.x]; + } + mean[i] = (1 - exponentialAverageFactor) * mean_val[threadIdx.x] + + exponentialAverageFactor * mean[i]; + variance[i] = (1 - exponentialAverageFactor) * variance_val[threadIdx.x] + + exponentialAverageFactor * variance[i]; + } + } + } else { + if(blockIdx.y == 0 && threadIdx.y == 0) { + mean_val[threadIdx.x] = x_sum / inner_size; + variance_val[threadIdx.x] = x_square_sum / inner_size - mean_val[threadIdx.x] * mean_val[threadIdx.x]; + inv_var_val[threadIdx.x] = 1 / sqrt(variance_val[threadIdx.x] + epsilon); + + if (save_mean && save_inv_variance) { + save_mean[i] = mean_val[threadIdx.x]; + save_inv_variance[i] = inv_var_val[threadIdx.x]; + } + mean[i] = (1 - exponentialAverageFactor) * mean_val[threadIdx.x] + + exponentialAverageFactor * mean[i]; + variance[i] = (1 - exponentialAverageFactor) * variance_val[threadIdx.x] + + exponentialAverageFactor * variance[i]; + } + } + __syncthreads(); + + for (int j = blockIdx.y * blockDim.y + threadIdx.y; j < inner_size; j += blockDim.x) { + const int index = layout == phi::DataLayout::kNCHW + ? (j / HxW * C + i) * HxW + j % HxW + : j * outer_size + i; + BatchNormParamType x_sub_mean = + static_cast>(x[index]) - mean_val[threadIdx.x]; + y[index] = scale[i] * x_sub_mean * inv_var_val[threadIdx.x] + bias[i]; + } + } +} template static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTrainingSMem( @@ -932,38 +1085,105 @@ void BatchNormKernel(const Context &ctx, //const bool use_native_kernel = (x_dims.size() == 2 && N >= 131070); const bool use_native_kernel = true; if(use_native_kernel) { + dim3 block; + dim3 grid; + + const int block_size = 512; + // init block&grid config + int block_x = std::min(phi::funcs::details::GetLastPow2(C), 32); + int block_y = std::min(phi::funcs::details::GetLastPow2(N * H * W * D / 16), block_size / block_x); + if (block_x * block_y != block_size) { + block_x = std::min(phi::funcs::details::GetLastPow2(C), block_size / block_y); + } + int grid_x = (C + block_x - 1) / block_x; + int grid_y = std::min((N * H * W * D + block_y * 16 - 1) / (block_y * 16), 128); + + block.x = block_x; + block.y = block_y; + grid.x = grid_x; + grid.y = grid_y; + + // init intermediate storage + DenseTensor block_data_tensor; + DenseTensor flag_tensor; + BatchNormParamType* block_data_ptr = nullptr; + int* flag_ptr = nullptr; + if(grid.y > 1) { + block_data_tensor.Resize({static_cast(2 * C * grid.y * sizeof(BatchNormParamType))}); + flag_tensor.Resize({static_cast(grid.x * sizeof(int))}); + + block_data_ptr = static_cast*>(ctx.template Alloc>(&block_data_tensor)); + flag_ptr = static_cast(ctx.template Alloc(&flag_tensor)); + } + + size_t smem_size = 3 * sizeof(BatchNormParamType) * block.x; if (compute_format == DataLayout::kNCHW) { - DispatchBNForwardTraining( - ctx, - transformed_x.template data(), - scale.template data>(), - bias.template data>(), - C, - N, - H * W * D, - epsilon, - this_factor, - transformed_y.template data(), - mean_out->template data>(), - variance_out->template data>(), - saved_mean->template data>(), - saved_variance->template data>()); + BNForwardTraining2D + <<>>( + transformed_x.template data(), + scale.template data>(), + bias.template data>(), + C, + N, + H * W * D, + epsilon, + this_factor, + transformed_y.template data(), + mean_out->template data>(), + variance_out->template data>(), + saved_mean->template data>(), + saved_variance->template data>(), + block_data_ptr, + flag_ptr); + // DispatchBNForwardTraining( + // ctx, + // transformed_x.template data(), + // scale.template data>(), + // bias.template data>(), + // C, + // N, + // H * W * D, + // epsilon, + // this_factor, + // transformed_y.template data(), + // mean_out->template data>(), + // variance_out->template data>(), + // saved_mean->template data>(), + // saved_variance->template data>()); } else { - DispatchBNForwardTraining( - ctx, - transformed_x.template data(), - scale.template data>(), - bias.template data>(), - C, - N, - H * W * D, - epsilon, - this_factor, - transformed_y.template data(), - mean_out->template data>(), - variance_out->template data>(), - saved_mean->template data>(), - saved_variance->template data>()); + BNForwardTraining2D + <<>>( + transformed_x.template data(), + scale.template data>(), + bias.template data>(), + C, + N, + H * W * D, + epsilon, + this_factor, + transformed_y.template data(), + mean_out->template data>(), + variance_out->template data>(), + saved_mean->template data>(), + saved_variance->template data>(), + block_data_ptr, + flag_ptr); + + // DispatchBNForwardTraining( + // ctx, + // transformed_x.template data(), + // scale.template data>(), + // bias.template data>(), + // C, + // N, + // H * W * D, + // epsilon, + // this_factor, + // transformed_y.template data(), + // mean_out->template data>(), + // variance_out->template data>(), + // saved_mean->template data>(), + // saved_variance->template data>()); } } else { #if CUDNN_VERSION_MIN(7, 4, 1) From f4320ccfa2ea9de1881820808bd0ba35c26bdbfc Mon Sep 17 00:00:00 2001 From: Zihang Yao <1162526220@qq.com> Date: Mon, 6 Jun 2022 10:25:22 +0800 Subject: [PATCH 07/13] revert --- paddle/phi/kernels/gpu/batch_norm_kernel.cu | 634 +------------------- 1 file changed, 9 insertions(+), 625 deletions(-) diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index 14778a89a4657..1e3b346271b23 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -141,563 +141,6 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTraining( } } - -template -static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTrainingWellford( - const T *x, - const BatchNormParamType *scale, - const BatchNormParamType *bias, - const int C, - const int N, - const int HxW, - const double epsilon, - double exponentialAverageFactor, - T *y, - BatchNormParamType *mean, - BatchNormParamType *variance, - BatchNormParamType *save_mean, - BatchNormParamType *save_inv_variance) { - int outer_size = C; - int inner_size = N * HxW; - __shared__ BatchNormParamType mean_val; - __shared__ BatchNormParamType variance_val; - __shared__ BatchNormParamType inv_var_val; - - constexpr int THREADS_PER_WARP = 32; - constexpr int THREADS_BITS_PER_WARP = 5; - - constexpr int WARP_PER_BLOCK = BlockDim / THREADS_PER_WARP; - const int WARP_BITS_PER_BLOCK = (31 - __clz(WARP_PER_BLOCK)); - - __shared__ int warp_shared_count[WARP_PER_BLOCK]; - __shared__ BatchNormParamType warp_shared_mean[WARP_PER_BLOCK]; - __shared__ BatchNormParamType warp_shared_var_n[WARP_PER_BLOCK]; - - for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { - BatchNormParamType local_mean = static_cast>(0); - BatchNormParamType local_var_n = static_cast>(0); - int local_count = 0; - - // thread-local iterative computation - for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { - const int index = layout == phi::DataLayout::kNCHW - ? (j / HxW * C + i) * HxW + j % HxW - : j * outer_size + i; - BatchNormParamType x_i = static_cast>(x[index]); - BatchNormParamType delta = (x_i - local_mean); - local_count++; - local_mean += delta / local_count; - local_var_n += delta * (x_i - local_mean); - } - - // warp sum - for(int b_i = 0; b_i < THREADS_BITS_PER_WARP; b_i++) { - BatchNormParamType o_mean = __shfl_xor_sync(0xffffffff, local_mean, 1 << b_i, THREADS_PER_WARP); - int o_count = __shfl_xor_sync(0xffffffff, local_count, 1 << b_i, THREADS_PER_WARP); - BatchNormParamType factor = 1.0 / static_cast(max(1, local_count+o_count)); - local_var_n += (__shfl_xor_sync(0xffffffff, local_var_n, 1 << b_i, THREADS_PER_WARP) + (local_mean - o_mean) * (local_mean - o_mean) * local_count * o_count * factor); - local_mean = (local_count * local_mean + o_count * o_mean) * factor; - local_count += o_count; - } - - if (threadIdx.x % THREADS_PER_WARP == 0) { - warp_shared_count[threadIdx.x / THREADS_PER_WARP] = local_count; - warp_shared_mean[threadIdx.x / THREADS_PER_WARP] = local_mean; - warp_shared_var_n[threadIdx.x / THREADS_PER_WARP] = local_var_n; - } - __syncthreads(); - - // block sum - if (threadIdx.x < WARP_PER_BLOCK) { - local_count = warp_shared_count[threadIdx.x]; - local_mean = warp_shared_count[threadIdx.x]; - local_var_n = warp_shared_count[threadIdx.x]; - } - - for(int b_i = 0; b_i < WARP_BITS_PER_BLOCK; b_i++) { - BatchNormParamType o_mean = __shfl_xor_sync(0xffffffff, local_mean, 1 << b_i, THREADS_PER_WARP); - int o_count = __shfl_xor_sync(0xffffffff, local_count, 1 << b_i, THREADS_PER_WARP); - BatchNormParamType factor = 1.0 / static_cast(max(1, local_count+o_count)); - local_var_n += (__shfl_xor_sync(0xffffffff, local_var_n, 1 << b_i, THREADS_PER_WARP) + (local_mean - o_mean) * (local_mean - o_mean) * local_count * o_count * factor); - local_mean = (local_count * local_mean + o_count * o_mean) * factor; - local_count += o_count; - } - - if (threadIdx.x == 0) { - mean_val = local_mean; - variance_val = local_var_n / local_count; - inv_var_val = 1 / sqrt(variance_val + epsilon); - - if (save_mean && save_inv_variance) { - save_mean[i] = mean_val; - save_inv_variance[i] = inv_var_val; - } - mean[i] = (1 - exponentialAverageFactor) * mean_val + - exponentialAverageFactor * mean[i]; - variance[i] = (1 - exponentialAverageFactor) * variance_val + - exponentialAverageFactor * variance[i]; - } - __syncthreads(); - - for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { - const int index = layout == phi::DataLayout::kNCHW - ? (j / HxW * C + i) * HxW + j % HxW - : j * outer_size + i; - BatchNormParamType x_sub_mean = - static_cast>(x[index]) - mean_val; - y[index] = scale[i] * x_sub_mean * inv_var_val + bias[i]; - } - } -} - -template -static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTrainingWellfordParallel( - const T *x, - const BatchNormParamType *scale, - const BatchNormParamType *bias, - const int C, - const int N, - const int HxW, - const double epsilon, - double exponentialAverageFactor, - T *y, - BatchNormParamType *mean, - BatchNormParamType *variance, - BatchNormParamType *save_mean, - BatchNormParamType *save_inv_variance) { - int outer_size = C; - int inner_size = N * HxW; - __shared__ BatchNormParamType mean_val; - __shared__ BatchNormParamType variance_val; - __shared__ BatchNormParamType inv_var_val; - - constexpr int PARALLEL_LOADS = 4; - - constexpr int THREADS_PER_WARP = 32; - constexpr int THREADS_BITS_PER_WARP = 5; - - constexpr int WARP_PER_BLOCK = BlockDim / THREADS_PER_WARP; - const int WARP_BITS_PER_BLOCK = (31 - __clz(WARP_PER_BLOCK)); - - __shared__ int warp_shared_count[WARP_PER_BLOCK]; - __shared__ BatchNormParamType warp_shared_mean[WARP_PER_BLOCK]; - __shared__ BatchNormParamType warp_shared_var_n[WARP_PER_BLOCK]; - - for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { - BatchNormParamType tmp_local_mean[PARALLEL_LOADS]; - BatchNormParamType tmp_local_var_n[PARALLEL_LOADS]; - int tmp_local_count[PARALLEL_LOADS]; - - #pragma unroll - for(int k = 0; k < PARALLEL_LOADS; k++) { - tmp_local_mean[k] = static_cast>(0); - tmp_local_var_n[k] = static_cast>(0); - tmp_local_count[k] = 0; - } - - // thread-local iterative computation - for (int j = threadIdx.x; j < inner_size; j += PARALLEL_LOADS * blockDim.x) { - BatchNormParamType tmp_local_x[PARALLEL_LOADS]; - BatchNormParamType tmp_local_count_inv[PARALLEL_LOADS]; - BatchNormParamType valid[PARALLEL_LOADS]; - auto offset = j; - #pragma unroll - for(int k = 0; k < PARALLEL_LOADS; k++) { - if(offset < inner_size) { - const int index = layout == phi::DataLayout::kNCHW - ? (offset / HxW * C + i) * HxW + offset % HxW - : offset * outer_size + i; - tmp_local_x[k] = static_cast>(x[index]); - tmp_local_count[k]++; - tmp_local_count_inv[k] = static_cast>(1) / tmp_local_count[k]; - valid[k] = static_cast>(1); - } else { - tmp_local_x[k] = static_cast>(0); - tmp_local_count_inv[k] = static_cast>(0); - valid[k] = static_cast>(0); - } - offset += blockDim.x; - } - - #pragma unroll - for(int k = 0; k < PARALLEL_LOADS; k++) { - BatchNormParamType delta = (tmp_local_x[k] - tmp_local_mean[k]); - tmp_local_mean[k] += delta * tmp_local_count_inv[k]; - tmp_local_var_n[k] += delta * (tmp_local_x[k] - tmp_local_mean[k]) * valid[k]; - } - } - - #pragma unroll - for(int k = 1; k < PARALLEL_LOADS; k++) { - BatchNormParamType factor = 1.0 / static_cast(max(1, tmp_local_count[0]+tmp_local_count[k])); - BatchNormParamType delta = (tmp_local_mean[0] - tmp_local_mean[k]); - tmp_local_mean[0] = (tmp_local_count[0] * tmp_local_mean[0] + tmp_local_count[k] * tmp_local_mean[k]) * factor; - tmp_local_var_n[0] += (tmp_local_var_n[k] + delta * delta * tmp_local_count[0] * tmp_local_count[k] * factor); - tmp_local_count[0] += tmp_local_count[k]; - } - - BatchNormParamType local_mean = tmp_local_mean[0]; - BatchNormParamType local_var_n = tmp_local_var_n[0]; - int local_count = tmp_local_count[0]; - - // warp sum - for(int b_i = 0; b_i < THREADS_BITS_PER_WARP; b_i++) { - BatchNormParamType o_mean = __shfl_xor_sync(0xffffffff, local_mean, 1 << b_i, THREADS_PER_WARP); - int o_count = __shfl_xor_sync(0xffffffff, local_count, 1 << b_i, THREADS_PER_WARP); - BatchNormParamType factor = 1.0 / static_cast(max(1, local_count+o_count)); - local_var_n += (__shfl_xor_sync(0xffffffff, local_var_n, 1 << b_i, THREADS_PER_WARP) + (local_mean - o_mean) * (local_mean - o_mean) * local_count * o_count * factor); - local_mean = (local_count * local_mean + o_count * o_mean) * factor; - local_count += o_count; - } - - if (threadIdx.x % THREADS_PER_WARP == 0) { - warp_shared_count[threadIdx.x / THREADS_PER_WARP] = local_count; - warp_shared_mean[threadIdx.x / THREADS_PER_WARP] = local_mean; - warp_shared_var_n[threadIdx.x / THREADS_PER_WARP] = local_var_n; - } - __syncthreads(); - - // block sum - if (threadIdx.x < WARP_PER_BLOCK) { - local_count = warp_shared_count[threadIdx.x]; - local_mean = warp_shared_count[threadIdx.x]; - local_var_n = warp_shared_count[threadIdx.x]; - } - - for(int b_i = 0; b_i < WARP_BITS_PER_BLOCK; b_i++) { - BatchNormParamType o_mean = __shfl_xor_sync(0xffffffff, local_mean, 1 << b_i, THREADS_PER_WARP); - int o_count = __shfl_xor_sync(0xffffffff, local_count, 1 << b_i, THREADS_PER_WARP); - BatchNormParamType factor = 1.0 / static_cast(max(1, local_count+o_count)); - local_var_n += (__shfl_xor_sync(0xffffffff, local_var_n, 1 << b_i, THREADS_PER_WARP) + (local_mean - o_mean) * (local_mean - o_mean) * local_count * o_count * factor); - local_mean = (local_count * local_mean + o_count * o_mean) * factor; - local_count += o_count; - } - - if (threadIdx.x == 0) { - mean_val = local_mean; - variance_val = local_var_n / local_count; - inv_var_val = 1 / sqrt(variance_val + epsilon); - - if (save_mean && save_inv_variance) { - save_mean[i] = mean_val; - save_inv_variance[i] = inv_var_val; - } - mean[i] = (1 - exponentialAverageFactor) * mean_val + - exponentialAverageFactor * mean[i]; - variance[i] = (1 - exponentialAverageFactor) * variance_val + - exponentialAverageFactor * variance[i]; - } - __syncthreads(); - - for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { - const int index = layout == phi::DataLayout::kNCHW - ? (j / HxW * C + i) * HxW + j % HxW - : j * outer_size + i; - BatchNormParamType x_sub_mean = - static_cast>(x[index]) - mean_val; - y[index] = scale[i] * x_sub_mean * inv_var_val + bias[i]; - } - } -} - -template -static __global__ void BNForwardTraining2D( - const T *x, - const BatchNormParamType *scale, - const BatchNormParamType *bias, - const int C, - const int N, - const int HxW, - const double epsilon, - double exponentialAverageFactor, - T *y, - BatchNormParamType *mean, - BatchNormParamType *variance, - BatchNormParamType *save_mean, - BatchNormParamType *save_inv_variance, - BatchNormParamType *block_data_ptr, - int *flag_ptr) { - int outer_size = C; - int inner_size = N * HxW; - - extern __shared__ __align__(sizeof(double)) char smem_buf[]; - - BatchNormParamType* mean_val = reinterpret_cast*>(smem_buf); - BatchNormParamType* variance_val = reinterpret_cast*>(&smem_buf[blockDim.x]); - BatchNormParamType* inv_var_val = reinterpret_cast*>(&smem_buf[2*blockDim.x]); - - __shared__ BatchNormParamType smem_sum[BlockDim]; - __shared__ BatchNormParamType smem_square_sum[BlockDim]; - - int outer_loop_stride = gridDim.x * blockDim.x; - int inner_loop_stride = gridDim.y * blockDim.y; - - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < outer_size; i += outer_loop_stride) { - BatchNormParamType x_sum = static_cast>(0); - BatchNormParamType x_square_sum = static_cast>(0); - - for (int j = blockIdx.y * blockDim.y + threadIdx.y; j < inner_size; j += inner_loop_stride) { - const int index = layout == phi::DataLayout::kNCHW - ? (j / HxW * C + i) * HxW + j % HxW - : j * outer_size + i; - BatchNormParamType x_i = static_cast>(x[index]); - x_sum += x_i; - x_square_sum += x_i * x_i; - } - - // vertical block sum - int tid = threadIdx.x + threadIdx.y * blockDim.x; - #pragma unroll - for (int offset = blockDim.y/2; offset > 0; offset >>= 1) { - if (threadIdx.y < offset*2) { - smem_sum[tid] = x_sum; - smem_square_sum[tid] = x_square_sum; - } - __syncthreads(); - if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { - int pair_tid = tid + offset * blockDim.x; - x_sum += smem_sum[pair_tid]; - x_square_sum += smem_square_sum[pair_tid]; - } - } - - if (gridDim.y > 1) { - volatile BatchNormParamType* staging_sum = block_data_ptr; - volatile BatchNormParamType* staging_square_sum = &block_data_ptr[C*gridDim.y]; - // write block data to global memory - if (threadIdx.y == 0) { - staging_sum[i + blockIdx.y * C] = x_sum; - staging_square_sum[i + blockIdx.y * C] = x_square_sum; - } - - // make sure write is visible to all blocks - __threadfence(); - __syncthreads(); - - __shared__ bool is_last_block_done; - // mark block done - if (threadIdx.x == 0 && threadIdx.y == 0) { - int old = atomicAdd(&flag_ptr[blockIdx.x], 1); - is_last_block_done = (old == (gridDim.y-1)); - } - - __syncthreads(); - - if (is_last_block_done) { - x_sum = static_cast>(0); - x_square_sum = static_cast>(0); - // thread sum - for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) { - x_sum += staging_sum[i+y*C]; - x_square_sum += staging_square_sum[i+y*C]; - } - - // vertical block sum - int tid = threadIdx.x + threadIdx.y * blockDim.x; - #pragma unroll - for (int offset = blockDim.y/2; offset > 0; offset >>= 1) { - if (threadIdx.y < offset*2) { - smem_sum[tid] = x_sum; - smem_square_sum[tid] = x_square_sum; - } - __syncthreads(); - if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { - int pair_tid = tid + offset * blockDim.x; - x_sum += smem_sum[pair_tid]; - x_square_sum += smem_square_sum[pair_tid]; - } - } - - // final compute - if(threadIdx.y == 0) { - mean_val[threadIdx.x] = x_sum / inner_size; - variance_val[threadIdx.x] = x_square_sum / inner_size - mean_val[threadIdx.x] * mean_val[threadIdx.x]; - inv_var_val[threadIdx.x] = 1 / sqrt(variance_val[threadIdx.x] + epsilon); - - if (save_mean && save_inv_variance) { - save_mean[i] = mean_val[threadIdx.x]; - save_inv_variance[i] = inv_var_val[threadIdx.x]; - } - mean[i] = (1 - exponentialAverageFactor) * mean_val[threadIdx.x] + - exponentialAverageFactor * mean[i]; - variance[i] = (1 - exponentialAverageFactor) * variance_val[threadIdx.x] + - exponentialAverageFactor * variance[i]; - } - } - } else { - if(blockIdx.y == 0 && threadIdx.y == 0) { - mean_val[threadIdx.x] = x_sum / inner_size; - variance_val[threadIdx.x] = x_square_sum / inner_size - mean_val[threadIdx.x] * mean_val[threadIdx.x]; - inv_var_val[threadIdx.x] = 1 / sqrt(variance_val[threadIdx.x] + epsilon); - - if (save_mean && save_inv_variance) { - save_mean[i] = mean_val[threadIdx.x]; - save_inv_variance[i] = inv_var_val[threadIdx.x]; - } - mean[i] = (1 - exponentialAverageFactor) * mean_val[threadIdx.x] + - exponentialAverageFactor * mean[i]; - variance[i] = (1 - exponentialAverageFactor) * variance_val[threadIdx.x] + - exponentialAverageFactor * variance[i]; - } - } - __syncthreads(); - - for (int j = blockIdx.y * blockDim.y + threadIdx.y; j < inner_size; j += blockDim.x) { - const int index = layout == phi::DataLayout::kNCHW - ? (j / HxW * C + i) * HxW + j % HxW - : j * outer_size + i; - BatchNormParamType x_sub_mean = - static_cast>(x[index]) - mean_val[threadIdx.x]; - y[index] = scale[i] * x_sub_mean * inv_var_val[threadIdx.x] + bias[i]; - } - } -} - -template -static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTrainingSMem( - const T *x, - const BatchNormParamType *scale, - const BatchNormParamType *bias, - const int C, - const int N, - const int HxW, - const double epsilon, - double exponentialAverageFactor, - T *y, - BatchNormParamType *mean, - BatchNormParamType *variance, - BatchNormParamType *save_mean, - BatchNormParamType *save_inv_variance) { - extern __shared__ __align__(sizeof(double)) char smem_buf[]; - BatchNormParamType* x_buf = reinterpret_cast*>(smem_buf); - - int outer_size = C; - int inner_size = N * HxW; - typedef cub::BlockReduce, BlockDim> BlockReduce; - __shared__ typename BlockReduce::TempStorage mean_storage; - __shared__ typename BlockReduce::TempStorage variance_storeage; - __shared__ BatchNormParamType mean_val; - __shared__ BatchNormParamType variance_val; - __shared__ BatchNormParamType inv_var_val; - - for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { - BatchNormParamType x_sum = static_cast>(0); - BatchNormParamType x_square_sum = static_cast>(0); - - for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { - const int index = layout == phi::DataLayout::kNCHW - ? (j / HxW * C + i) * HxW + j % HxW - : j * outer_size + i; - BatchNormParamType x_i = static_cast>(x[index]); - x_buf[j] = x_i; - x_sum += x_i; - x_square_sum += x_i * x_i; - } - x_sum = BlockReduce(mean_storage).Reduce(x_sum, cub::Sum()); - x_square_sum = - BlockReduce(variance_storeage).Reduce(x_square_sum, cub::Sum()); - if (threadIdx.x == 0) { - mean_val = x_sum / inner_size; - variance_val = x_square_sum / inner_size - mean_val * mean_val; - inv_var_val = 1 / sqrt(variance_val + epsilon); - - if (save_mean && save_inv_variance) { - save_mean[i] = mean_val; - save_inv_variance[i] = inv_var_val; - } - mean[i] = (1 - exponentialAverageFactor) * mean_val + - exponentialAverageFactor * mean[i]; - variance[i] = (1 - exponentialAverageFactor) * variance_val + - exponentialAverageFactor * variance[i]; - } - __syncthreads(); - - for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { - const int index = layout == phi::DataLayout::kNCHW - ? (j / HxW * C + i) * HxW + j % HxW - : j * outer_size + i; - BatchNormParamType x_sub_mean = - static_cast>(x_buf[j]) - mean_val; - y[index] = scale[i] * x_sub_mean * inv_var_val + bias[i]; - } - } -} - -template -inline bool TryDispatchBNForwardTrainingSMem( - const Context &ctx, - const T *x, - const BatchNormParamType *scale, - const BatchNormParamType *bias, - const int C, - const int N, - const int HxW, - const double epsilon, - double exponentialAverageFactor, - T *y, - BatchNormParamType *mean, - BatchNormParamType *variance, - BatchNormParamType *save_mean, - BatchNormParamType *save_inv_variance) { - constexpr int block_size = 512; - const size_t smem = N * HxW * sizeof(BatchNormParamType); - int max_active_blocks_conf; - { - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks_conf, - BNForwardTrainingSMem, - block_size, smem); - } - if (max_active_blocks_conf <= 0) { - return false; - } - const int max_threads = ctx.GetMaxPhysicalThreadCount(); - const int max_blocks = std::max(max_threads / block_size, 1); - const int grid = std::min(C, max_blocks); - BNForwardTrainingSMem<<>>( - x, scale, bias, C, N, HxW, epsilon, exponentialAverageFactor, - y, mean, variance, save_mean, save_inv_variance); - return true; -} - -template -inline void DispatchBNForwardTraining( - const Context &ctx, - const T *x, - const BatchNormParamType *scale, - const BatchNormParamType *bias, - const int C, - const int N, - const int HxW, - const double epsilon, - double exponentialAverageFactor, - T *y, - BatchNormParamType *mean, - BatchNormParamType *variance, - BatchNormParamType *save_mean, - BatchNormParamType *save_inv_variance) { - if ((N * HxW) <= 1024) { - // TODO: impl register-cache version - return; - } else { - bool dispatch_smem_impl_success = false; - { - dispatch_smem_impl_success = TryDispatchBNForwardTrainingSMem( - ctx, x, scale, bias, C, N, HxW, epsilon, exponentialAverageFactor, - y, mean, variance, save_mean, save_inv_variance); - } - if (!dispatch_smem_impl_success) { - const int block = 512; - const int max_threads = ctx.GetMaxPhysicalThreadCount(); - const int max_blocks = std::max(max_threads / block, 1); - const int grid = std::min(C, max_blocks); - return BNForwardTraining<<>>( - x, scale, bias, C, N, HxW, epsilon, exponentialAverageFactor, - y, mean, variance, save_mean, save_inv_variance); - } - } -} - template void BatchNormKernel(const Context &ctx, const DenseTensor &x, @@ -1085,41 +528,15 @@ void BatchNormKernel(const Context &ctx, //const bool use_native_kernel = (x_dims.size() == 2 && N >= 131070); const bool use_native_kernel = true; if(use_native_kernel) { - dim3 block; - dim3 grid; - - const int block_size = 512; - // init block&grid config - int block_x = std::min(phi::funcs::details::GetLastPow2(C), 32); - int block_y = std::min(phi::funcs::details::GetLastPow2(N * H * W * D / 16), block_size / block_x); - if (block_x * block_y != block_size) { - block_x = std::min(phi::funcs::details::GetLastPow2(C), block_size / block_y); - } - int grid_x = (C + block_x - 1) / block_x; - int grid_y = std::min((N * H * W * D + block_y * 16 - 1) / (block_y * 16), 128); - - block.x = block_x; - block.y = block_y; - grid.x = grid_x; - grid.y = grid_y; - - // init intermediate storage - DenseTensor block_data_tensor; - DenseTensor flag_tensor; - BatchNormParamType* block_data_ptr = nullptr; - int* flag_ptr = nullptr; - if(grid.y > 1) { - block_data_tensor.Resize({static_cast(2 * C * grid.y * sizeof(BatchNormParamType))}); - flag_tensor.Resize({static_cast(grid.x * sizeof(int))}); - - block_data_ptr = static_cast*>(ctx.template Alloc>(&block_data_tensor)); - flag_ptr = static_cast(ctx.template Alloc(&flag_tensor)); - } - - size_t smem_size = 3 * sizeof(BatchNormParamType) * block.x; + const int block = 256; + const int max_threads = ctx.GetMaxPhysicalThreadCount(); + const int max_blocks = std::max(max_threads / block, 1); + const int grid = std::min(C, max_blocks); if (compute_format == DataLayout::kNCHW) { - BNForwardTraining2D - <<>>( + BNForwardTraining< + T, + block, + DataLayout::kNCHW><<>>( transformed_x.template data(), scale.template data>(), bias.template data>(), @@ -1132,24 +549,7 @@ void BatchNormKernel(const Context &ctx, mean_out->template data>(), variance_out->template data>(), saved_mean->template data>(), - saved_variance->template data>(), - block_data_ptr, - flag_ptr); - // DispatchBNForwardTraining( - // ctx, - // transformed_x.template data(), - // scale.template data>(), - // bias.template data>(), - // C, - // N, - // H * W * D, - // epsilon, - // this_factor, - // transformed_y.template data(), - // mean_out->template data>(), - // variance_out->template data>(), - // saved_mean->template data>(), - // saved_variance->template data>()); + saved_variance->template data>()); } else { BNForwardTraining2D <<>>( @@ -1168,22 +568,6 @@ void BatchNormKernel(const Context &ctx, saved_variance->template data>(), block_data_ptr, flag_ptr); - - // DispatchBNForwardTraining( - // ctx, - // transformed_x.template data(), - // scale.template data>(), - // bias.template data>(), - // C, - // N, - // H * W * D, - // epsilon, - // this_factor, - // transformed_y.template data(), - // mean_out->template data>(), - // variance_out->template data>(), - // saved_mean->template data>(), - // saved_variance->template data>()); } } else { #if CUDNN_VERSION_MIN(7, 4, 1) From 958535d9d371a7dffec16c8607991521821b0a8b Mon Sep 17 00:00:00 2001 From: Zihang Yao <1162526220@qq.com> Date: Mon, 6 Jun 2022 14:13:04 +0800 Subject: [PATCH 08/13] minor --- paddle/phi/kernels/gpu/batch_norm_kernel.cu | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index 1e3b346271b23..9da0027598b5f 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -25,7 +25,6 @@ namespace cub = hipcub; #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/batch_norm_kernel.h" #include "paddle/phi/kernels/funcs/eigen/common.h" -#include "paddle/phi/kernels/funcs/reduce_function.h" #include "paddle/fluid/operators/norm_utils.cu.h" #include "paddle/fluid/operators/norm_utils.h" @@ -525,10 +524,9 @@ void BatchNormKernel(const Context &ctx, // static_cast(saved_variance->template mutable_data< // BatchNormParamType>(ctx.GetPlace())))); #else - //const bool use_native_kernel = (x_dims.size() == 2 && N >= 131070); - const bool use_native_kernel = true; + const bool use_native_kernel = (x_dims.size() == 2 && N >= 131070); if(use_native_kernel) { - const int block = 256; + const int block = 512; const int max_threads = ctx.GetMaxPhysicalThreadCount(); const int max_blocks = std::max(max_threads / block, 1); const int grid = std::min(C, max_blocks); @@ -551,8 +549,10 @@ void BatchNormKernel(const Context &ctx, saved_mean->template data>(), saved_variance->template data>()); } else { - BNForwardTraining2D - <<>>( + BNForwardTraining< + T, + block, + DataLayout::kNHWC><<>>( transformed_x.template data(), scale.template data>(), bias.template data>(), @@ -565,9 +565,7 @@ void BatchNormKernel(const Context &ctx, mean_out->template data>(), variance_out->template data>(), saved_mean->template data>(), - saved_variance->template data>(), - block_data_ptr, - flag_ptr); + saved_variance->template data>()); } } else { #if CUDNN_VERSION_MIN(7, 4, 1) From 5d3e4ec70d3728bf18dc783fdd266e479a0bc5a9 Mon Sep 17 00:00:00 2001 From: Zihang Yao <1162526220@qq.com> Date: Tue, 7 Jun 2022 15:03:13 +0800 Subject: [PATCH 09/13] format --- paddle/phi/kernels/gpu/batch_norm_kernel.cu | 68 ++++++++++----------- 1 file changed, 32 insertions(+), 36 deletions(-) diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index 2085e22159572..5832a332670c1 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -516,47 +516,43 @@ void BatchNormKernel(const Context &ctx, // BatchNormParamType>(ctx.GetPlace())))); #else const bool use_native_kernel = (x_dims.size() == 2 && N >= 131070); - if(use_native_kernel) { + if (use_native_kernel) { const int block = 512; const int max_threads = ctx.GetMaxPhysicalThreadCount(); const int max_blocks = std::max(max_threads / block, 1); const int grid = std::min(C, max_blocks); if (compute_format == DataLayout::kNCHW) { - BNForwardTraining< - T, - block, - DataLayout::kNCHW><<>>( - transformed_x.template data(), - scale.template data>(), - bias.template data>(), - C, - N, - H * W * D, - epsilon, - this_factor, - transformed_y.template data(), - mean_out->template data>(), - variance_out->template data>(), - saved_mean->template data>(), - saved_variance->template data>()); + BNForwardTraining + <<>>( + transformed_x.template data(), + scale.template data>(), + bias.template data>(), + C, + N, + H * W * D, + epsilon, + this_factor, + transformed_y.template data(), + mean_out->template data>(), + variance_out->template data>(), + saved_mean->template data>(), + saved_variance->template data>()); } else { - BNForwardTraining< - T, - block, - DataLayout::kNHWC><<>>( - transformed_x.template data(), - scale.template data>(), - bias.template data>(), - C, - N, - H * W * D, - epsilon, - this_factor, - transformed_y.template data(), - mean_out->template data>(), - variance_out->template data>(), - saved_mean->template data>(), - saved_variance->template data>()); + BNForwardTraining + <<>>( + transformed_x.template data(), + scale.template data>(), + bias.template data>(), + C, + N, + H * W * D, + epsilon, + this_factor, + transformed_y.template data(), + mean_out->template data>(), + variance_out->template data>(), + saved_mean->template data>(), + saved_variance->template data>()); } } else { #if CUDNN_VERSION_MIN(7, 4, 1) @@ -655,7 +651,7 @@ void BatchNormKernel(const Context &ctx, epsilon, ctx.template Alloc>(saved_mean), ctx.template Alloc>(saved_variance))); -#endif // CUDNN_VERSION_MIN(7, 4, 1) +#endif // CUDNN_VERSION_MIN(7, 4, 1) } #endif } From 55b1f2a68a66e67d8467cea988390894e3618a03 Mon Sep 17 00:00:00 2001 From: Zihang Yao <1162526220@qq.com> Date: Tue, 7 Jun 2022 15:40:41 +0800 Subject: [PATCH 10/13] remove magic number --- paddle/phi/kernels/gpu/batch_norm_kernel.cu | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index 5832a332670c1..702722591553f 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -515,7 +515,10 @@ void BatchNormKernel(const Context &ctx, // static_cast(saved_variance->template mutable_data< // BatchNormParamType>(ctx.GetPlace())))); #else - const bool use_native_kernel = (x_dims.size() == 2 && N >= 131070); + // CUDNN PER_ACTIVATION mode only support small batch size + const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070; + const bool use_native_kernel = + (x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD); if (use_native_kernel) { const int block = 512; const int max_threads = ctx.GetMaxPhysicalThreadCount(); From 26a43fc56609d1d65b233edb26c71cb2353a0c5b Mon Sep 17 00:00:00 2001 From: Zihang Yao <1162526220@qq.com> Date: Fri, 10 Jun 2022 02:33:28 +0800 Subject: [PATCH 11/13] fix backward --- paddle/phi/kernels/gpu/batch_norm_kernel.cu | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index 702722591553f..73af404910e0b 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -557,6 +557,22 @@ void BatchNormKernel(const Context &ctx, saved_mean->template data>(), saved_variance->template data>()); } +#if CUDNN_VERSION_MIN(7, 4, 1) + // -------------- allocate reserve space for backward-------------- + if (reserve_space != nullptr) { + size_t reserve_space_size = 0; + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload:: + cudnnGetBatchNormalizationTrainingExReserveSpaceSize( + /*handle=*/handle, + /*mode=*/mode_, + /*bnOps=*/CUDNN_BATCHNORM_OPS_BN, + /*activationDesc=*/nullptr, + /*xDesc=*/data_desc_, + /*sizeInBytes=*/&reserve_space_size)); + reserve_space->Resize({static_cast(reserve_space_size)}); + } +#endif } else { #if CUDNN_VERSION_MIN(7, 4, 1) size_t workspace_size = 0; From 7adb69699cca40f303a9685044d118e538f8cdf6 Mon Sep 17 00:00:00 2001 From: Zihang Yao <1162526220@qq.com> Date: Sat, 11 Jun 2022 13:49:46 +0800 Subject: [PATCH 12/13] fix backward --- .../phi/kernels/gpu/batch_norm_grad_kernel.cu | 194 +++++++++++------- paddle/phi/kernels/gpu/batch_norm_kernel.cu | 16 -- 2 files changed, 115 insertions(+), 95 deletions(-) diff --git a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu index 6de239182c15b..b23b119342d68 100644 --- a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu @@ -542,70 +542,60 @@ void BatchNormGradRawKernel(const Context &ctx, // This branch calls CUDNN APIs if (d_x && d_scale && d_bias) { - bool called = false; -#if CUDNN_VERSION_MIN(7, 4, 1) - called = true; - size_t workspace_size = 0; - void *workspace_ptr = nullptr; - DenseTensor workspace_tensor; - auto reserve_space_size = reserve_space->memory_size(); - // --------------- cudnn batchnorm workspace --------------- - PADDLE_ENFORCE_GPU_SUCCESS( - paddle::platform::dynload:: - cudnnGetBatchNormalizationBackwardExWorkspaceSize( - /*handle=*/ctx.cudnn_handle(), - /*mode=*/mode_, - /*bnIps=*/CUDNN_BATCHNORM_OPS_BN, - /*xDesc=*/data_desc_, - /*yDesc=*/data_desc_, - /*dyDesc=*/data_desc_, - /*dzDesc=*/nullptr, - /*dxDesc=*/data_desc_, - /*bnScaleBiasMeanVarDesc=*/bn_param_desc_, - /*activationDesc=*/nullptr, - /*sizeInBytes=*/&workspace_size)); - - workspace_tensor.Resize({static_cast(workspace_size)}); - workspace_ptr = - static_cast(ctx.template Alloc(&workspace_tensor)); - - PADDLE_ENFORCE_GPU_SUCCESS( - paddle::platform::dynload::cudnnBatchNormalizationBackwardEx( - /*handle=*/ctx.cudnn_handle(), - /*mode=*/mode_, - /*bnOps=*/CUDNN_BATCHNORM_OPS_BN, - /*alphaDataDiff=*/CudnnDataType::kOne(), - /*betaDataDiff=*/CudnnDataType::kZero(), - /*alphaParamDiff=*/CudnnDataType::kOne(), - /*betaParamDiff=*/CudnnDataType::kZero(), - /*xDesc=*/data_desc_, - /*xData=*/transformed_x.template data(), - /*yDesc=*/nullptr, - /*yData=*/nullptr, - /*dyDesc=*/data_desc_, - /*dyData=*/transformed_d_y.template data(), - /*dzDesc=*/nullptr, - /*dzData=*/nullptr, - /*dxDesc=*/data_desc_, - /*dxData=*/ctx.template Alloc(&transformed_d_x), - /*dBnScaleBiasDesc=*/bn_param_desc_, - /*bnScaleData=*/scale.template data>(), - /*bnBiasData=*/nullptr, - /*dBnScaleData=*/ - ctx.template Alloc>(d_scale), - /*dBnBiasData=*/ctx.template Alloc>(d_bias), - /*epsilon=*/epsilon, - /*savedMean=*/saved_mean_data, - /*savedInvVariance=*/saved_var_data, - /*activationDesc=*/nullptr, - /*workspace=*/workspace_ptr, - /*workSpaceSizeInBytes=*/workspace_size, - /*reserveSpace=*/ - const_cast(reserve_space->template data()), - /*reserveSpaceSizeInBytes=*/reserve_space_size)); -#endif // CUDNN_VERSION_MIN(7, 4, 1) - if (!called) { #ifdef PADDLE_WITH_HIP + if (compute_format == DataLayout::kNCHW) { + BNBackward + <<>>( + transformed_d_y.template data(), + transformed_x.template data(), + scale.template data>(), + saved_mean_data, + saved_var_data, + C, + N, + H * W * D, + epsilon, + transformed_d_x.template data(), + ctx.template Alloc>(d_scale), + ctx.template Alloc>(d_bias)); + } else { + BNBackward + <<>>( + transformed_d_y.template data(), + transformed_x.template data(), + scale.template data>(), + saved_mean_data, + saved_var_data, + C, + N, + H * W * D, + epsilon, + transformed_d_x.template data(), + ctx.template Alloc>(d_scale), + ctx.template Alloc>(d_bias)); + } + +// TODO(wangran16): wait for MIOpen to improve the performance of BN +// PADDLE_ENFORCE_GPU_SUCCESS( +// platform::dynload::miopenBatchNormalizationBackward( +// dev_ctx.cudnn_handle(), mode_, CudnnDataType::kOne(), +// CudnnDataType::kZero(), CudnnDataType::kOne(), +// CudnnDataType::kZero(), data_desc_, +// transformed_x.template data(), data_desc_, +// transformed_d_y.template data(), data_desc_, +// transformed_d_x.template mutable_data(ctx.GetPlace()), +// bn_param_desc_, scale->template data>(), +// d_scale->template mutable_data>( +// ctx.GetPlace()), +// d_bias->template mutable_data>( +// ctx.GetPlace()), +// epsilon, saved_mean_data, saved_var_data)); +#else + // CUDNN PER_ACTIVATION mode only support small batch size + const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070; + const bool use_native_kernel = + (x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD); + if (use_native_kernel) { if (compute_format == DataLayout::kNCHW) { BNBackward <<>>( @@ -637,22 +627,67 @@ void BatchNormGradRawKernel(const Context &ctx, ctx.template Alloc>(d_scale), ctx.template Alloc>(d_bias)); } + } else { +#if CUDNN_VERSION_MIN(7, 4, 1) + size_t workspace_size = 0; + void *workspace_ptr = nullptr; + DenseTensor workspace_tensor; + auto reserve_space_size = reserve_space->memory_size(); + // --------------- cudnn batchnorm workspace --------------- + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload:: + cudnnGetBatchNormalizationBackwardExWorkspaceSize( + /*handle=*/ctx.cudnn_handle(), + /*mode=*/mode_, + /*bnIps=*/CUDNN_BATCHNORM_OPS_BN, + /*xDesc=*/data_desc_, + /*yDesc=*/data_desc_, + /*dyDesc=*/data_desc_, + /*dzDesc=*/nullptr, + /*dxDesc=*/data_desc_, + /*bnScaleBiasMeanVarDesc=*/bn_param_desc_, + /*activationDesc=*/nullptr, + /*sizeInBytes=*/&workspace_size)); -// TODO(wangran16): wait for MIOpen to improve the performance of BN -// PADDLE_ENFORCE_GPU_SUCCESS( -// platform::dynload::miopenBatchNormalizationBackward( -// dev_ctx.cudnn_handle(), mode_, CudnnDataType::kOne(), -// CudnnDataType::kZero(), CudnnDataType::kOne(), -// CudnnDataType::kZero(), data_desc_, -// transformed_x.template data(), data_desc_, -// transformed_d_y.template data(), data_desc_, -// transformed_d_x.template mutable_data(ctx.GetPlace()), -// bn_param_desc_, scale->template data>(), -// d_scale->template mutable_data>( -// ctx.GetPlace()), -// d_bias->template mutable_data>( -// ctx.GetPlace()), -// epsilon, saved_mean_data, saved_var_data)); + workspace_tensor.Resize({static_cast(workspace_size)}); + workspace_ptr = + static_cast(ctx.template Alloc(&workspace_tensor)); + + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cudnnBatchNormalizationBackwardEx( + /*handle=*/ctx.cudnn_handle(), + /*mode=*/mode_, + /*bnOps=*/CUDNN_BATCHNORM_OPS_BN, + /*alphaDataDiff=*/CudnnDataType::kOne(), + /*betaDataDiff=*/CudnnDataType::kZero(), + /*alphaParamDiff=*/CudnnDataType::kOne(), + /*betaParamDiff=*/CudnnDataType::kZero(), + /*xDesc=*/data_desc_, + /*xData=*/transformed_x.template data(), + /*yDesc=*/nullptr, + /*yData=*/nullptr, + /*dyDesc=*/data_desc_, + /*dyData=*/transformed_d_y.template data(), + /*dzDesc=*/nullptr, + /*dzData=*/nullptr, + /*dxDesc=*/data_desc_, + /*dxData=*/ctx.template Alloc(&transformed_d_x), + /*dBnScaleBiasDesc=*/bn_param_desc_, + /*bnScaleData=*/scale.template data>(), + /*bnBiasData=*/nullptr, + /*dBnScaleData=*/ + ctx.template Alloc>(d_scale), + /*dBnBiasData=*/ + ctx.template Alloc>(d_bias), + /*epsilon=*/epsilon, + /*savedMean=*/saved_mean_data, + /*savedInvVariance=*/saved_var_data, + /*activationDesc=*/nullptr, + /*workspace=*/workspace_ptr, + /*workSpaceSizeInBytes=*/workspace_size, + /*reserveSpace=*/ + const_cast(reserve_space->template data()), + /*reserveSpaceSizeInBytes=*/reserve_space_size)); #else PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::cudnnBatchNormalizationBackward( @@ -675,8 +710,9 @@ void BatchNormGradRawKernel(const Context &ctx, epsilon, saved_mean_data, saved_var_data)); -#endif +#endif // CUDNN_VERSION_MIN(7, 4, 1) } +#endif if (data_layout == DataLayout::kNHWC && compute_format == DataLayout::kNCHW) { diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index 73af404910e0b..702722591553f 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -557,22 +557,6 @@ void BatchNormKernel(const Context &ctx, saved_mean->template data>(), saved_variance->template data>()); } -#if CUDNN_VERSION_MIN(7, 4, 1) - // -------------- allocate reserve space for backward-------------- - if (reserve_space != nullptr) { - size_t reserve_space_size = 0; - PADDLE_ENFORCE_GPU_SUCCESS( - paddle::platform::dynload:: - cudnnGetBatchNormalizationTrainingExReserveSpaceSize( - /*handle=*/handle, - /*mode=*/mode_, - /*bnOps=*/CUDNN_BATCHNORM_OPS_BN, - /*activationDesc=*/nullptr, - /*xDesc=*/data_desc_, - /*sizeInBytes=*/&reserve_space_size)); - reserve_space->Resize({static_cast(reserve_space_size)}); - } -#endif } else { #if CUDNN_VERSION_MIN(7, 4, 1) size_t workspace_size = 0; From 3c26e1ec167611efd9f235d1ba8cfe9272df3e6d Mon Sep 17 00:00:00 2001 From: Zihang Yao <1162526220@qq.com> Date: Sat, 11 Jun 2022 14:39:53 +0800 Subject: [PATCH 13/13] add unit test for batchnorm1d --- .../tests/unittests/test_batch_norm_op_v2.py | 42 ++++++++++++++++--- 1 file changed, 37 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py index 9db95f094a7e3..cfd5d5f7c9bd0 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py @@ -110,11 +110,43 @@ def compute_v2(x): y.backward() return y.numpy(), x1.gradient() - x = np.random.randn(*shape).astype("float32") - y1, g1 = compute_v1(x) - y2, g2 = compute_v2(x) - self.assertTrue(np.allclose(g1, g2)) - self.assertTrue(np.allclose(y1, y2)) + x = np.random.randn(*shape).astype("float32") + y1, g1 = compute_v1(x) + y2, g2 = compute_v2(x) + self.assertTrue(np.allclose(g1, g2)) + self.assertTrue(np.allclose(y1, y2)) + + def test_eager_api_1d(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + shape = [200000, 4] + + def compute_v1(x): + with fluid.dygraph.guard(p): + bn = fluid.dygraph.BatchNorm(shape[1]) + x1 = paddle.to_tensor(x) + x1.stop_gradient = False + y = bn(x1) + y.backward() + return y.numpy(), x1.gradient() + + def compute_v2(x): + with fluid.dygraph.guard(p): + with _test_eager_guard(): + bn = paddle.nn.BatchNorm1D(shape[1]) + x1 = paddle.to_tensor(x) + x1.stop_gradient = False + y = bn(x1) + y.backward() + return y.numpy(), x1.gradient() + + x = np.random.randn(*shape).astype("float32") + y1, g1 = compute_v1(x) + y2, g2 = compute_v2(x) + self.assertTrue(np.allclose(g1, g2)) + self.assertTrue(np.allclose(y1, y2)) def test_dygraph(self): places = [fluid.CPUPlace()]