Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add deepspeed grad ckpt #30233

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

SeunghyunSEO
Copy link

@SeunghyunSEO SeunghyunSEO commented Apr 13, 2024

hi, @younesbelkada , sorry for the late PR.
its been busy days ;-)

i wrote some draft lines for deepspeed gradient checkpointing,
i tried to change original code as little as i can, but 2 things were inevitable

    1. add gradient_checkpointing_kwargs as member variables for parsing num_checkpoints and use_deepspeed_grad_ckpt keys to check use advanced grad ckpt or not
    1. second thing is not about code but deepspeed checkpoint function force me to use deepspeed.init because of this line.

here is my code for sanity checking advanced activation checkpointing

import os
import copy
import random
from pdb import set_trace as Tra

import torch
import deepspeed
from deepspeed import get_accelerator
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments


def _reset_seeds(seed=1234):
    torch.manual_seed(seed)
    random.seed(seed)

def _get_dummy_inputs(B=2, T=10, seed=1234):
    _reset_seeds(seed)
    return {
        'input_ids': torch.rand(B, T).long().cuda(),
        'attention_mask' : torch.ones(B, T).cuda(),
    }

def _get_optimizer(model):
    return torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

if __name__ == "__main__":

    ## set distributed arguments (because deepspeed ckpt check rank for printing...)
    local_rank = int(os.environ["LOCAL_RANK"])
    get_accelerator().set_device(local_rank)
    device = torch.device(get_accelerator().device_name(), local_rank)
    deepspeed.init_distributed()

    ## get model, tokenizer and dummy optimizer
    model_path = "lmsys/vicuna-7b-v1.5"
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path).cuda().train() # fp16 can cause nan 
    optimizer = _get_optimizer(model)

    ## vanilla gradient checkpointing
    model.gradient_checkpointing_enable()
    inputs = _get_dummy_inputs()
    optimizer.zero_grad()
    output1 = model(**inputs)
    output1.logits.sum().backward()
    tmp_grad1 = copy.deepcopy(model.model.layers[0].self_attn.q_proj.weight.grad.cpu())

    ## deepspeed CPU offloading, selective gradient checkpointing
    gradient_checkpointing_kwargs = {
        "use_deepspeed_grad_ckpt" : True,
        "num_checkpoints" : 4,
        "checkpoint_in_cpu" : True,
    }
    model.gradient_checkpointing_enable(gradient_checkpointing_kwargs)
    inputs = _get_dummy_inputs()
    optimizer.zero_grad()
    output2 = model(**inputs)
    output2.logits.sum().backward()
    tmp_grad2 = copy.deepcopy(model.model.layers[0].self_attn.q_proj.weight.grad.cpu())

    grad_allclose = torch.allclose(tmp_grad1, tmp_grad2, rtol=1e-05, atol=1e-08)
    assert grad_allclose
    print(f'''
    grad_allclose : {grad_allclose}
    tmp_grad1     : {tmp_grad1}
    tmp_grad2     : {tmp_grad2}
    ''')

and the result was like

/path/to/dir/transformers$ python -m torch.distributed.launch test_act_ckpt.py
/path/to/dir/venv/transformers_pr/lib/python3.10/site-packages/torch/distributed/launch.py:183: FutureWarning: The module torch.distributed.launch is deprecated
and will be removed in future. Use torchrun.
Note that --use-env is set by default in torchrun.
If your script expects `--local-rank` argument to be set, please
change it to read from `os.environ['LOCAL_RANK']` instead. See 
https://pytorch.org/docs/stable/distributed.html#launch-utility for 
further instructions

  warnings.warn(
[2024-04-13 08:46:25,432] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-04-13 08:46:26,405] [INFO] [comm.py:637:init_distributed] cdb=None
[2024-04-13 08:46:26,405] [INFO] [comm.py:668:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.50s/it]
/path/to/dir/transformers/src/transformers/generation/configuration_utils.py:497: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.
  warnings.warn(
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
[2024-04-13 08:46:38,841] [INFO] [checkpointing.py:539:forward] Activation Checkpointing Information
[2024-04-13 08:46:38,841] [INFO] [checkpointing.py:540:forward] ----Partition Activations False, CPU CHECKPOINTING True
[2024-04-13 08:46:38,841] [INFO] [checkpointing.py:541:forward] ----contiguous Memory Checkpointing False with 4 total layers
[2024-04-13 08:46:38,841] [INFO] [checkpointing.py:543:forward] ----Synchronization False
[2024-04-13 08:46:38,841] [INFO] [checkpointing.py:544:forward] ----Profiling time in checkpointing False

    grad_allclose : True
    tmp_grad1     : tensor([[-2.6913e-12,  1.5395e-12,  8.9857e-13,  ..., -8.0048e-13,
          5.5549e-12, -4.5321e-13],
        [ 4.7248e-12, -2.7027e-12, -1.5775e-12,  ...,  1.4053e-12,
         -9.7519e-12,  7.9564e-13],
        [-4.4621e-12,  2.5525e-12,  1.4898e-12,  ..., -1.3271e-12,
          9.2097e-12, -7.5140e-13],
        ...,
        [-1.0189e-12,  5.8284e-13,  3.4018e-13,  ..., -3.0304e-13,
          2.1030e-12, -1.7158e-13],
        [-2.7976e-12,  1.6004e-12,  9.3407e-13,  ..., -8.3210e-13,
          5.7743e-12, -4.7112e-13],
        [ 7.3264e-12, -4.1910e-12, -2.4461e-12,  ...,  2.1791e-12,
         -1.5122e-11,  1.2337e-12]])
    tmp_grad2     : tensor([[-2.6913e-12,  1.5395e-12,  8.9857e-13,  ..., -8.0048e-13,
          5.5549e-12, -4.5321e-13],
        [ 4.7248e-12, -2.7027e-12, -1.5775e-12,  ...,  1.4053e-12,
         -9.7519e-12,  7.9564e-13],
        [-4.4621e-12,  2.5525e-12,  1.4898e-12,  ..., -1.3271e-12,
          9.2097e-12, -7.5140e-13],
        ...,
        [-1.0189e-12,  5.8284e-13,  3.4018e-13,  ..., -3.0304e-13,
          2.1030e-12, -1.7158e-13],
        [-2.7976e-12,  1.6004e-12,  9.3407e-13,  ..., -8.3210e-13,
          5.7743e-12, -4.7112e-13],
        [ 7.3264e-12, -4.1910e-12, -2.4461e-12,  ...,  2.1791e-12,
         -1.5122e-11,  1.2337e-12]])

my environment was

python -c "import torch; print(torch.__version__); \
import transformers; print(transformers.__version__); \
import deepspeed; print(deepspeed.__version__)"
2.2.2+cu118
4.40.0.dev0
[2024-04-13 08:44:58,468] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
0.14.0

@SeunghyunSEO SeunghyunSEO changed the title draft for deepspeed grad ckpt add deepspeed grad ckpt Apr 13, 2024
@github-actions github-actions bot closed this May 22, 2024
@huggingface huggingface deleted a comment from github-actions bot May 22, 2024
@amyeroberts amyeroberts reopened this May 22, 2024
@amyeroberts
Copy link
Collaborator

cc @younesbelkada

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants