/
reducer.cc
998 lines (870 loc) · 37.2 KB
/
reducer.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/collective/reducer.h"
namespace paddle {
namespace distributed {
static Backend TransToBackend(platform::Place place) {
static const std::map<phi::AllocationType, Backend> type_backend = {
{phi::AllocationType::GPU, Backend::GPU},
{phi::AllocationType::CPU, Backend::CPU},
};
phi::AllocationType type = place.GetType();
auto it = type_backend.find(type);
PADDLE_ENFORCE_EQ(it != type_backend.end(), true,
platform::errors::InvalidArgument(
"Place type (%s) is not supported. ", place));
return it->second;
}
std::vector<std::vector<size_t>> Eager_AssignGroupBySize(
const std::vector<Tensor> tensors,
const std::vector<bool> &is_sparse_gradient,
const std::vector<size_t> &group_size_limits,
const std::vector<int64_t> &tensor_indices) {
PADDLE_ENFORCE_EQ(
tensors.size(), is_sparse_gradient.size(),
platform::errors::PreconditionNotMet(
"tensors len must be equal to is_sparse_gradient len, but "
"[%lu] != [%lu]",
tensors.size(), is_sparse_gradient.size()));
auto check_perm = [](const std::vector<int64_t> &x) -> bool {
size_t len = x.size();
std::vector<size_t> cnt(len, 0);
for (size_t i = 0; i < len; ++i) {
if (x[i] >= static_cast<int64_t>(len) || x[i] < 0 || cnt[x[i]]) {
return false;
}
cnt[x[i]]++;
}
return true;
};
PADDLE_ENFORCE_EQ(true, check_perm(tensor_indices),
platform::errors::PreconditionNotMet(
"tensor_indices must be a permutation from 0 to %lu",
tensor_indices.size()));
// the return vector
std::vector<std::vector<size_t>> res;
// Key: the var type
// Value: should use which index in group_size_limits for group size limit
std::map<experimental::DataType, size_t> group_limit_index;
// Key: the var type
// Value: <the var index in input tensors, total numel in this group>
std::map<experimental::DataType, std::pair<std::vector<size_t>, size_t>>
next_group;
for (size_t i = 0; i < tensors.size(); ++i) {
const auto &var = tensors[i];
size_t tensor_real_index = i;
if (!tensor_indices.empty()) {
tensor_real_index = tensor_indices[i];
}
if (is_sparse_gradient[tensor_real_index]) {
// we keep sparse var a single group
res.push_back({tensor_real_index});
continue;
}
const auto &var_dtype = var.dtype();
VLOG(3) << "var[" << var.name() << "] 's type is " << var_dtype;
auto &group_info = next_group[var_dtype];
int64_t var_size = -1;
if (var.is_dense_tensor()) {
var_size =
std::dynamic_pointer_cast<phi::DenseTensor>(var.impl())->numel();
} else {
VLOG(3) << "var " << var.name()
<< " is not tensor or selected_rows, so skip it";
continue;
}
group_info.first.push_back(tensor_real_index);
group_info.second += experimental::SizeOf(var_dtype) * var_size;
// group_info.second += framework::SizeOfType(var_dtype) * var_size;
if (group_limit_index.find(var_dtype) == group_limit_index.end()) {
// means it is the first var of var_dtype
group_limit_index[var_dtype] = 0;
}
auto &cur_limit_index = group_limit_index[var_dtype];
if (group_info.second >= group_size_limits[cur_limit_index]) {
// exceed group capacity and create a new group
res.emplace_back(std::move(group_info.first));
group_info = std::pair<std::vector<size_t>, size_t>();
cur_limit_index =
(std::min)(cur_limit_index + 1, group_size_limits.size() - 1);
}
}
// add the final groups
for (auto &e : next_group) {
auto &group_info = e.second;
if (!group_info.first.empty()) {
res.emplace_back(std::move(group_info.first));
}
}
for (const auto &group_index : res) {
PADDLE_ENFORCE_NE(
group_index.empty(), true,
platform::errors::PreconditionNotMet(
"AssignGroupBySize construct empty group, please check."));
}
if (tensor_indices.empty()) {
std::sort(res.begin(), res.end(),
[](const std::vector<size_t> &x, const std::vector<size_t> &y) {
return x.front() < y.front();
});
}
return res;
}
template <typename DeviceContext, typename T>
static void ConcatTensorsForAllReduce(
const DeviceContext &context,
const std::vector<phi::DenseTensor> &dense_tensors_,
Tensor *p_dense_contents) {
operators::math::ConcatFunctor<DeviceContext, T> concat_functor_;
concat_functor_(
context, dense_tensors_, 0,
std::dynamic_pointer_cast<phi::DenseTensor>(p_dense_contents->impl())
.get());
}
template <typename DeviceContext, typename T>
static void SplitTensorsForAllReduce(
const DeviceContext &context, Tensor *p_dense_contents,
std::vector<phi::DenseTensor> *p_dense_tensors) {
auto *in =
std::dynamic_pointer_cast<phi::DenseTensor>(p_dense_contents->impl())
.get();
std::vector<phi::DenseTensor *> outs;
std::vector<const phi::DenseTensor *> shape_refer;
outs.reserve(p_dense_tensors->size());
shape_refer.reserve(p_dense_tensors->size());
for (auto &tensor : *p_dense_tensors) {
outs.emplace_back(&tensor);
shape_refer.emplace_back(&tensor);
}
operators::math::SplitFunctor<DeviceContext, T> split_functor_;
split_functor_(context, *in, shape_refer, 0, &outs);
}
// context is used to select the stream for concat
template <typename DeviceContext>
static void ConcatTensorsWithType(
const DeviceContext &context,
const std::vector<phi::DenseTensor> &dense_tensors_,
Tensor *p_dense_contents, phi::DataType type) {
switch (type) {
case phi::DataType::FLOAT16:
ConcatTensorsForAllReduce<DeviceContext, platform::float16>(
context, dense_tensors_, p_dense_contents);
break;
case phi::DataType::FLOAT32:
ConcatTensorsForAllReduce<DeviceContext, float>(context, dense_tensors_,
p_dense_contents);
break;
case phi::DataType::FLOAT64:
ConcatTensorsForAllReduce<DeviceContext, double>(context, dense_tensors_,
p_dense_contents);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when it concats tensors for "
"allreduce.",
type));
}
}
// context is used to select the stream for split
template <typename DeviceContext>
static void SplitTensorsWithType(const DeviceContext &context,
Tensor *p_dense_contents,
std::vector<phi::DenseTensor> *p_dense_tensors,
phi::DataType type) {
switch (type) {
case phi::DataType::FLOAT16:
SplitTensorsForAllReduce<DeviceContext, platform::float16>(
context, p_dense_contents, p_dense_tensors);
break;
case phi::DataType::FLOAT32:
SplitTensorsForAllReduce<DeviceContext, float>(context, p_dense_contents,
p_dense_tensors);
break;
case phi::DataType::FLOAT64:
SplitTensorsForAllReduce<DeviceContext, double>(context, p_dense_contents,
p_dense_tensors);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when it splits tensors for "
"allreduce.",
type));
}
}
void EagerGroup::ConcatTensors(const platform::Place &place) {
if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto *default_ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
ConcatTensorsWithType(*default_ctx, dense_tensors_, &dense_contents_,
dtype_);
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't concat grad tensors since it's not compiled with NCCL,"
"Please recompile or reinstall Paddle with NCCL support."));
#endif
} else if (platform::is_cpu_place(place)) {
auto *default_ctx = static_cast<platform::CPUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
ConcatTensorsWithType(*default_ctx, dense_tensors_, &dense_contents_,
dtype_);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Concat grad tensor not supported on place (%s)", place));
}
}
void EagerGroup::SplitTensors(const platform::Place &place) {
if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto *default_ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
SplitTensorsWithType(*default_ctx, &dense_contents_, &dense_tensors_,
dtype_);
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't split grad tensor since it's not compiled with NCCL,"
"Please recompile or reinstall Paddle with NCCL support."));
#endif
} else if (platform::is_cpu_place(place)) {
auto *default_ctx = static_cast<platform::CPUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
SplitTensorsWithType(*default_ctx, &dense_contents_, &dense_tensors_,
dtype_);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Split grad tensor not supported on place (%s)", place));
}
}
EagerReducer::EagerReducer(
const std::vector<Tensor> tensors,
const std::vector<std::vector<size_t>> &group_indices,
const std::vector<bool> &is_sparse_gradient,
std::shared_ptr<distributed::ProcessGroup> process_group,
const std::vector<size_t> &group_size_limits, bool find_unused_parameters)
: tensors_(tensors),
group_indices_(group_indices),
is_sparse_gradient_(is_sparse_gradient),
process_group_(process_group),
group_size_limits_(group_size_limits),
find_unused_vars_each_step_(find_unused_parameters) {
VLOG(3) << "Start construct the Reducer ...";
nranks_ = process_group_->GetSize();
// initialize groups
InitializeGroups(group_indices);
for (size_t global_var_index = 0; global_var_index < tensors_.size();
++global_var_index) {
auto tensor = tensors_[global_var_index];
auto reduce_hook = [=](void) -> void {
this->AddDistHook(global_var_index);
};
const auto &grad_node = GetGradNodeFromTensor(&tensor);
PADDLE_ENFORCE(
grad_node.get() != nullptr,
paddle::platform::errors::Fatal("Detected NULL grad_node,"
"Leaf tensor should have had grad_node "
"with type: GradNodeAccumulation"));
const auto &accumulation_grad_node =
std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node);
accumulation_grad_node->RegisterReduceHook(
std::make_shared<egr::CppTensorVoidHook>(reduce_hook));
gradnode_index_map_[grad_node.get()] = global_var_index;
}
vars_marked_ready_.resize(tensors_.size(), false);
local_used_vars_.resize(tensors_.size(), 0);
if (find_unused_vars_each_step_) {
global_used_vars_ = paddle::experimental::empty(
IntArray({static_cast<int32_t>(tensors_.size())}), DataType::INT32,
inner_place_);
}
}
std::shared_ptr<egr::GradNodeBase> EagerReducer::GetGradNodeFromTensor(
Tensor *tensor) {
auto *autograd_meta = tensor->get_autograd_meta();
const auto &grad_node =
static_cast<egr::AutogradMeta *>(autograd_meta)->GetMutableGradNode();
return grad_node;
}
void EagerReducer::InitializeGroups(
const std::vector<std::vector<size_t>> &group_indices) {
VLOG(3) << "Start initialize groups ..";
// clear the group
groups_.clear();
groups_.reserve(group_indices.size());
variable_locators_.clear();
variable_locators_.resize(tensors_.size());
auto group_nums = group_indices.size();
for (size_t group_index = 0; group_index < group_nums; ++group_index) {
const auto &tensor_indices_ = group_indices[group_index];
PADDLE_ENFORCE_GT(
tensor_indices_.size(), 0,
platform::errors::PreconditionNotMet(
"The number of group[%d]'s elements is 0.", group_index));
EagerGroup group;
// It's just for check the sparse or dense
auto first_var = tensors_[tensor_indices_.front()];
if (tensor_indices_.size() == 1 &&
is_sparse_gradient_[tensor_indices_.front()]) {
// process the sparse gradient. one sparse, one group
group.dtype_ = first_var.dtype();
group.is_sparse_ = true;
} else {
// process the dense gradient.
InitializeDenseGroups(tensor_indices_, &group);
group.dense_contents_ = paddle::experimental::empty(
IntArray({group.all_length_}), group.dtype_, inner_place_);
}
// map tensors to this group by VariableLocator
size_t inside_group_index = 0;
for (const auto var_index : tensor_indices_) {
TensorLocator tensor_locator;
tensor_locator.group_index = group_index;
tensor_locator.inside_group_index = inside_group_index++;
variable_locators_[var_index] = tensor_locator;
}
group.tensor_indices_ = std::move(tensor_indices_);
groups_.emplace_back(std::move(group));
VLOG(3) << "The Group[" << group_index << "]:" << groups_.back();
}
}
void EagerReducer::InitializeDenseGroups(
const std::vector<size_t> &tensor_indices_, EagerGroup *p_group) {
VLOG(3) << "InitializeDenseGroups.";
int64_t all_length = 0;
for (size_t index = 0; index < tensor_indices_.size(); ++index) {
auto tensor_index = tensor_indices_[index];
auto &tensor = tensors_[tensor_index];
auto &tensor_name = tensor.name();
PADDLE_ENFORCE_EQ(is_sparse_gradient_[tensor_index], false,
platform::errors::PreconditionNotMet(
"Tensor %s's GRAD must be Tensor, but received "
"GRAD is SelectedRows",
tensor_name));
PADDLE_ENFORCE_EQ(tensor.initialized(), true,
platform::errors::PreconditionNotMet(
"Tensor %s is not initialized.", tensor_name));
const auto size = tensor.numel();
PADDLE_ENFORCE_GT(
size, 0, platform::errors::PreconditionNotMet(
"The number of tensor %s's elements is 0.", tensor_name));
all_length += size;
p_group->length_.push_back(size);
// for concat operator
p_group->origin_shapes_.push_back(IntArray(tensor.shape()));
p_group->dense_tensors_.push_back(phi::DenseTensor());
const auto &dtype = tensor.dtype();
const auto &inner_place = tensor.impl()->place();
if (index > 0) {
PADDLE_ENFORCE_EQ(dtype, p_group->dtype_,
platform::errors::PreconditionNotMet(
"Tensor %s has unexpected dtype.", tensor_name));
} else {
p_group->dtype_ = dtype;
inner_place_ = inner_place;
}
}
p_group->all_length_ = all_length;
}
void EagerReducer::TraverseBackwardGraph(const std::vector<Tensor> &outputs) {
std::queue<egr::GradNodeBase *> queue;
std::set<egr::GradNodeBase *> visited;
for (const auto &output : outputs) {
auto *auto_grad_meta =
static_cast<egr::AutogradMeta *>(output.get_autograd_meta());
if (!auto_grad_meta) continue;
auto shared_grad_node = auto_grad_meta->GetMutableGradNode();
if (shared_grad_node == nullptr || shared_grad_node.get() == nullptr ||
auto_grad_meta->StopGradient()) {
continue;
}
egr::GradNodeBase *grad_node = shared_grad_node.get();
queue.emplace(grad_node);
}
while (!queue.empty()) {
egr::GradNodeBase *node = queue.front();
queue.pop();
const paddle::small_vector<std::vector<egr::GradSlotMeta>,
egr::kSlotSmallVectorSize> &metas =
node->OutputMeta();
for (size_t i = 0; i < metas.size(); i++) {
for (size_t j = 0; j < metas[i].size(); j++) {
const egr::Edge &edge = metas[i][j].GetEdge();
auto next_node_shared = edge.GetMutableGradNode();
if (!next_node_shared || !next_node_shared.get()) {
continue;
}
auto *next_node = next_node_shared.get();
const bool was_inserted = visited.insert(next_node).second;
if (was_inserted) {
queue.emplace(next_node);
}
}
}
}
for (const auto &it : gradnode_index_map_) {
if (visited.count(it.first) == 0) {
unused_vars_.push_back(it.second);
VLOG(3) << "[Rank " << process_group_->GetRank() << "]: "
<< "Tensor " << tensors_[it.second].name() << " at index "
<< it.second << " is marked as unused.";
}
}
}
void EagerReducer::PrepareForBackward(const std::vector<Tensor> &outputs) {
VLOG(3) << "after forward, then reset count for backward.";
grad_need_hooks_ = true;
next_group_ = 0;
std::for_each(groups_.begin(), groups_.end(), [](EagerGroup &group) {
group.pending_ = group.tensor_indices_.size();
group.sparse_contents_ = Tensor();
});
// reinitialize vars_marked_ready_ for next iteration
vars_marked_ready_.clear();
vars_marked_ready_.resize(tensors_.size(), false);
PADDLE_ENFORCE_EQ(
groups_need_finalize_, false,
platform::errors::PreconditionNotMet(
"A serious error has occurred here. Please "
"set find_unused_parameters=True to traverse backward graph "
"in each step to prepare reduce in advance. If you have "
"set, There may be several reasons for this error: "
"1) Please note that all forward outputs derived from the module "
"parameters must participate in the calculation of losses and "
"subsequent gradient calculations. If not, the wrapper will hang, "
"waiting for autograd to generate gradients for these parameters. "
"you can use detach or stop_gradient to make the unused parameters "
"detached from the autograd graph. "
"2) Used multiple forwards and one backward. You may be able to wrap "
"multiple forwards in a model."));
// The first var to trigger the unused parameter
has_marked_unused_vars_ = false;
if (find_unused_vars_once_ || find_unused_vars_each_step_) {
unused_vars_.clear();
TraverseBackwardGraph(outputs);
// only check once in first step
find_unused_vars_once_ = false;
}
if (find_unused_vars_each_step_ && unused_vars_.empty()) {
LOG_FIRST_N(WARNING, 1)
<< "All parameters are involved in the backward pass. "
"It is recommended to set find_unused_parameters to False "
"to improve performance. However, if unused parameters "
"appear in subsequent iterative training, then an error "
"will occur. Please make it clear that in the subsequent "
"training, there will be no parameters that are not used "
"in the backward pass, and then set find_unused_parameters";
}
if (unused_vars_.size() == tensors_.size()) {
LOG_FIRST_N(WARNING, 1)
<< "There is no parameter in the device involved "
"in the backward calculation. If there are "
"parameters on other devices involved in the "
"backward, then a serious error will occur here.";
}
}
void EagerReducer::AddDistHook(size_t var_index) {
PADDLE_ENFORCE_LT(var_index, variable_locators_.size(),
platform::errors::OutOfRange(
"Out of bounds variable index. it must be less"
"than %d, but it is %d",
variable_locators_.size(), var_index));
// gradient synchronization is not required when grad_need_hooks_ is false.
if (!grad_need_hooks_) {
return;
}
VLOG(3) << "Tensor[" << var_index << "] [" << tensors_[var_index].name()
<< "@Grad] arrived and triggered disthook";
local_used_vars_[var_index] = 1;
if (!has_marked_unused_vars_) {
has_marked_unused_vars_ = true;
for (const auto unused_index : unused_vars_) {
MarkVarReady(unused_index, false);
}
}
MarkVarReady(var_index, true);
}
void EagerReducer::MarkVarReady(const size_t var_index,
const bool is_used_var) {
VLOG(3) << "Tensor[" << var_index << "][" << tensors_[var_index].name()
<< "] is marked ready.";
// error happened, if the var is ready before.
if (vars_marked_ready_[var_index]) {
auto error_info = string::Sprintf(
"Error happened, when parameter[%d][%s] has been ready before. "
"Please set find_unused_parameters=True to traverse backward graph "
"in each step to prepare reduce in advance. If you have set, "
"there may be several reasons for this error: "
"1) In multiple reentrant backward phase, some parameters are reused."
"2) Using model parameters outside of forward function. Please "
"make sure that model parameters are not shared in concurrent "
"forward-backward passes.",
var_index, tensors_[var_index].name());
PADDLE_ENFORCE_EQ(has_marked_unused_vars_, false,
platform::errors::PreconditionNotMet(error_info));
error_info +=
"3) Unused parameters retrieval is incorrect. "
"The return value of forward will be used to retrieve"
" the unused parameters of the entire model. These "
"gradients of unused parameters will not be synchronized "
"between multiple cards. However, if the unused "
"parameters participate in the backward calculation "
"again at a later time (e.g. after the forward function, "
"the loss calculation uses the unused "
"paramters of the forward and trigger backward), "
"its gradient will be wrong.";
PADDLE_ENFORCE_EQ(has_marked_unused_vars_, true,
platform::errors::PreconditionNotMet(error_info));
} else {
vars_marked_ready_[var_index] = true;
}
groups_need_finalize_ = true;
const auto &var_locator = variable_locators_[var_index];
const auto group_index = var_locator.group_index;
const auto inside_group_index = var_locator.inside_group_index;
auto &group = groups_[group_index];
auto &group_tensor = group.dense_tensors_[inside_group_index];
const auto length = group.length_[inside_group_index];
if (!group.is_sparse_) {
if (is_used_var) {
auto *autograd_meta = tensors_[var_index].get_autograd_meta();
auto &grad_tensor =
static_cast<egr::AutogradMeta *>(autograd_meta)->Grad();
group_tensor
.ShareDataWith(*(
std::dynamic_pointer_cast<phi::DenseTensor>(grad_tensor.impl())))
.Resize({grad_tensor.numel()});
} else {
// TODO(shenliang03): maybe save the memory by avoiding tensor
// construction
if (!group_tensor.initialized()) {
group_tensor.Resize({static_cast<int64_t>(length)});
group_tensor.mutable_data(inner_place_, group.dtype_);
}
if (HasGrad(var_index)) {
VLOG(3) << "Tensor[" << tensors_[var_index].name() << "] has grad";
auto grad_tensor = egr::EagerUtils::mutable_grad(tensors_[var_index]);
group_tensor
.ShareDataWith(*(std::dynamic_pointer_cast<phi::DenseTensor>(
grad_tensor->impl())))
.Resize({length});
} else {
VLOG(3) << "Tensor[" << tensors_[var_index].name()
<< "] doesn't have grad";
auto *dev_ctx =
platform::DeviceContextPool::Instance().Get(inner_place_);
group_tensor.Resize({static_cast<int64_t>(length)});
phi::funcs::set_constant(*dev_ctx, &group_tensor, 0.0);
}
}
} else {
auto *autograd_meta = tensors_[var_index].get_autograd_meta();
auto &grad_tensor = static_cast<egr::AutogradMeta *>(autograd_meta)->Grad();
// process sparse group
PADDLE_ENFORCE_EQ(
HasGrad(var_index), true,
platform::errors::PreconditionNotMet(
"The sparse parameter[%d][%s] should have gradient. "
"Currently, DataParallel does not support sparse "
"parameters without generating gradients during training. "
"For example, if is_sparese=True is used in Embedding, "
"the current step of this parameter cannot generate gradient "
"because of stop_gradient/detatch, where error will occur.",
var_index, tensors_[var_index].name()));
// need to check tensor type
PADDLE_ENFORCE_EQ(
grad_tensor.is_selected_rows(), true,
platform::errors::PreconditionNotMet(
"The sparse parameter[%d][%s] must have a selectedrows gradient. "
"Before forward pass, the parameter type is inferred to be "
"SelectedRows, but after backward pass, its actual type becomes "
"LodTensor. It is currently not supported by DataParallel. "
"For example, if sparse embedding is used, and the weight of "
"embedding is shared with subsequent dense parameters, then "
"the parameter gradient of the embedding will be converted "
"to dense parameters.",
var_index, tensors_[var_index].name()));
group.sparse_contents_.set_impl(grad_tensor.impl());
}
if (--group.pending_ == 0) {
// can start allreduce
MarkGroupReady(group_index);
}
if (next_group_ == groups_.size()) {
FinalizeBackward();
}
}
void EagerReducer::MarkGroupReady(size_t group_index) {
VLOG(3) << "Group[" << group_index << "] is ready";
PADDLE_ENFORCE_GE(
group_index, next_group_,
platform::errors::PreconditionNotMet(
"The index of the incoming group must be greater "
"than or equal to the previously synchronized group index, "
"expect it to greater than or equal to %d, but got %d.",
next_group_, group_index));
if (group_index > next_group_) {
VLOG(3) << "It will adjust the order of group in next batch automatically";
return;
}
for (; next_group_ < groups_.size() && groups_[next_group_].pending_ == 0;
++next_group_) {
UNUSED auto &group = groups_[next_group_];
if (group.is_sparse_) {
AllReduceSparse(&group, next_group_);
} else {
FusedAllReduceSchedule(&group, next_group_);
}
}
}
bool EagerReducer::HasGrad(size_t var_index) {
auto grad = egr::EagerUtils::mutable_grad(tensors_[var_index]);
if (grad && grad->initialized()) {
return true;
} else {
return false;
}
}
void EagerReducer::ProcessUnusedDenseVars() {
// The calculation stream must be used here to
// avoid conflicts with communication.
VLOG(3) << "Local used vars : "
<< string::join_strings(local_used_vars_, ',');
const auto *dev_ctx =
platform::DeviceContextPool::Instance().Get(inner_place_);
auto *global_used_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(global_used_vars_.impl())
.get();
framework::TensorFromVector<int32_t>(local_used_vars_, *dev_ctx,
global_used_tensor);
distributed::AllreduceOptions opts;
opts.reduce_op = ReduceOp::SUM;
std::vector<Tensor> reduce_tensors = {global_used_vars_};
std::vector<phi::DenseTensor> in_out;
for (auto &t : reduce_tensors) {
in_out.push_back(*std::dynamic_pointer_cast<phi::DenseTensor>(t.impl()));
}
process_group_->AllReduce(in_out, in_out, opts)->Synchronize();
framework::TensorToVector<int>(*global_used_tensor, *dev_ctx,
&local_used_vars_);
dev_ctx->Wait();
// sync compute stream to get global used var message,
// but maybe affect speed performance
VLOG(3) << "Global used vars : "
<< string::join_strings(local_used_vars_, ',');
for (const auto var_index : unused_vars_) {
const bool global_unused = (local_used_vars_[var_index] == 0);
// global used but local unused, set grad
VLOG(3) << "[Rank " << process_group_->GetRank() << "]: "
<< "Var [" << var_index << "] [" << tensors_[var_index].name()
<< "] global_unused: " << global_unused
<< " has grad: " << HasGrad(var_index);
if (!global_unused) {
VLOG(3) << "Set Tensor[" << var_index << "]'s Grad for [Rank "
<< process_group_->GetRank() << "]";
const auto &var_locator = variable_locators_[var_index];
const auto group_index = var_locator.group_index;
const auto &group = groups_[group_index];
const auto inside_group_index = var_locator.inside_group_index;
auto &src_tensor = group.dense_tensors_[inside_group_index];
// sparse no need to check and no support find_unused_parameters
if (group.is_sparse_) {
continue;
}
Tensor grad_value(std::make_shared<phi::DenseTensor>(src_tensor));
auto dest_var_base = tensors_[var_index];
auto grad_tensor = egr::EagerUtils::mutable_grad(dest_var_base);
grad_tensor->copy_(grad_value, inner_place_, true);
grad_tensor->reshape(dest_var_base.shape());
}
}
}
void EagerReducer::FinalizeBackward() {
groups_need_finalize_ = false;
grad_need_hooks_ = false;
for (auto &group : groups_) {
if (!group.is_sparse_) {
group.task->Synchronize();
}
}
for (auto &group : groups_) {
if (!group.is_sparse_) {
group.SplitTensors(inner_place_);
}
}
if (find_unused_vars_each_step_) {
ProcessUnusedDenseVars();
local_used_vars_.clear();
local_used_vars_.resize(tensors_.size(), 0);
VLOG(3) << "ProcessUnusedDenseVars is finished.";
}
VLOG(3) << "In the batch, Reducer is finished.";
}
void EagerReducer::FusedAllReduceSchedule(EagerGroup *group,
const int curr_group_index) {
// The overall timeline: concat > div_nranks > allreduce > split
distributed::AllreduceOptions opts;
opts.reduce_op = ReduceOp::SUM;
VLOG(3) << "group [" << curr_group_index << "] start fused_allreduce.";
// concat tensors
group->ConcatTensors(inner_place_);
// div nranks
paddle::experimental::scale_(group->dense_contents_, 1.0 / nranks_, 0.0,
false);
// all_reduce
std::vector<Tensor> reduce_tensors = {group->dense_contents_};
std::vector<phi::DenseTensor> in_out;
for (auto &t : reduce_tensors) {
in_out.push_back(*std::dynamic_pointer_cast<phi::DenseTensor>(t.impl()));
}
group->task = process_group_->AllReduce(in_out, in_out, opts);
// split in FinalizeBackward()
}
void EagerReducer::AllReduceSparse(EagerGroup *group,
const int curr_group_index) {
// div nranks
Tensor sparse_tensor(group->sparse_contents_);
paddle::experimental::scale_(sparse_tensor, 1.0 / nranks_, 0.0, false);
VLOG(3) << "sparse_group [" << curr_group_index << "] start allreduce.";
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(inner_place_);
if (platform::is_gpu_place(inner_place_)) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
dev_ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(inner_place_));
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't concat grad tensors since it's not compiled with NCCL,"
"Please recompile or reinstall Paddle with NCCL support."));
#endif
} else if (platform::is_cpu_place(inner_place_)) {
dev_ctx = static_cast<platform::CPUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(inner_place_));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Split grad tensor not supported on place (%s)", inner_place_));
}
auto src = std::dynamic_pointer_cast<phi::SelectedRows>(
group->sparse_contents_.impl());
const auto &src_rows = src->rows();
const auto &rank_ = process_group_->GetRank();
const auto &size_ = process_group_->GetSize();
framework::Vector<int64_t> rows_num_vector(size_);
rows_num_vector[rank_] = static_cast<int64_t>(src_rows.size());
Tensor rows_num_tensor = paddle::experimental::empty(
IntArray({static_cast<int64_t>(size_)}), DataType::INT64, inner_place_);
auto *rows_num_dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(rows_num_tensor.impl()).get();
framework::TensorFromVector<int64_t>(rows_num_vector, *dev_ctx,
rows_num_dense_tensor);
distributed::AllreduceOptions opts;
opts.reduce_op = ReduceOp::SUM;
std::vector<Tensor> reduce_tensors = {rows_num_tensor};
std::vector<phi::DenseTensor> in_out;
for (auto &t : reduce_tensors) {
in_out.push_back(*std::dynamic_pointer_cast<phi::DenseTensor>(t.impl()));
}
process_group_->AllReduce(in_out, in_out, opts)->Synchronize();
framework::TensorToVector<int64_t>(*rows_num_dense_tensor, *dev_ctx,
&rows_num_vector);
dev_ctx->Wait();
const auto *cpu_rows_num_ptr = rows_num_vector.data();
auto rows_num = std::accumulate(cpu_rows_num_ptr, cpu_rows_num_ptr + size_,
static_cast<int64_t>(0));
VLOG(3) << "Gather rows: " << string::join_strings(rows_num_vector, ',')
<< ", total rows number: " << rows_num
<< ", height: " << src->height();
dev_ctx->Wait();
if (std::all_of(cpu_rows_num_ptr, cpu_rows_num_ptr + size_,
[&](int64_t row) { return row == cpu_rows_num_ptr[0]; })) {
// During sparse communication, the number of each card is same.
// allgather is used to speed up the allreduce by replacing broadcast.
VLOG(3) << "allgather replaces broadcast to speed up in sparse allreduce";
Tensor dst_rows_tensor =
paddle::experimental::empty(IntArray({static_cast<int64_t>(rows_num)}),
DataType::INT64, inner_place_);
Tensor src_rows_tensor = paddle::experimental::empty(
IntArray({static_cast<int64_t>((*src).rows().size())}), DataType::INT64,
inner_place_);
auto *src_rows_dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(src_rows_tensor.impl())
.get();
framework::TensorFromVector<int64_t>((*src).rows(), *dev_ctx,
src_rows_dense_tensor);
std::vector<Tensor> src_rows_tensors = {src_rows_tensor};
std::vector<Tensor> dst_rows_tensors = {dst_rows_tensor};
std::vector<phi::DenseTensor> in;
std::vector<phi::DenseTensor> out;
for (auto &t : src_rows_tensors) {
in.push_back(*std::dynamic_pointer_cast<phi::DenseTensor>(t.impl()));
}
for (auto &t : dst_rows_tensors) {
out.push_back(*std::dynamic_pointer_cast<phi::DenseTensor>(t.impl()));
}
process_group_->AllGather(in, out)->Synchronize();
framework::Vector<int64_t> dst_rows_vector(rows_num, 0);
auto *dst_rows_dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(dst_rows_tensor.impl())
.get();
framework::TensorToVector<int64_t>(*dst_rows_dense_tensor, *dev_ctx,
&dst_rows_vector);
dev_ctx->Wait();
Tensor src_value_tensor(std::make_shared<phi::DenseTensor>(src->value()));
std::vector<int64_t> dst_shape = src_value_tensor.shape();
dst_shape[dst_shape.size() - 2] = rows_num;
auto dst_dense_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
paddle::experimental::full(IntArray(dst_shape), 0,
src_value_tensor.dtype(), inner_place_)
.impl());
auto dst =
std::make_shared<phi::SelectedRows>(dst_rows_vector, (*src).height());
*(dst->mutable_value()) = *dst_dense_tensor;
Tensor dst_value_tensor(std::make_shared<phi::DenseTensor>(dst->value()));
std::vector<Tensor> src_value_tensors = {src_value_tensor};
std::vector<Tensor> dst_value_tensors = {dst_value_tensor};
std::vector<phi::DenseTensor> src_dense;
std::vector<phi::DenseTensor> dst_dense;
for (auto &t : src_value_tensors) {
src_dense.push_back(
*std::dynamic_pointer_cast<phi::DenseTensor>(t.impl()));
}
for (auto &t : dst_value_tensors) {
dst_dense.push_back(
*std::dynamic_pointer_cast<phi::DenseTensor>(t.impl()));
}
process_group_->AllGather(src_dense, dst_dense)->Synchronize();
src->set_rows(dst_rows_vector);
*(src->mutable_value()) =
*(std::dynamic_pointer_cast<phi::DenseTensor>(dst_value_tensor.impl()));
} else {
PADDLE_THROW(
platform::errors::Unimplemented("This case is not supported."));
}
}
std::ostream &operator<<(std::ostream &out, const EagerGroup &group) {
const auto &tensors_ = group.tensor_indices_;
out << "numel: " << group.all_length_ << " ;var number: " << tensors_.size()
<< "\n";
auto begin = tensors_.begin();
auto end = tensors_.end();
out << "[";
for (int i = 0; begin != end && i < 100; ++i, ++begin) {
if (i > 0) out << ' ';
out << *begin;
}
if (begin != end) {
out << " ...";
}
out << "]\n";
return out;
}
} // namespace distributed
} // namespace paddle