-
Notifications
You must be signed in to change notification settings - Fork 3.3k
/
bagua.py
244 lines (200 loc) · 10.1 KB
/
bagua.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
import logging
import os
from typing import Any, Dict, List, Optional, Union
import torch
from torch.nn import Module
import pytorch_lightning as pl
from pytorch_lightning.overrides.base import (
_LightningModuleWrapperBase,
_LightningPrecisionModuleWrapperBase,
unwrap_lightning_module,
)
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _BAGUA_AVAILABLE
from pytorch_lightning.utilities.seed import reset_seed
if _BAGUA_AVAILABLE:
import bagua.torch_api as bagua
from bagua.torch_api.algorithms import Algorithm
from bagua.torch_api.algorithms.q_adam import QAdamOptimizer
from bagua.torch_api.communication import allreduce_inplace, barrier, broadcast_object, is_initialized
from bagua.torch_api.communication import ReduceOp as BaguaReduceOp
from bagua.torch_api.data_parallel.distributed import DistributedDataParallel_V1_9_0 as BaguaDistributedDataParallel
else:
BaguaReduceOp = None
BaguaDistributedDataParallel = None
log = logging.getLogger(__name__)
class LightningBaguaModule(_LightningModuleWrapperBase):
def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None:
super().__init__(pl_module)
# Bagua use `bagua_module_name` to distinguish different modules
self._bagua_module_name = f"{pl_module.__class__.__name__}{id(pl_module)}"
if _BAGUA_AVAILABLE:
# Convert a reduce op to its equivalent `bagua.torch_api.ReduceOp`
_bagua_reduce_ops = {
ReduceOp.SUM: BaguaReduceOp.SUM,
ReduceOp.PRODUCT: BaguaReduceOp.PRODUCT,
ReduceOp.MIN: BaguaReduceOp.MIN,
ReduceOp.MAX: BaguaReduceOp.MAX,
ReduceOp.BAND: BaguaReduceOp.BAND,
ReduceOp.BOR: BaguaReduceOp.BOR,
ReduceOp.BXOR: BaguaReduceOp.BXOR,
"avg": BaguaReduceOp.AVG,
"mean": BaguaReduceOp.AVG,
"sum": BaguaReduceOp.SUM,
}
else:
_bagua_reduce_ops = {}
class BaguaStrategy(DDPStrategy):
strategy_name = "bagua"
def __init__(
self,
algorithm: str = "gradient_allreduce",
flatten: bool = True,
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
**bagua_kwargs: Union[Any, Dict[str, Any]],
):
"""Strategy for training using the `Bagua <https://github.com/BaguaSys/bagua>`_ library, with advanced
distributed training algorithms and system optimizations.
This strategy requires the `bagua` package to be installed. See
`installation guide <https://tutorials.baguasys.com/installation>`_ for more information.
The :class:`BaguaStrategy` is only supported on GPU and on Linux systems.
Arguments:
algorithm: Distributed algorithm used to do the actual communication and update. Built-in algorithms
include "gradient_allreduce", "bytegrad", "decentralized", "low_precision_decentralized", "qadam" and
"async".
flatten: Whether to flatten the Bagua communication buckets. The flatten operation will reset data
pointer of bucket tensors so that they can use faster code paths.
bagua_kwargs: Additional keyword arguments that will be passed to initialize the Bagua algorithm. More
details on keyword arguments accepted for each algorithm can be found in the
`documentation <https://bagua.readthedocs.io/en/latest/autoapi/bagua/torch_api/algorithms/index.html>`_.
"""
if not _BAGUA_AVAILABLE:
raise MisconfigurationException(
"To use the `BaguaStrategy`, you must have `Bagua` installed. Use `pip install bagua` to install it."
)
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)
self._bagua_algorithm = algorithm
self._bagua_flatten = flatten
self._bagua_kwargs = bagua_kwargs
@property
def lightning_module(self) -> "pl.LightningModule":
model = self._model
if isinstance(model, BaguaDistributedDataParallel):
model = model.module
return unwrap_lightning_module(model) # type: ignore[arg-type]
def setup_distributed(self) -> None:
reset_seed()
# determine which process we are and world size
self.set_world_ranks()
self._init_bagua_distributed()
def _init_bagua_distributed(self) -> None:
self._set_node_environment_variables()
log.info(
"Initializing Bagua Distributed: "
f"GLOBAL_RANK: {self.global_rank}, "
f"MEMBER: {self.global_rank + 1}/{self.world_size}"
)
# need to set device first before initialize Bagua distributed environment
# Note: setup_environment calls super().setup_distributed after calling init_distributed()
torch.cuda.set_device(self.local_rank)
if not is_initialized():
bagua.init_process_group()
def _set_node_environment_variables(self) -> None:
"""Set the environment variables as required by the :func:`bagua.init_process_group` call.
This enables the use of other cluster environments which don't set these exact variables, e.g., Bagua can be
launched with ``torch.distributed.run``.
"""
os.environ["MASTER_ADDR"] = self.cluster_environment.main_address # type: ignore[union-attr]
os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) # type: ignore[union-attr]
os.environ["RANK"] = str(self.global_rank)
os.environ["NODE_RANK"] = str(self.node_rank)
os.environ["WORLD_SIZE"] = str(self.world_size)
os.environ["LOCAL_RANK"] = str(self.local_rank)
def _check_qadam_optimizer(self) -> None:
has_qadam_optimizer = any([isinstance(opt, QAdamOptimizer) for opt in self.optimizers])
if not has_qadam_optimizer or len(self.optimizers) > 1 or len(self.lr_scheduler_configs) > 1:
raise MisconfigurationException("Bagua QAdam can only accept one QAdamOptimizer and one LR Scheduler.")
self._bagua_kwargs["q_adam_optimizer"] = self.optimizers[0]
def configure_ddp(self) -> None:
model = LightningBaguaModule(self.model) # type: ignore[arg-type]
self._model = self._setup_model(model)
# start the background communication for async algorithm
assert self.lightning_module.trainer is not None
if self.lightning_module.trainer.training and self._bagua_algorithm == "async":
self.model.bagua_algorithm.resume(self.model) # type: ignore
def _setup_model(self, model: Module) -> BaguaDistributedDataParallel:
"""Wraps the model into a Bagua distributed module."""
if self._bagua_algorithm == "qadam":
self._check_qadam_optimizer()
algorithm = Algorithm.init(self._bagua_algorithm, **self._bagua_kwargs)
return BaguaDistributedDataParallel(
module=model,
optimizers=self.optimizers,
algorithm=algorithm,
gradient_as_bucket_view=self._bagua_flatten,
)
@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
cls.strategy_name,
cls,
description=f"{cls.__class__.__name__}",
)
def teardown(self) -> None:
# abort the background communication for async algorithm
assert self.lightning_module.trainer is not None
if self.lightning_module.trainer.training and self._bagua_algorithm == "async":
self.model.bagua_algorithm.abort(self.model) # type: ignore
if isinstance(self.model, BaguaDistributedDataParallel):
self.model = self.lightning_module
if self.root_device.type == "cuda":
# GPU teardown
log.detail(f"{self.__class__.__name__}: moving model to CPU")
self.lightning_module.cpu()
# clean up memory
torch.cuda.empty_cache()
def barrier(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
if is_initialized():
barrier()
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
return broadcast_object(obj, src)
def reduce(
self, tensor: torch.Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
) -> torch.Tensor:
"""Reduces a tensor from several distributed processes to one aggregated tensor.
Args:
tensor: The tensor to sync and reduce.
group: The process group to gather results from. Defaults to all processes (world).
reduce_op: The reduction operation.
Can also be a string 'sum' or ReduceOp.
Return:
The reduced value, except when the input was not a tensor the output remains is unchanged.
"""
if not isinstance(tensor, torch.Tensor):
return tensor
if group is not None:
raise ValueError("`Bagua` does not support allreduce using a subcommunicator at this time. Unset `group`.")
if reduce_op is None:
op = BaguaReduceOp.AVG
else:
op = _bagua_reduce_ops.get(reduce_op, None)
if op is None:
raise ValueError(f"Unrecognized `reduce_op` for `BaguaStrategy`: {reduce_op}")
allreduce_inplace(tensor, op=op)
return tensor