diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index 702722591553f..73af404910e0b 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -557,6 +557,22 @@ void BatchNormKernel(const Context &ctx, saved_mean->template data>(), saved_variance->template data>()); } +#if CUDNN_VERSION_MIN(7, 4, 1) + // -------------- allocate reserve space for backward-------------- + if (reserve_space != nullptr) { + size_t reserve_space_size = 0; + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload:: + cudnnGetBatchNormalizationTrainingExReserveSpaceSize( + /*handle=*/handle, + /*mode=*/mode_, + /*bnOps=*/CUDNN_BATCHNORM_OPS_BN, + /*activationDesc=*/nullptr, + /*xDesc=*/data_desc_, + /*sizeInBytes=*/&reserve_space_size)); + reserve_space->Resize({static_cast(reserve_space_size)}); + } +#endif } else { #if CUDNN_VERSION_MIN(7, 4, 1) size_t workspace_size = 0;