Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Bagua Strategy #11146

Merged
merged 89 commits into from Feb 4, 2022
Merged
Show file tree
Hide file tree
Changes from 84 commits
Commits
Show all changes
89 commits
Select commit Hold shift + click to select a range
b231a81
init commit
wangraying Dec 14, 2021
0db1584
.
wangraying Dec 15, 2021
33646db
support qadam and remove bagua algorithm enum
wangraying Dec 16, 2021
433e2ed
add bagua communication api
wangraying Dec 17, 2021
b3b23ff
allow multiple wraps
wangraying Dec 17, 2021
709d311
remove
wangraying Dec 17, 2021
a9472bd
.
wangraying Dec 17, 2021
d351050
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 17, 2021
39ddbb9
update reduce
wangraying Dec 17, 2021
174a8c5
Merge branch 'bagua-plugin' of https://github.com/wangraying/pytorch-…
wangraying Dec 17, 2021
cac63db
minor fix
wangraying Dec 18, 2021
cd75ee9
Merge branch 'master' into bagua-plugin
awaelchli Dec 23, 2021
f390d3a
update bagua to strategy api
awaelchli Dec 23, 2021
ff9ed13
merge pre_dispatch and setup()
awaelchli Dec 23, 2021
110b6e8
move start_training to setup()
awaelchli Dec 23, 2021
71eff2a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 23, 2021
d2c1bf2
fix undefined import
awaelchli Dec 23, 2021
d5b03d8
update for bagua strategy
wangraying Dec 27, 2021
9f3df78
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 27, 2021
ad8aac5
update doc
wangraying Dec 27, 2021
385020c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 27, 2021
2d6b8c9
ci
wangraying Dec 30, 2021
10f2fff
.
wangraying Dec 30, 2021
32f2e12
Apply suggestions from code review
wangraying Jan 4, 2022
0e2e839
Merge branch 'bagua-plugin' of https://github.com/wangraying/pytorch-…
wangraying Jan 4, 2022
466fa52
add tests
wangraying Jan 4, 2022
3b48243
Merge branch 'bagua-plugin' of https://github.com/wangraying/pytorch-…
wangraying Jan 4, 2022
66e9a2c
update ci
wangraying Jan 4, 2022
6ec6990
Merge branch 'master' into bagua-plugin
wangraying Jan 4, 2022
c2d3b40
update to master
wangraying Jan 4, 2022
7319a5f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 4, 2022
de3cffd
fix mypy
wangraying Jan 4, 2022
65097de
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 4, 2022
69ccbd8
ci
wangraying Jan 4, 2022
90f4ff4
Merge branch 'bagua-plugin' of https://github.com/wangraying/pytorch-…
wangraying Jan 4, 2022
d9c7e9c
update test
wangraying Jan 4, 2022
9432714
add assertion to test
wangraying Jan 5, 2022
aa622cf
Merge branch 'master' into bagua-plugin
wangraying Jan 5, 2022
b0568f8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2022
d7966a7
small update
wangraying Jan 6, 2022
fc321ee
small update
wangraying Jan 6, 2022
42d5baa
Merge branch 'master' into bagua-plugin
wangraying Jan 17, 2022
53f3a28
fix for ci failure
wangraying Jan 17, 2022
19732f1
for ci
wangraying Jan 17, 2022
0d92130
fmt
wangraying Jan 17, 2022
94786b8
add doc
wangraying Jan 18, 2022
a4e2332
update
wangraying Jan 18, 2022
dbe9268
.
wangraying Jan 18, 2022
5cf1bce
.
wangraying Jan 18, 2022
76a1350
refine doc
wangraying Jan 18, 2022
c498608
update doc
wangraying Jan 18, 2022
9e4abe0
Apply suggestions from code review
wangraying Jan 18, 2022
54a30b2
Update docs/source/accelerators/gpu.rst
wangraying Jan 18, 2022
90b9405
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 18, 2022
38e665d
update doc
wangraying Jan 18, 2022
b0a9f51
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 18, 2022
3815fc7
add test
wangraying Jan 18, 2022
202c3bd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 18, 2022
b8aad03
ci
wangraying Jan 18, 2022
0ee1802
Merge branch 'bagua-plugin' of https://github.com/wangraying/pytorch-…
wangraying Jan 18, 2022
fc3f2c8
Update docs/source/accelerators/gpu.rst
wangraying Jan 18, 2022
75b548f
update doc
wangraying Jan 18, 2022
062ffc1
add brief intro
wangraying Jan 18, 2022
3bae9ed
fix test
wangraying Jan 18, 2022
c251e74
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 18, 2022
fc6d415
typo
wangraying Jan 18, 2022
7d74281
Apply suggestions from code review
wangraying Jan 19, 2022
0f0f1bc
apply suggestions
wangraying Jan 19, 2022
78c4f6b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2022
e13e794
ci
wangraying Jan 19, 2022
53a0ad0
Apply suggestions from code review
wangraying Jan 19, 2022
54e052e
Apply suggestions from code review
wangraying Jan 20, 2022
997d22d
update doc and code
wangraying Jan 20, 2022
b2df6b1
Merge branch 'master' into bagua-plugin
wangraying Jan 21, 2022
8dcb9d0
fix for refactor
wangraying Jan 21, 2022
de55851
move property to top
awaelchli Jan 24, 2022
c840157
fix typo
awaelchli Jan 24, 2022
a1660ab
update
wangraying Jan 25, 2022
2812635
Update pytorch_lightning/strategies/bagua.py
wangraying Jan 25, 2022
8ccd249
Merge remote-tracking branch 'origin/master' into bagua-plugin
wangraying Jan 25, 2022
f378851
Merge branch 'master' into bagua-plugin
carmocca Feb 3, 2022
a05b745
minor updates
awaelchli Feb 4, 2022
dce2d8a
Merge branch 'master' into bagua-plugin
awaelchli Feb 4, 2022
347dd9d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 4, 2022
0a1b278
fix
awaelchli Feb 4, 2022
716adf6
Merge branch 'bagua-plugin' of github.com:wangraying/pytorch-lightnin…
awaelchli Feb 4, 2022
0881ffb
minor changes
wangraying Feb 4, 2022
29aab45
Update RunIf
carmocca Feb 4, 2022
8d19cbe
Remove unnecessary function
carmocca Feb 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .azure-pipelines/gpu-tests.yml
Expand Up @@ -52,6 +52,7 @@ jobs:
python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)"
pip install fairscale==0.4.0
pip install deepspeed==0.5.7
pip install bagua-cuda102==0.9.0
pip install . --requirement requirements/devel.txt
pip list
displayName: 'Install dependencies'
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -86,6 +86,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a warning when using `DistributedSampler` during validation/testing ([#11479](https://github.com/PyTorchLightning/pytorch-lightning/pull/11479))


- Added support for `Bagua` training strategy ([#11146](https://github.com/PyTorchLightning/pytorch-lightning/pull/11146))


### Changed

- Set the `prog_bar` flag to False in `LightningModule.log_grad_norm` ([#11472](https://github.com/PyTorchLightning/pytorch-lightning/pull/11472))
Expand Down
122 changes: 122 additions & 0 deletions docs/source/accelerators/gpu.rst
Expand Up @@ -282,6 +282,7 @@ Lightning allows multiple ways of training
- DistributedDataParallel (``strategy='ddp_spawn'``) (multiple-gpus across many machines (spawn based)).
- DistributedDataParallel 2 (``strategy='ddp2'``) (DP in a machine, DDP across machines).
- Horovod (``strategy='horovod'``) (multi-machine, multi-gpu, configured at runtime)
- Bagua (``strategy='bagua'``) (multiple-gpus across many machines with advanced training algorithms)
- TPUs (``tpu_cores=8|x``) (tpu or TPU pod)

.. note::
Expand Down Expand Up @@ -489,6 +490,127 @@ number of worker processes:
See the official `Horovod documentation <https://horovod.readthedocs.io/en/stable>`_ for details
on installation and performance tuning.


Bagua
^^^^^
`Bagua <https://github.com/BaguaSys/bagua>`_ is a deep learning training acceleration framework which supports
multiple advanced distributed training algorithms including:

- `Gradient AllReduce <https://tutorials.baguasys.com/algorithms/gradient-allreduce>`_ for centralized synchronous communication, where gradients are averaged among all workers.
- `Decentralized SGD <https://tutorials.baguasys.com/algorithms/decentralized>`_ for decentralized synchronous communication, where each worker exchanges data with one or a few specific workers.
- `ByteGrad <https://tutorials.baguasys.com/algorithms/bytegrad>`_ and `QAdam <https://tutorials.baguasys.com/algorithms/q-adam>`_ for low precision communication, where data is compressed into low precision before communication.
- `Asynchronous Model Average <https://tutorials.baguasys.com/algorithms/async-model-average>`_ for asynchronous communication, where workers are not required to be synchronized in the same iteration in a lock-step style.

By default, Bagua uses *Gradient AllReduce* algorithm, which is also the algorithm implemented in Distributed Data Parallel and Horovod,
but Bagua can usually produce a higher training throughput due to its backend written in Rust.

.. code-block:: python

# train on 2 GPUs (using Bagua mode)
trainer = Trainer(strategy="bagua", accelerator="gpu", devices=4)


By specifying the ``algorithm`` in the ``BaguaStrategy``, you can select more advanced training algorithms featured by Bagua:


.. code-block:: python

# train on 4 GPUs, using Bagua Gradient AllReduce algorithm
trainer = Trainer(
strategy=BaguaStrategy(algorithm="gradient_allreduce"),
accelerator="gpu",
devices=4,
)

# train on 4 GPUs, using Bagua ByteGrad algorithm
trainer = Trainer(
strategy=BaguaStrategy(algorithm="bytegrad"),
accelerator="gpu",
devices=4,
)

# train on 4 GPUs, using Bagua Decentralized SGD
trainer = Trainer(
strategy=BaguaStrategy(algorithm="decentralized"),
accelerator="gpu",
devices=4,
)

# train on 4 GPUs, using Bagua Low Precision Decentralized SGD
trainer = Trainer(
strategy=BaguaStrategy(algorithm="low_precision_decentralized"),
accelerator="gpu",
devices=4,
)

# train on 4 GPUs, using Asynchronous Model Average algorithm, with a synchronization interval of 100ms
trainer = Trainer(
tchaton marked this conversation as resolved.
Show resolved Hide resolved
strategy=BaguaStrategy(algorithm="async", sync_interval_ms=100),
accelerator="gpu",
devices=4,
)

To use *QAdam*, we need to initialize
`QAdamOptimizer <https://bagua.readthedocs.io/en/latest/autoapi/bagua/torch_api/algorithms/q_adam/index.html#bagua.torch_api.algorithms.q_adam.QAdamOptimizer>`_ first:

.. code-block:: python

from pytorch_lightning.strategies import BaguaStrategy
from bagua.torch_api.algorithms.q_adam import QAdamOptimizer


class MyModel(pl.LightningModule):
...

def configure_optimizers(self):
# initialize QAdam Optimizer
return QAdamOptimizer(self.parameters(), lr=0.05, warmup_steps=100)


model = MyModel()
trainer = Trainer(
accelerator="gpu",
devices=4,
strategy=BaguaStrategy(algorithm="qadam"),
)
trainer.fit(model)

Bagua relies on its own `launcher <https://tutorials.baguasys.com/getting-started/#launch-job>`_ to schedule jobs.
Below, find examples using ``bagua.distributed.launch`` which follows ``torch.distributed.launch`` API:

.. code-block:: bash

# start training with 8 GPUs on a single node
python -m bagua.distributed.launch --nproc_per_node=8 train.py

# Run on node1 to start training on two nodes (node1 and node2), 8 GPUs per node
python -m bagua.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=hostname1 --master_port=port1 train.py

# Run on node2 to start training on two nodes (node1 and node2), 8 GPUs per node
python -m bagua.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=hostname1 --master_port=port1 train.py
carmocca marked this conversation as resolved.
Show resolved Hide resolved


If the ssh service is available with passwordless login on each node, you can launch the distributed job on a
single node with ``baguarun`` which has a similar syntax as ``mpirun``. When staring the job, ``baguarun`` will
automatically spawn new processes on each of your training node provided by ``--host_list`` option and each node in it
is described as an ip address followed by a ssh port.

.. code-block:: bash

# Run on node1 (or node2) to start training on two nodes (node1 and node2), 8 GPUs per node
baguarun --host_list hostname1:ssh_port1,hostname2:ssh_port2 --nproc_per_node=8 --master_port=port1 train.py


.. note:: You can also start training in the same way as Distributed Data Parallel. However, system optimizations like
`Bagua-Net <https://tutorials.baguasys.com/more-optimizations/bagua-net>`_ and
`Performance autotuning <https://tutorials.baguasys.com/performance-autotuning/>`_ can only be enabled through bagua
launcher. It is worth noting that with ``Bagua-Net``, Distributed Data Parallel can also achieve
better performance without modifying the training script.


See `Bagua Tutorials <https://tutorials.baguasys.com/>`_ for more details on installation and advanced features.


DP/DDP2 caveats
^^^^^^^^^^^^^^^
In DP and DDP2 each GPU within a machine sees a portion of a batch.
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_references.rst
Expand Up @@ -43,6 +43,7 @@ Strategy API
:nosignatures:
:template: classtemplate.rst

BaguaStrategy
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
DDP2Strategy
DDPFullyShardedStrategy
DDPShardedStrategy
Expand Down
1 change: 1 addition & 0 deletions docs/source/extensions/plugins.rst
Expand Up @@ -107,6 +107,7 @@ Training Strategies
DDPShardedStrategy
DDPSpawnShardedStrategy
DDPSpawnStrategy
BaguaStrategy
DeepSpeedStrategy
HorovodStrategy
SingleTPUStrategy
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/environments/__init__.py
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.plugins.environments.bagua_environment import BaguaEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.kubeflow_environment import KubeflowEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.lightning_environment import LightningEnvironment # noqa: F401
Expand Down
62 changes: 62 additions & 0 deletions pytorch_lightning/plugins/environments/bagua_environment.py
@@ -0,0 +1,62 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os

from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment

log = logging.getLogger(__name__)


class BaguaEnvironment(ClusterEnvironment):
justusschock marked this conversation as resolved.
Show resolved Hide resolved
"""Environment for distributed training with `Bagua <https://tutorials.baguasys.com/>`_"""

@property
def creates_processes_externally(self) -> bool:
return True

@property
def main_address(self) -> str:
return os.environ.get("MASTER_ADDR", "127.0.0.1")

@property
def main_port(self) -> int:
return int(os.environ.get("MASTER_PORT", -1))

@property
def service_port(self) -> int:
return int(os.environ.get("BAGUA_SERVICE_PORT", -1))

@staticmethod
def detect() -> bool:
return "BAGUA_SERVICE_PORT" in os.environ
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def world_size(self) -> int:
return int(os.environ["WORLD_SIZE"])

def set_world_size(self, size: int) -> None:
log.debug("`BaguaEnvironment.set_world_size` was called, but setting world size is not allowed. Ignored.")

def global_rank(self) -> int:
return int(os.environ["RANK"])
carmocca marked this conversation as resolved.
Show resolved Hide resolved

def set_global_rank(self, rank: int) -> None:
log.debug("`BaguaEnvironment.set_global_rank` was called, but setting global rank is not allowed. Ignored.")

def local_rank(self) -> int:
return int(os.environ.get("LOCAL_RANK", 0))

def node_rank(self) -> int:
return int(os.environ.get("NODE_RANK", 0))
1 change: 1 addition & 0 deletions pytorch_lightning/strategies/__init__.py
@@ -1,5 +1,6 @@
from pathlib import Path

from pytorch_lightning.strategies.bagua import BaguaStrategy # noqa: F401
from pytorch_lightning.strategies.ddp import DDPStrategy # noqa: F401
from pytorch_lightning.strategies.ddp2 import DDP2Strategy # noqa: F401
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy # noqa: F401
Expand Down