You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The reproduction code is based on this previous issue with a small modification to actually print the item itself (a number representing its index) inside the code for simplicity
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
import time
import torch
import horovod.torch as hvd
BATCH_SIZE_PER_GPU = 2
class MyDataset(torch.utils.data.Dataset):
def __init__(self, n):
self.n = n
def __getitem__(self, index):
index = index % self.n
return index
def __len__(self):
return self.n
@hvd.elastic.run
def train(state, data_loader, a):
rank = hvd.rank()
print(f"train rank={rank}")
total_epoch = 100
for epoch in range(state.epoch, total_epoch):
print(f"epoch={epoch}")
print("Epoch {} / {}, Start training".format(epoch, total_epoch))
print(f"train... rank={rank}")
print(f"start enumerate train_loader... rank={rank}")
batch_offset = state.batch
for i, d in enumerate(data_loader):
state.batch = batch_idx = batch_offset + i
if state.batch % 50 == 0:
t1 = time.time()
state.commit()
print(f"time: {time.time() - t1}")
state.check_host_updates()
b = hvd.allreduce(a)
# print(f"b: {b}")
state.train_sampler.record_batch(i, BATCH_SIZE_PER_GPU)
msg = 'Epoch: [{0}][{1}/{2}]\t Consumed {3}'.format(
state.epoch, state.batch, len(data_loader), d)
print(msg)
time.sleep(1)
state.epoch += 1
state.batch = 0
data_loader.sampler.set_epoch(epoch)
state.commit()
def main():
hvd.init()
torch.manual_seed(219)
torch.cuda.set_device(hvd.local_rank())
dataset = MyDataset(2000)
sampler = hvd.elastic.ElasticSampler(dataset, shuffle=False)
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=BATCH_SIZE_PER_GPU,
shuffle=False,
num_workers=2,
sampler=sampler,
worker_init_fn=None,
drop_last=True,
)
a = torch.Tensor([1,2,3,4])
state = hvd.elastic.TorchState(epoch=0,
train_sampler=sampler,
batch=0)
train(state, data_loader, a)
if __name__ == "__main__":
main()
In order to simulate graceful node shutdown then resumption every few seconds, I use the following host-discovery-script (assumes a Ray cluster is present, but easy to adopt to any other config)
#!/usr/bin/env python3
import ray
from datetime import datetime
ray.init('auto',logging_level='critical')
ips = [k['NodeManagerAddress'] for k in ray.nodes() if k['Alive'] and 'head' not in k['NodeManagerHostname'] ]
ips = ips[datetime.now().second//10%2:]
ips = '\n'.join(ips)
print(ips)
It can be easily seen that every time graceful node removal happens items are repeated again and again
After investigating the issue, I found the root cause for this is that upon graceful node removal, everything continues as is and the sampler isn't reseted, so when __iter__ is called (resuming training after the node removal), nothing have been updated, and it is calculated upon old values
Proposed solution
Simply adding self.reset() at the begging of __iter__here solves the problem in my tests
I can open a PR for this if it is deemed a good solution
The text was updated successfully, but these errors were encountered:
Environment:
Checklist:
Bug report:
Reproducing the problem
The reproduction code is based on this previous issue with a small modification to actually print the item itself (a number representing its index) inside the code for simplicity
In order to simulate graceful node shutdown then resumption every few seconds, I use the following
host-discovery-script
(assumes a Ray cluster is present, but easy to adopt to any other config)This should print
It can be easily seen that every time graceful node removal happens items are repeated again and again
After investigating the issue, I found the root cause for this is that upon graceful node removal, everything continues as is and the sampler isn't
reset
ed, so when__iter__
is called (resuming training after the node removal), nothing have been updated, and it is calculated upon old valuesProposed solution
Simply adding
self.reset()
at the begging of__iter__
here solves the problem in my testsI can open a PR for this if it is deemed a good solution
The text was updated successfully, but these errors were encountered: