-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into feature/mark-forward-method
- Loading branch information
Showing
24 changed files
with
1,902 additions
and
105 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
## Tensor Parallel and 2D Parallel | ||
|
||
This example shows how to apply tensor-parallelism to your model (here Llama 2 7B) with the `ModelParallelStrategy`, and how it can be combined with FSDP (2D parallelism). | ||
PyTorch 2.3+ and a machine with at least 4 GPUs and 24 GB memory each are required to run this example. | ||
|
||
```bash | ||
pip install 'torch>=2.3' | ||
``` | ||
|
||
Navigate to this example folder and run the training script: | ||
|
||
```bash | ||
cd examples/fabric/tensor_parallel | ||
python train.py | ||
``` | ||
|
||
You should see an output like this: | ||
|
||
``` | ||
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4 | ||
Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/4 | ||
Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/4 | ||
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/4 | ||
---------------------------------------------------------------------------------------------------- | ||
distributed_backend=nccl | ||
All distributed processes registered. Starting with 4 processes | ||
---------------------------------------------------------------------------------------------------- | ||
Number of model parameters: 6.7 B | ||
Starting training ... | ||
Iteration 0 complete | ||
Iteration 1 complete | ||
Iteration 2 complete | ||
Iteration 3 complete | ||
Iteration 4 complete | ||
Iteration 5 complete | ||
Iteration 6 complete | ||
Iteration 7 complete | ||
Saving a (distributed) checkpoint ... | ||
Training successfully completed! | ||
Peak memory usage: 17.95 GB | ||
``` | ||
|
||
> \[!NOTE\] | ||
> The `ModelParallelStrategy` is experimental and subject to change. Report issues on [GitHub](https://github.com/Lightning-AI/pytorch-lightning/issues). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import torch | ||
from torch.utils.data import Dataset | ||
|
||
|
||
class RandomTokenDataset(Dataset): | ||
def __init__(self, vocab_size: int, seq_length: int): | ||
self.vocab_size = vocab_size | ||
self.seq_length = seq_length | ||
self.tokens = torch.randint( | ||
self.vocab_size, | ||
size=(len(self), self.seq_length + 1), | ||
# Set a seed to make this toy dataset the same on each rank | ||
# Fabric will add a `DistributedSampler` to shard the data correctly | ||
generator=torch.Generator().manual_seed(42), | ||
) | ||
|
||
def __len__(self) -> int: | ||
return 128 | ||
|
||
def __getitem__(self, item: int): | ||
return self.tokens[item] |
Oops, something went wrong.