Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Elastic] Graceful node shutdown results in repeated training samples + proposed solution #3399

Open
ASDen opened this issue Feb 3, 2022 · 0 comments
Labels

Comments

@ASDen
Copy link
Contributor

ASDen commented Feb 3, 2022

Environment:

  1. Framework: PyTorch
  2. Framework version: 1.10
  3. Horovod version: latest master
  4. Python version: 3.7
  5. OS and version: Ubuntu 20.04

Checklist:

  1. Did you search issues to find if somebody asked this question before? yes
  2. If your question is about hang, did you read this doc? yes
  3. If your question is about docker, did you read this doc? yes
  4. Did you check if you question is answered in the troubleshooting guide? yes

Bug report:

Reproducing the problem

The reproduction code is based on this previous issue with a small modification to actually print the item itself (a number representing its index) inside the code for simplicity

#! /usr/bin/env python3
# -*- coding: utf-8 -*-


import time
import torch
import horovod.torch as hvd


BATCH_SIZE_PER_GPU = 2


class MyDataset(torch.utils.data.Dataset):
    def __init__(self, n):
        self.n = n
    
    def __getitem__(self, index):
        index = index % self.n
        return index

    def __len__(self):
        return self.n


@hvd.elastic.run
def train(state, data_loader, a):
    rank = hvd.rank()
    print(f"train rank={rank}")
    total_epoch = 100
    for epoch in range(state.epoch, total_epoch):
        print(f"epoch={epoch}")

        print("Epoch {} / {}, Start training".format(epoch, total_epoch))
        print(f"train... rank={rank}")
        print(f"start enumerate train_loader... rank={rank}")
        batch_offset = state.batch
        for i, d in enumerate(data_loader):
            state.batch = batch_idx = batch_offset + i
            if state.batch % 50 == 0:
                t1 = time.time()
                state.commit()
                print(f"time: {time.time() - t1}")
            state.check_host_updates()
            b = hvd.allreduce(a)
            # print(f"b: {b}")
            state.train_sampler.record_batch(i, BATCH_SIZE_PER_GPU)
    
            msg = 'Epoch: [{0}][{1}/{2}]\t Consumed {3}'.format(
                state.epoch, state.batch, len(data_loader), d)

            print(msg)
            time.sleep(1)
        state.epoch += 1
        state.batch = 0
        data_loader.sampler.set_epoch(epoch)
        state.commit()


def main():
    hvd.init()
    torch.manual_seed(219)
    torch.cuda.set_device(hvd.local_rank())

    dataset = MyDataset(2000)
    sampler = hvd.elastic.ElasticSampler(dataset, shuffle=False)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=BATCH_SIZE_PER_GPU,
        shuffle=False,
        num_workers=2,
        sampler=sampler,
        worker_init_fn=None,
        drop_last=True,
    )
    a = torch.Tensor([1,2,3,4])
    state = hvd.elastic.TorchState(epoch=0,
                                   train_sampler=sampler,
                                   batch=0)

    train(state, data_loader, a)


if __name__ == "__main__":
    main()

In order to simulate graceful node shutdown then resumption every few seconds, I use the following host-discovery-script (assumes a Ray cluster is present, but easy to adopt to any other config)

#!/usr/bin/env python3

import ray
from datetime import datetime

ray.init('auto',logging_level='critical')
ips = [k['NodeManagerAddress'] for k in ray.nodes() if k['Alive'] and 'head' not in k['NodeManagerHostname'] ]
ips = ips[datetime.now().second//10%2:]
ips = '\n'.join(ips)
print(ips)

This should print

[0]<stdout>:train rank=0
[0]<stdout>:epoch=0
[0]<stdout>:Epoch 0 / 100, Start training
[0]<stdout>:train... rank=0
[0]<stdout>:start enumerate train_loader... rank=0
[1]<stdout>:train rank=1
[1]<stdout>:epoch=0
[1]<stdout>:Epoch 0 / 100, Start training
[1]<stdout>:train... rank=1
[1]<stdout>:start enumerate train_loader... rank=1
[0]<stdout>:time: 0.019937753677368164
[1]<stdout>:time: 0.007147789001464844
[1]<stdout>:Epoch: [0][0/500]	 Consumed tensor([1, 3])
[0]<stdout>:Epoch: [0][0/500]	 Consumed tensor([0, 2])
[0]<stdout>:Epoch: [0][1/500]	 Consumed tensor([4, 6])
[1]<stdout>:Epoch: [0][1/500]	 Consumed tensor([5, 7])
[0]<stdout>:Epoch: [0][2/500]	 Consumed tensor([ 8, 10])
[1]<stdout>:Epoch: [0][2/500]	 Consumed tensor([ 9, 11])
[0]<stdout>:Epoch: [0][3/500]	 Consumed tensor([12, 14])
[1]<stdout>:Epoch: [0][3/500]	 Consumed tensor([13, 15])
[0]<stdout>:train rank=0
[0]<stdout>:epoch=0
[0]<stdout>:Epoch 0 / 100, Start training
[0]<stdout>:train... rank=0
[0]<stdout>:start enumerate train_loader... rank=0
[1]<stdout>:train rank=1
[1]<stdout>:epoch=0
[1]<stdout>:Epoch 0 / 100, Start training
[1]<stdout>:train... rank=1
[1]<stdout>:start enumerate train_loader... rank=1
[0]<stdout>:Epoch: [0][4/500]	 Consumed tensor([0, 2])
[1]<stdout>:Epoch: [0][4/500]	 Consumed tensor([1, 3])
[1]<stdout>:Epoch: [0][5/500]	 Consumed tensor([5, 7])
[0]<stdout>:Epoch: [0][5/500]	 Consumed tensor([4, 6])
[1]<stdout>:Epoch: [0][6/500]	 Consumed tensor([ 9, 11])
[0]<stdout>:Epoch: [0][6/500]	 Consumed tensor([ 8, 10])
[1]<stdout>:Epoch: [0][7/500]	 Consumed tensor([13, 15])
[0]<stdout>:Epoch: [0][7/500]	 Consumed tensor([12, 14])
[1]<stdout>:Epoch: [0][8/500]	 Consumed tensor([17, 19])
[0]<stdout>:Epoch: [0][8/500]	 Consumed tensor([16, 18])
[0]<stdout>:Epoch: [0][9/500]	 Consumed tensor([20, 22])
[1]<stdout>:Epoch: [0][9/500]	 Consumed tensor([21, 23])
[1]<stdout>:Epoch: [0][10/500]	 Consumed tensor([25, 27])
[0]<stdout>:Epoch: [0][10/500]	 Consumed tensor([24, 26])
[0]<stdout>:Epoch: [0][11/500]	 Consumed tensor([28, 30])
[1]<stdout>:Epoch: [0][11/500]	 Consumed tensor([29, 31])
[1]<stdout>:Epoch: [0][12/500]	 Consumed tensor([33, 35])
[0]<stdout>:Epoch: [0][12/500]	 Consumed tensor([32, 34])
[0]<stdout>:Epoch: [0][13/500]	 Consumed tensor([36, 38])
[1]<stdout>:Epoch: [0][13/500]	 Consumed tensor([37, 39])
[0]<stdout>:train rank=0
[0]<stdout>:epoch=0
[0]<stdout>:Epoch 0 / 100, Start training
[0]<stdout>:train... rank=0
[0]<stdout>:start enumerate train_loader... rank=0
[1]<stdout>:train rank=1
[1]<stdout>:epoch=0
[1]<stdout>:Epoch 0 / 100, Start training
[1]<stdout>:train... rank=1
[1]<stdout>:start enumerate train_loader... rank=1
[1]<stdout>:Epoch: [0][14/500]	 Consumed tensor([1, 3])
[0]<stdout>:Epoch: [0][14/500]	 Consumed tensor([0, 2])
[1]<stdout>:Epoch: [0][15/500]	 Consumed tensor([5, 7])
[0]<stdout>:Epoch: [0][15/500]	 Consumed tensor([4, 6])
[0]<stdout>:Epoch: [0][16/500]	 Consumed tensor([ 8, 10])
[1]<stdout>:Epoch: [0][16/500]	 Consumed tensor([ 9, 11])
[1]<stdout>:Epoch: [0][17/500]	 Consumed tensor([13, 15])
[0]<stdout>:Epoch: [0][17/500]	 Consumed tensor([12, 14])
[0]<stdout>:Epoch: [0][18/500]	 Consumed tensor([16, 18])
[1]<stdout>:Epoch: [0][18/500]	 Consumed tensor([17, 19])
[0]<stdout>:Epoch: [0][19/500]	 Consumed tensor([20, 22])
[1]<stdout>:Epoch: [0][19/500]	 Consumed tensor([21, 23])
[0]<stdout>:Epoch: [0][20/500]	 Consumed tensor([24, 26])
[1]<stdout>:Epoch: [0][20/500]	 Consumed tensor([25, 27])
[1]<stdout>:Epoch: [0][21/500]	 Consumed tensor([29, 31])
[0]<stdout>:Epoch: [0][21/500]	 Consumed tensor([28, 30])
[0]<stdout>:Epoch: [0][22/500]	 Consumed tensor([32, 34])
[1]<stdout>:Epoch: [0][22/500]	 Consumed tensor([33, 35])
[0]<stdout>:Epoch: [0][23/500]	 Consumed tensor([36, 38])
[1]<stdout>:Epoch: [0][23/500]	 Consumed tensor([37, 39])
[0]<stdout>:train rank=0
[0]<stdout>:epoch=0
[0]<stdout>:Epoch 0 / 100, Start training
[0]<stdout>:train... rank=0
[0]<stdout>:start enumerate train_loader... rank=0
[1]<stdout>:train rank=1
[1]<stdout>:epoch=0
[1]<stdout>:Epoch 0 / 100, Start training
[1]<stdout>:train... rank=1
[1]<stdout>:start enumerate train_loader... rank=1
[0]<stdout>:Epoch: [0][24/500]	 Consumed tensor([0, 2])
[1]<stdout>:Epoch: [0][24/500]	 Consumed tensor([1, 3])
[1]<stdout>:Epoch: [0][25/500]	 Consumed tensor([5, 7])
[0]<stdout>:Epoch: [0][25/500]	 Consumed tensor([4, 6])
[1]<stdout>:Epoch: [0][26/500]	 Consumed tensor([ 9, 11])
[0]<stdout>:Epoch: [0][26/500]	 Consumed tensor([ 8, 10])
[1]<stdout>:Epoch: [0][27/500]	 Consumed tensor([13, 15])

It can be easily seen that every time graceful node removal happens items are repeated again and again

After investigating the issue, I found the root cause for this is that upon graceful node removal, everything continues as is and the sampler isn't reseted, so when __iter__ is called (resuming training after the node removal), nothing have been updated, and it is calculated upon old values

Proposed solution

Simply adding self.reset() at the begging of __iter__ here solves the problem in my tests
I can open a PR for this if it is deemed a good solution

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Development

Successfully merging a pull request may close this issue.

1 participant