Skip to content

Commit

Permalink
[Dygraph] Support grad division to nranks before reduce in sharding s…
Browse files Browse the repository at this point in the history
…tage2 (#47764)
  • Loading branch information
haohongxiang committed Nov 10, 2022
1 parent 7964119 commit 3addd56
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 19 deletions.
Expand Up @@ -498,12 +498,7 @@ def _offload_clear_grad(self):
with device_guard(self._rank, self.offload_device):
self.offload_grads.buffer.zero_()

def step(self):
"""
A wrapper for Optimizer's step function to finish the update operation of the optimizer.
"""
# This method won't be called directly by opt.step()!
# The _redefine_opt_step() in class GroupShardedStage2 will wrap this function.
def _step(self):
if self._broadcast_overlap:
# Clear the pre forward hook in the optimizer step.
for hook_remove in self._forward_pre_hook_remove_helper:
Expand Down Expand Up @@ -536,6 +531,14 @@ def step(self):
# Synchronize all the updated shards in between the ranks
self._broadcast_params()

def step(self):
"""
A wrapper for Optimizer's step function to finish the update operation of the optimizer.
"""
# This method won't be called directly by opt.step()!
# The _redefine_opt_step() in class GroupShardedStage2 will wrap this function.
self._step()

def minimize(self):
raise RuntimeError(
"optimizer.minimize() not support now, please use optimizer.step()"
Expand Down
Expand Up @@ -225,13 +225,13 @@ def _clear_gradients(self):

def _grad_scale(self):
"""
Before the gradient accumulation, scale the gradient.
Before the optimization, scale the gradients before allreduce of dp_group.
"""

if self._dp_group is None or self._dp_group.nranks <= 1:
scale_factor = self._world_size_scaling
return
else:
scale_factor = 1.0 / (self._group.nranks * self._dp_group.nranks)
scale_factor = 1.0 / (self._dp_group.nranks)

# Scale grad storages
for dtype in self._grad_storages.keys():
Expand Down Expand Up @@ -366,6 +366,13 @@ def _set_reduce_overlap(self, reduce_overlap):
), "Only support comm overlap strategy for single optimizer"
self._sharding_optimizers[0]._set_reduce_overlap(reduce_overlap)

def _get_scaled_grad_fn(self):
@paddle.autograd.no_grad()
def scale(grad):
grad.scale_(self._world_size_scaling)

return scale

def _get_reduce_fn(self, index, param, dst_rank):
"""
There are two ways to reduce gradient.
Expand Down Expand Up @@ -510,6 +517,8 @@ def _setup_backward_hooks(self):
return

for index, param in enumerate(self._trainable_params):
param._register_grad_hook(self._get_scaled_grad_fn())

dst_rank = self._trainable_param2rank[param.name]

reduce_function = self._get_reduce_fn(index, param, dst_rank)
Expand Down
Expand Up @@ -153,16 +153,6 @@ def test_sharding_api():
list(range(paddle.distributed.get_world_size()))
)

stage2_dp_params = train_mlp(
mlp1,
shard_level="os_g",
use_multi_precision=True,
output_dir=output_dir,
amp_level='O2',
sync_buffers=True,
dp_group=dp_group,
)

# fp16
stage2_params = train_mlp(
mlp1,
Expand Down

0 comments on commit 3addd56

Please sign in to comment.