forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 1
/
batch_norm_grad_kernel.cu
1043 lines (982 loc) · 40.6 KB
/
batch_norm_grad_kernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/operators/layout_utils.h"
#include "paddle/fluid/operators/norm_utils.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/flags.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/batch_norm_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/norm_utils.h"
#include "paddle/phi/kernels/gpu/batch_norm_utils.h"
#ifdef __HIPCC__
#define LAUNCH_BOUNDS(BlockDim) __launch_bounds__(BlockDim)
#else
#define LAUNCH_BOUNDS(BlockDim)
#endif
DECLARE_bool(cudnn_batchnorm_spatial_persistent);
namespace phi {
template <typename T>
using CudnnDataType = paddle::platform::CudnnDataType<T>;
template <typename T>
using BatchNormParamType = typename CudnnDataType<T>::BatchNormParamType;
template <typename T, int BlockDim, phi::DataLayout layout>
static __global__ LAUNCH_BOUNDS(BlockDim) void KeBNBackwardScaleBias(
const T *dy,
const T *x,
const BatchNormParamType<T> *mean,
const BatchNormParamType<T> *variance,
const double epsilon,
const int N,
const int C,
const int HxW,
BatchNormParamType<T> *dscale,
BatchNormParamType<T> *dbias) {
const int outer_size = C;
const int inner_size = N * HxW;
typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage ds_storage;
__shared__ typename BlockReduce::TempStorage db_storage;
for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
BatchNormParamType<T> ds_sum = static_cast<BatchNormParamType<T>>(0);
BatchNormParamType<T> db_sum = static_cast<BatchNormParamType<T>>(0);
BatchNormParamType<T> inv_var_i = 1.0 / sqrt(variance[i] + epsilon);
BatchNormParamType<T> mean_i = mean[i];
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index = layout == phi::DataLayout::kNCHW
? (j / HxW * C + i) * HxW + j % HxW
: j * outer_size + i;
ds_sum += static_cast<BatchNormParamType<T>>(dy[index]) *
(static_cast<BatchNormParamType<T>>(x[index]) - mean_i);
db_sum += static_cast<BatchNormParamType<T>>(dy[index]);
}
ds_sum = BlockReduce(ds_storage).Reduce(ds_sum, cub::Sum());
db_sum = BlockReduce(db_storage).Reduce(db_sum, cub::Sum());
if (threadIdx.x == 0) {
dscale[i] = ds_sum * inv_var_i;
dbias[i] = db_sum;
}
__syncthreads();
}
}
template <typename T, phi::DataLayout layout>
static __global__ void KeBNBackwardData(const T *dy,
const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *variance,
const double epsilon,
const int C,
const int HxW,
const int num,
T *dx) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = gid; i < num; i += stride) {
const int c = layout == phi::DataLayout::kNCHW ? i / HxW % C : i % C;
BatchNormParamType<T> inv_var = 1.0 / sqrt(variance[c] + epsilon);
dx[i] = static_cast<T>(static_cast<BatchNormParamType<T>>(dy[i]) *
scale[c] * inv_var);
}
}
template <typename T>
static __global__ void KeBNRestoreData(const phi::DataLayout layout,
T *x,
const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *bias,
const BatchNormParamType<T> *mean,
const BatchNormParamType<T> *variance,
double epsilon,
int C,
int M,
const int num,
const T *y) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = gid; i < num; i += stride) {
const int c = layout == phi::DataLayout::kNCHW ? (i / M) % C : i % C;
auto y_i = static_cast<BatchNormParamType<T>>(y[i]);
auto x_i = (y_i - bias[c]) / scale[c] / variance[c] + mean[c];
x[i] = static_cast<T>(x_i);
}
}
template <typename T>
class InplaceHelper {
public:
void operator()(const phi::DataLayout layout,
T *x,
const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *bias,
const BatchNormParamType<T> *mean,
const BatchNormParamType<T> *variance,
double epsilon,
int C,
int M,
const int num,
const T *y,
int grid2,
const int block,
const gpuStream_t &stream) {
PADDLE_ENFORCE_EQ(x,
y,
phi::errors::InvalidArgument(
"X and Y should be inplaced in inplace mode"));
KeBNRestoreData<<<grid2, block, 0, stream>>>(
layout, x, scale, bias, mean, variance, epsilon, C, M, num, y);
}
};
template <typename T, int BlockDim, phi::DataLayout layout>
static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackward(
const T *dy,
const T *x,
const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *saved_mean,
const BatchNormParamType<T> *saved_inv_variance,
const int C,
const int N,
const int HxW,
const double epsilon,
T *dx,
BatchNormParamType<T> *dscale,
BatchNormParamType<T> *dbias) {
const int outer_size = C;
const int inner_size = N * HxW;
typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage ds_storage;
__shared__ typename BlockReduce::TempStorage db_storage;
__shared__ typename BlockReduce::TempStorage mean_storage;
__shared__ typename BlockReduce::TempStorage variance_storeage;
__shared__ BatchNormParamType<T> inv_var_val;
__shared__ BatchNormParamType<T> mean_val;
__shared__ BatchNormParamType<T> dscale_val;
__shared__ BatchNormParamType<T> dbias_val;
for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
BatchNormParamType<T> ds_sum = static_cast<BatchNormParamType<T>>(0);
BatchNormParamType<T> db_sum = static_cast<BatchNormParamType<T>>(0);
if (saved_mean && saved_inv_variance) {
if (threadIdx.x == 0) {
inv_var_val = saved_inv_variance[i];
mean_val = saved_mean[i];
}
} else {
BatchNormParamType<T> x_sum = static_cast<BatchNormParamType<T>>(0);
BatchNormParamType<T> x_square_sum =
static_cast<BatchNormParamType<T>>(0);
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index = layout == phi::DataLayout::kNCHW
? (j / HxW * C + i) * HxW + j % HxW
: j * outer_size + i;
BatchNormParamType<T> x_i =
static_cast<BatchNormParamType<T>>(x[index]);
x_sum += x_i;
x_square_sum += x_i * x_i;
}
x_sum = BlockReduce(mean_storage).Reduce(x_sum, cub::Sum());
x_square_sum =
BlockReduce(variance_storeage).Reduce(x_square_sum, cub::Sum());
if (threadIdx.x == 0) {
mean_val = x_sum / inner_size;
inv_var_val =
1 / sqrt(x_square_sum / inner_size - mean_val * mean_val + epsilon);
}
}
__syncthreads();
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index = layout == phi::DataLayout::kNCHW
? (j / HxW * C + i) * HxW + j % HxW
: j * outer_size + i;
BatchNormParamType<T> dy_i =
static_cast<BatchNormParamType<T>>(dy[index]);
ds_sum +=
dy_i * (static_cast<BatchNormParamType<T>>(x[index]) - mean_val);
db_sum += dy_i;
}
ds_sum = BlockReduce(ds_storage).Reduce(ds_sum, cub::Sum());
db_sum = BlockReduce(db_storage).Reduce(db_sum, cub::Sum());
if (threadIdx.x == 0) {
dscale_val = ds_sum * inv_var_val;
dbias_val = db_sum;
dscale[i] = dscale_val;
dbias[i] = dbias_val;
}
__syncthreads();
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index = layout == phi::DataLayout::kNCHW
? (j / HxW * C + i) * HxW + j % HxW
: j * outer_size + i;
dx[index] = scale[i] * inv_var_val *
(static_cast<BatchNormParamType<T>>(dy[index]) -
dbias_val / static_cast<BatchNormParamType<T>>(inner_size) -
(static_cast<BatchNormParamType<T>>(x[index]) - mean_val) *
inv_var_val * dscale_val / inner_size);
}
}
}
template <typename T, int BlockDim, phi::DataLayout layout>
static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackwardData(
const T *dy,
const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *mean,
const T *x,
const BatchNormParamType<T> *variance,
const int C,
const int N,
const int HxW,
T *dx) {
const int outer_size = C;
const int inner_size = N * HxW;
typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage dy_storage;
__shared__ typename BlockReduce::TempStorage dy_x_sub_mean_storage;
__shared__ BatchNormParamType<T> dy_sum_val;
__shared__ BatchNormParamType<T> dy_x_sub_mean_sum_val;
for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
BatchNormParamType<T> inv_var_i = variance[i];
BatchNormParamType<T> mean_i = mean[i];
BatchNormParamType<T> dy_sum = static_cast<BatchNormParamType<T>>(0);
BatchNormParamType<T> dy_x_sub_mean_sum =
static_cast<BatchNormParamType<T>>(0);
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index = layout == phi::DataLayout::kNCHW
? (j / HxW * C + i) * HxW + j % HxW
: j * outer_size + i;
BatchNormParamType<T> dy_i =
static_cast<BatchNormParamType<T>>(dy[index]);
dy_sum += dy_i;
dy_x_sub_mean_sum +=
dy_i * (static_cast<BatchNormParamType<T>>(x[index]) - mean_i);
}
dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum());
dy_x_sub_mean_sum = BlockReduce(dy_x_sub_mean_storage)
.Reduce(dy_x_sub_mean_sum, cub::Sum());
if (threadIdx.x == 0) {
dy_sum_val = dy_sum;
dy_x_sub_mean_sum_val = dy_x_sub_mean_sum;
}
__syncthreads();
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index = layout == phi::DataLayout::kNCHW
? (j / HxW * C + i) * HxW + j % HxW
: j * outer_size + i;
dx[index] =
(static_cast<BatchNormParamType<T>>(dy[index]) -
dy_sum_val / static_cast<BatchNormParamType<T>>(inner_size) -
(static_cast<BatchNormParamType<T>>(x[index]) - mean_i) *
dy_x_sub_mean_sum_val * inv_var_i * inv_var_i / inner_size) *
scale[i] * inv_var_i;
}
}
}
template <typename T, typename Context>
void BatchNormGradRawKernel(const Context &ctx,
const DenseTensor &x,
const DenseTensor &scale,
const DenseTensor &bias,
const paddle::optional<DenseTensor> &mean,
const paddle::optional<DenseTensor> &variance,
const DenseTensor &saved_mean,
const DenseTensor &saved_variance,
const paddle::optional<DenseTensor> &reserve_space,
const DenseTensor &y_grad,
float momentum,
float epsilon_f,
const std::string &data_layout_str,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
bool is_inplace,
DenseTensor *x_grad,
DenseTensor *scale_grad,
DenseTensor *bias_grad) {
double epsilon = static_cast<double>(epsilon_f);
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_layout_str);
const auto *d_y = &y_grad;
auto *d_x = x_grad;
auto *d_scale = scale_grad;
auto *d_bias = bias_grad;
use_global_stats = is_test || use_global_stats;
const auto &x_dims = x.dims();
PADDLE_ENFORCE_EQ(
x_dims.size() >= 2 && x_dims.size() <= 5,
true,
phi::errors::InvalidArgument(
"The size of input's dimensions should be between 2 and 5."
"But received: the size of input's dimensions is [%d],"
"the dimensions of input is [%s]",
x_dims.size(),
x_dims));
int N, C, H, W, D;
phi::funcs::ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
// init output
if (d_x) {
ctx.template Alloc<T>(d_x);
}
if (d_scale && d_bias) {
ctx.template Alloc<BatchNormParamType<T>>(d_scale);
ctx.template Alloc<BatchNormParamType<T>>(d_bias);
}
PADDLE_ENFORCE_EQ(
scale.dims().size(),
1UL,
phi::errors::InvalidArgument(
"The size of scale's dimensions must equal to 1. But received: "
"the size of scale's dimensions is [%d], the dimensions of scale "
"is [%s].",
scale.dims().size(),
scale.dims()));
PADDLE_ENFORCE_EQ(
scale.dims()[0],
C,
phi::errors::InvalidArgument(
"The first dimension of scale must equal to Channels[%d]. But "
"received: the first dimension of scale is [%d]",
C,
scale.dims()[0]));
auto dtype = paddle::platform::CudnnDataType<T>::type;
#ifdef PADDLE_WITH_HIP
auto compute_format =
data_layout == DataLayout::kNHWC ? DataLayout::kNHWC : DataLayout::kNCHW;
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// HIP do not support compute format of NHWC
// auto compute_format = DataLayout::kNCHW;
#else
const bool fast_nhwc_batch_norm = dtype == CUDNN_DATA_HALF &&
FLAGS_cudnn_batchnorm_spatial_persistent &&
(reserve_space.get_ptr() != nullptr);
auto compute_format = fast_nhwc_batch_norm && data_layout == DataLayout::kNHWC
? DataLayout::kNHWC
: DataLayout::kNCHW;
#endif
DenseTensor transformed_x(x.type());
DenseTensor transformed_d_y(d_y->type());
DenseTensor transformed_d_x;
if (data_layout == DataLayout::kNHWC && compute_format == DataLayout::kNCHW &&
x_dims.size() > 2) {
VLOG(3) << "Transform input tensor from NHWC to NCHW.";
ResizeToChannelFirst<Context, T>(ctx, &x, &transformed_x);
TransToChannelFirst<Context, T>(ctx, &x, &transformed_x);
ResizeToChannelFirst<Context, T>(ctx, d_y, &transformed_d_y);
TransToChannelFirst<Context, T>(ctx, d_y, &transformed_d_y);
if (d_x) {
ResizeToChannelFirst<Context, T>(ctx, d_x, &transformed_d_x);
}
} else {
transformed_x.ShareDataWith(x);
transformed_d_y.ShareDataWith(*d_y);
if (d_x) {
transformed_d_x.ShareDataWith(*d_x);
}
}
std::vector<int> dims;
std::vector<int> strides;
if (compute_format == DataLayout::kNCHW) {
dims = {N, C, H, W, D};
strides = {C * H * W * D, H * W * D, W * D, D, 1};
} else {
dims = {N, C, H, W, D};
strides = {H * W * C * D, 1, W * D * C, D * C, C};
}
const int num = transformed_x.numel();
#ifdef HIPCC
const int block = 256;
#else
const int block = 512;
#endif
int max_threads = ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
int grid1 = (num + block - 1) / block;
int grid2 = std::min(C, max_blocks);
auto stream = ctx.stream();
InplaceHelper<T> inplace_functor;
if (!use_global_stats) {
if ((N * H * W * D) == 1) {
if (d_x) {
paddle::framework::TensorCopy(*d_y, ctx.GetPlace(), d_x);
}
phi::funcs::SetConstant<Context, BatchNormParamType<T>> functor;
functor(ctx, d_scale, static_cast<BatchNormParamType<T>>(0));
functor(ctx, d_bias, static_cast<BatchNormParamType<T>>(0));
return;
}
// ------------------- cudnn descriptors ---------------------
#ifdef PADDLE_WITH_HIP
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// miopenTensorDescriptor_t data_desc_;
// miopenTensorDescriptor_t bn_param_desc_;
// miopenBatchNormMode_t mode_;
// PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenCreateTensorDescriptor(&data_desc_));
// PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenCreateTensorDescriptor(&bn_param_desc_));
#else
cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_;
cudnnBatchNormMode_t mode_;
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnCreateTensorDescriptor(
&bn_param_desc_));
#endif
if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
LOG(ERROR) << "Provided epsilon is smaller than "
<< "CUDNN_BN_MIN_EPSILON. Setting it to "
<< "CUDNN_BN_MIN_EPSILON instead.";
}
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
#ifdef PADDLE_WITH_HIP
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// mode_ = miopenBNSpatial;
#elif CUDNN_VERSION_MIN(7, 0, 1)
if (FLAGS_cudnn_batchnorm_spatial_persistent) {
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
} else if (H == 1 && W == 1) {
mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
} else {
mode_ = CUDNN_BATCHNORM_SPATIAL;
}
#else
if (H == 1 && W == 1) {
mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
} else {
mode_ = CUDNN_BATCHNORM_SPATIAL;
}
#endif // CUDNN_VERSION_MIN(7, 0, 1)
#ifdef PADDLE_WITH_HIP
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSetTensorDescriptor(
// data_desc_, CudnnDataType<T>::type,
// x_dims.size() > 3 ? x_dims.size() : 4, const_cast<int *>(dims.data()),
// const_cast<int *>(strides.data())));
// PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenDeriveBNTensorDescriptor(bn_param_desc_,
// data_desc_, mode_));
#else
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_,
CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4,
dims.data(),
strides.data()));
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnDeriveBNTensorDescriptor(
bn_param_desc_, data_desc_, mode_));
#endif
const auto *saved_mean_data =
saved_mean.template data<BatchNormParamType<T>>();
const auto *saved_var_data =
saved_variance.template data<BatchNormParamType<T>>();
if (is_inplace) {
inplace_functor(compute_format,
transformed_x.data<T>(),
scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(),
saved_mean_data,
saved_var_data,
epsilon,
C,
H * W * D,
num,
transformed_x.data<T>(),
grid2,
block,
stream);
}
// This branch calls CUDNN APIs
if (d_x && d_scale && d_bias) {
#ifdef PADDLE_WITH_HIP
if (compute_format == DataLayout::kNCHW) {
BNBackward<T, block, DataLayout::kNCHW>
<<<grid2, block, 0, ctx.stream()>>>(
transformed_d_y.template data<T>(),
transformed_x.template data<T>(),
scale.template data<BatchNormParamType<T>>(),
saved_mean_data,
saved_var_data,
C,
N,
H * W * D,
epsilon,
transformed_d_x.template data<T>(),
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.template Alloc<BatchNormParamType<T>>(d_bias));
} else {
BNBackward<T, block, DataLayout::kNHWC>
<<<grid2, block, 0, ctx.stream()>>>(
transformed_d_y.template data<T>(),
transformed_x.template data<T>(),
scale.template data<BatchNormParamType<T>>(),
saved_mean_data,
saved_var_data,
C,
N,
H * W * D,
epsilon,
transformed_d_x.template data<T>(),
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.template Alloc<BatchNormParamType<T>>(d_bias));
}
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenBatchNormalizationBackward(
// dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
// CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
// CudnnDataType<T>::kZero(), data_desc_,
// transformed_x.template data<T>(), data_desc_,
// transformed_d_y.template data<T>(), data_desc_,
// transformed_d_x.template mutable_data<T>(ctx.GetPlace()),
// bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
// d_scale->template mutable_data<BatchNormParamType<T>>(
// ctx.GetPlace()),
// d_bias->template mutable_data<BatchNormParamType<T>>(
// ctx.GetPlace()),
// epsilon, saved_mean_data, saved_var_data));
#else
// CUDNN PER_ACTIVATION mode only support small batch size
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
const bool use_native_kernel =
(x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD);
if (use_native_kernel) {
if (compute_format == DataLayout::kNCHW) {
BNBackward<T, block, DataLayout::kNCHW>
<<<grid2, block, 0, ctx.stream()>>>(
transformed_d_y.template data<T>(),
transformed_x.template data<T>(),
scale.template data<BatchNormParamType<T>>(),
saved_mean_data,
saved_var_data,
C,
N,
H * W * D,
epsilon,
transformed_d_x.template data<T>(),
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.template Alloc<BatchNormParamType<T>>(d_bias));
} else {
BNBackward<T, block, DataLayout::kNHWC>
<<<grid2, block, 0, ctx.stream()>>>(
transformed_d_y.template data<T>(),
transformed_x.template data<T>(),
scale.template data<BatchNormParamType<T>>(),
saved_mean_data,
saved_var_data,
C,
N,
H * W * D,
epsilon,
transformed_d_x.template data<T>(),
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.template Alloc<BatchNormParamType<T>>(d_bias));
}
} else {
#if CUDNN_VERSION_MIN(7, 4, 1)
size_t workspace_size = 0;
void *workspace_ptr = nullptr;
DenseTensor workspace_tensor;
auto reserve_space_size = reserve_space->memory_size();
// --------------- cudnn batchnorm workspace ---------------
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::
cudnnGetBatchNormalizationBackwardExWorkspaceSize(
/*handle=*/ctx.cudnn_handle(),
/*mode=*/mode_,
/*bnIps=*/CUDNN_BATCHNORM_OPS_BN,
/*xDesc=*/data_desc_,
/*yDesc=*/data_desc_,
/*dyDesc=*/data_desc_,
/*dzDesc=*/nullptr,
/*dxDesc=*/data_desc_,
/*bnScaleBiasMeanVarDesc=*/bn_param_desc_,
/*activationDesc=*/nullptr,
/*sizeInBytes=*/&workspace_size));
workspace_tensor.Resize({static_cast<int64_t>(workspace_size)});
workspace_ptr =
static_cast<void *>(ctx.template Alloc<uint8_t>(&workspace_tensor));
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnBatchNormalizationBackwardEx(
/*handle=*/ctx.cudnn_handle(),
/*mode=*/mode_,
/*bnOps=*/CUDNN_BATCHNORM_OPS_BN,
/*alphaDataDiff=*/CudnnDataType<T>::kOne(),
/*betaDataDiff=*/CudnnDataType<T>::kZero(),
/*alphaParamDiff=*/CudnnDataType<T>::kOne(),
/*betaParamDiff=*/CudnnDataType<T>::kZero(),
/*xDesc=*/data_desc_,
/*xData=*/transformed_x.template data<T>(),
/*yDesc=*/nullptr,
/*yData=*/nullptr,
/*dyDesc=*/data_desc_,
/*dyData=*/transformed_d_y.template data<T>(),
/*dzDesc=*/nullptr,
/*dzData=*/nullptr,
/*dxDesc=*/data_desc_,
/*dxData=*/ctx.template Alloc<T>(&transformed_d_x),
/*dBnScaleBiasDesc=*/bn_param_desc_,
/*bnScaleData=*/scale.template data<BatchNormParamType<T>>(),
/*bnBiasData=*/nullptr,
/*dBnScaleData=*/
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
/*dBnBiasData=*/
ctx.template Alloc<BatchNormParamType<T>>(d_bias),
/*epsilon=*/epsilon,
/*savedMean=*/saved_mean_data,
/*savedInvVariance=*/saved_var_data,
/*activationDesc=*/nullptr,
/*workspace=*/workspace_ptr,
/*workSpaceSizeInBytes=*/workspace_size,
/*reserveSpace=*/
const_cast<uint8_t *>(reserve_space->template data<uint8_t>()),
/*reserveSpaceSizeInBytes=*/reserve_space_size));
#else
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnBatchNormalizationBackward(
ctx.cudnn_handle(),
mode_,
CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(),
CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(),
data_desc_,
transformed_x.template data<T>(),
data_desc_,
transformed_d_y.template data<T>(),
data_desc_,
ctx.template Alloc<T>(&transformed_d_x),
bn_param_desc_,
scale.template data<BatchNormParamType<T>>(),
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.template Alloc<BatchNormParamType<T>>(d_bias),
epsilon,
saved_mean_data,
saved_var_data));
#endif // CUDNN_VERSION_MIN(7, 4, 1)
}
#endif
if (data_layout == DataLayout::kNHWC &&
compute_format == DataLayout::kNCHW) {
VLOG(3) << "Transform batchnorm output from NCHW to NHWC";
TransToChannelLast<Context, T>(ctx, &transformed_d_x, d_x);
}
} else {
// This branch call CUDA kernels
if (compute_format == DataLayout::kNCHW) {
if (d_x) {
BNBackwardData<T, block, phi::DataLayout::kNCHW>
<<<grid2, block, 0, ctx.stream()>>>(
d_y->data<T>(),
scale.data<BatchNormParamType<T>>(),
saved_mean_data,
x.data<T>(),
saved_var_data,
C,
N,
H * W * D,
d_x->data<T>());
}
if (d_scale && d_bias) {
KeBNBackwardScaleBias<T, block, phi::DataLayout::kNCHW>
<<<grid2, block, 0, stream>>>(
d_y->data<T>(),
x.data<T>(),
saved_mean_data,
saved_var_data,
epsilon,
N,
C,
H * W * D,
d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
}
} else {
if (d_x) {
BNBackwardData<T, block, phi::DataLayout::kNHWC>
<<<grid2, block, 0, ctx.stream()>>>(
d_y->data<T>(),
scale.data<BatchNormParamType<T>>(),
saved_mean_data,
x.data<T>(),
saved_var_data,
C,
N,
H * W * D,
d_x->data<T>());
}
if (d_scale && d_bias) {
KeBNBackwardScaleBias<T, block, phi::DataLayout::kNHWC>
<<<grid2, block, 0, stream>>>(
d_y->data<T>(),
x.data<T>(),
saved_mean_data,
saved_var_data,
epsilon,
N,
C,
H * W * D,
d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
}
}
}
#ifdef PADDLE_WITH_HIP
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// clean when exit.
// PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenDestroyTensorDescriptor(data_desc_));
// PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenDestroyTensorDescriptor(bn_param_desc_));
#else
// clean when exit.
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnDestroyTensorDescriptor(
bn_param_desc_));
#endif
} else {
const auto *running_mean = mean.get_ptr();
const auto *running_var = variance.get_ptr();
const auto *running_mean_data =
running_mean->template data<BatchNormParamType<T>>();
const auto *running_var_data =
running_var->template data<BatchNormParamType<T>>();
if (is_inplace) {
auto px = x;
inplace_functor(data_layout,
ctx.template Alloc<T>(&px),
scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(),
running_mean_data,
running_var_data,
epsilon,
C,
H * W * D,
num,
x.data<T>(),
grid2,
block,
stream);
}
if (compute_format == DataLayout::kNCHW) {
if (d_x) {
KeBNBackwardData<T, phi::DataLayout::kNCHW>
<<<grid1, block, 0, stream>>>(d_y->data<T>(),
scale.data<BatchNormParamType<T>>(),
running_var_data,
epsilon,
C,
H * W,
num,
d_x->data<T>());
}
if (d_scale && d_bias) {
KeBNBackwardScaleBias<T, block, phi::DataLayout::kNCHW>
<<<grid2, block, 0, stream>>>(
d_y->data<T>(),
x.data<T>(),
running_mean_data,
running_var_data,
epsilon,
N,
C,
H * W * D,
d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
}
} else {
if (d_x) {
KeBNBackwardData<T, phi::DataLayout::kNHWC>
<<<grid1, block, 0, stream>>>(d_y->data<T>(),
scale.data<BatchNormParamType<T>>(),
running_var_data,
epsilon,
C,
H * W,
num,
d_x->data<T>());
}
if (d_scale && d_bias) {
KeBNBackwardScaleBias<T, block, phi::DataLayout::kNHWC>
<<<grid2, block, 0, stream>>>(
d_y->data<T>(),
x.data<T>(),
running_mean_data,
running_var_data,
epsilon,
N,
C,
H * W * D,
d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
}
}
}
}
template <typename T, typename Context>
void BatchNormGradKernel(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &scale,
const DenseTensor &bias,
const paddle::optional<DenseTensor> &mean,
const paddle::optional<DenseTensor> &variance,
const DenseTensor &saved_mean,
const DenseTensor &saved_variance,
const paddle::optional<DenseTensor> &reserve_space,
const DenseTensor &y_grad,
float momentum,
float epsilon,
const std::string &data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
DenseTensor *x_grad,
DenseTensor *scale_grad,
DenseTensor *bias_grad) {
BatchNormGradRawKernel<T, Context>(dev_ctx,
x,
scale,
bias,
mean,
variance,
saved_mean,
saved_variance,
reserve_space,
y_grad,
momentum,
epsilon,
data_layout,
is_test,
use_global_stats,
trainable_statistics,
fuse_with_relu,
false,
x_grad,
scale_grad,
bias_grad);
}
template <typename T, typename Context>
void BatchNormDoubleGradKernel(const Context &ctx,
const DenseTensor &x,
const DenseTensor &scale,
const paddle::optional<DenseTensor> &mean,
const paddle::optional<DenseTensor> &variance,
const DenseTensor &saved_mean,
const DenseTensor &saved_variance,
const DenseTensor &y_grad,
const DenseTensor &x_grad_grad,
const DenseTensor &scale_grad_grad,
const DenseTensor &bias_grad_grad,
float momentum,
float epsilon,
const std::string &data_layout_str,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
DenseTensor *x_grad,
DenseTensor *scale_grad,
DenseTensor *y_grad_grad) {
PADDLE_ENFORCE_EQ(is_test,
false,
phi::errors::InvalidArgument(
"`is_test = True` CANNOT be used in train program. If "
"you want to use global status in pre_train model, "
"please set `use_global_stats = True`"));
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_layout_str);
const DenseTensor *running_mean = nullptr;
const DenseTensor *running_variance = nullptr;
if (use_global_stats) {
running_mean = mean.get_ptr();
running_variance = variance.get_ptr();
}
paddle::operators::NormDoubleGradFunctor<Context, T>(ctx,
data_layout,
&x,
&scale,
&y_grad,
&saved_mean,
&saved_variance,
running_mean,
running_variance,
epsilon,
use_global_stats,
&x_grad_grad,
&scale_grad_grad,
&bias_grad_grad,
x_grad,
scale_grad,
y_grad_grad);
}
} // namespace phi
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(batch_norm_grad,
GPU,
ALL_LAYOUT,
phi::BatchNormGradKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(batch_norm_grad_raw,
GPU,
ALL_LAYOUT,
phi::BatchNormGradRawKernel,
float,
phi::dtype::float16) {}
#else
PD_REGISTER_KERNEL(batch_norm_grad,
GPU,