Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Elastic Horovod] It will loss some indices of processed samples in hvd.elastic.state when some nodes dropped #3143

Closed
hgx1991 opened this issue Sep 1, 2021 · 0 comments · Fixed by #3144
Labels

Comments

@hgx1991
Copy link
Contributor

hgx1991 commented Sep 1, 2021

Environment:

  1. Framework: PyTorch
  2. Framework version: 1.7.0+cu101
  3. Horovod version: 0.22.1

Checklist:

  1. Did you search issues to find if somebody asked this question before?
    yes
  2. If your question is about hang, did you read this doc?
    yes
  3. If your question is about docker, did you read this doc?
    yes
  4. Did you check if you question is answered in the troubleshooting guide?
    yes

Bug report:
horovod/torch/elastic/state.py

class SamplerStateHandler(StateHandler):
    def __init__(self, sampler):
        super().__init__(sampler)
        self._saved_sampler_state = copy.deepcopy(self.value.state_dict())

    def save(self):
        self._saved_sampler_state = copy.deepcopy(self.value.state_dict())

    def restore(self):
        self.value.load_state_dict(self._saved_sampler_state)

    def sync(self):
        # Get the set of processed indices from all workers
        world_processed_indices = _union(allgather_object(self.value.processed_indices))

        # Replace local processed indices with global indices
        state_dict = self.value.state_dict()
        state_dict['processed_indices'] = world_processed_indices

        # Broadcast and load the state to make sure we're all in sync
        self.value.load_state_dict(broadcast_object(state_dict))

when state.commit() is called, function save() above only save the local state of ElasticSampler locally. If one node is dropped by some reason, the indices of processed samples on this node is lost. So after restart and sync, those samples whill be processed again, which is not what we want.

Steps to reproduce

  1. Add some log to SamplerStateHandler.sync()
class SamplerStateHandler(StateHandler):
    ......
    def sync(self):
        # Get the set of processed indices from all workers
        world_processed_indices = _union(allgather_object(self.value.processed_indices))
        print(f"world_processed_indices: {world_processed_indices }")
        ......
  1. Use the code below to reproduce. Note not do shuffle for the convenience of observation.
#! /usr/bin/env python3
# -*- coding: utf-8 -*-


import time
import torch
import horovod.torch as hvd


BATCH_SIZE_PER_GPU = 2


class MyDataset(torch.utils.data.Dataset):
    def __init__(self, n):
        self.n = n
    
    def __getitem__(self, index):
        index = index % self.n
        return index

    def __len__(self):
        return self.n


@hvd.elastic.run
def train(state, data_loader, a):
    rank = hvd.rank()
    print(f"train rank={rank}")
    total_epoch = 100
    for epoch in range(state.epoch, total_epoch):
        print(f"epoch={epoch}")

        print("Epoch {} / {}, Start training".format(epoch, total_epoch))
        print(f"train... rank={rank}")
        print(f"start enumerate train_loader... rank={rank}")
        batch_offset = state.batch
        for i, d in enumerate(data_loader):
            state.batch = batch_idx = batch_offset + i
            if state.batch % 5 == 0:
                t1 = time.time()
                state.commit()
                print(f"time: {time.time() - t1}")
            state.check_host_updates()
            b = hvd.allreduce(a)
            print(f"b: {b}")
            state.train_sampler.record_batch(i, BATCH_SIZE_PER_GPU)
    
            # if rank == 0:
            msg = 'Epoch: [{0}][{1}/{2}]\t'.format(
                state.epoch, state.batch, len(data_loader))

            print(msg)
            time.sleep(0.5)
        state.epoch += 1
        state.batch = 0
        data_loader.sampler.set_epoch(epoch)
        state.commit()


def main():
    hvd.init()
    torch.manual_seed(219)
    torch.cuda.set_device(hvd.local_rank())

    dataset = MyDataset(2000)
    sampler = hvd.elastic.ElasticSampler(dataset, shuffle=False)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=BATCH_SIZE_PER_GPU,
        shuffle=False,
        num_workers=2,
        sampler=sampler,
        worker_init_fn=None,
        drop_last=True,
    )
    a = torch.Tensor([1,2,3,4])
    state = hvd.elastic.TorchState(epoch=0,
                                   train_sampler=sampler,
                                   batch=0)

    train(state, data_loader, a)


if __name__ == "__main__":
    main()
  1. Use elastic horovod to run the above code on some nodes, for example, on 3 nodes.
  2. Kill the processes on one node after a while.
  3. Observe the log we added to SamplerStateHandler.sync().

Sloutions

  1. sloution 1
    Save the global state. To get the global state of ElasticSampler on every node, function save should call function sync first.
class SamplerStateHandler(StateHandler):
    ......
    def save(self):
        self.sync()
        self._saved_sampler_state = copy.deepcopy(self.value.state_dict())
    ......

But this will cause state.commit() to take a long time.
2. sloution 2
Maybe we can save the number of processed samples instead of save all the processed indices. The number of processed samples can be calculated locally by batch_size and num_replicas.

@hgx1991 hgx1991 added the bug label Sep 1, 2021
@hgx1991 hgx1991 changed the title [Elastic Horovod] It will loss some indices of trained samples in hvd.elastic.state when some nodes dropped [Elastic Horovod] It will loss some indices of processed samples in hvd.elastic.state when some nodes dropped Sep 1, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Development

Successfully merging a pull request may close this issue.

1 participant