You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
classSamplerStateHandler(StateHandler):
def__init__(self, sampler):
super().__init__(sampler)
self._saved_sampler_state=copy.deepcopy(self.value.state_dict())
defsave(self):
self._saved_sampler_state=copy.deepcopy(self.value.state_dict())
defrestore(self):
self.value.load_state_dict(self._saved_sampler_state)
defsync(self):
# Get the set of processed indices from all workersworld_processed_indices=_union(allgather_object(self.value.processed_indices))
# Replace local processed indices with global indicesstate_dict=self.value.state_dict()
state_dict['processed_indices'] =world_processed_indices# Broadcast and load the state to make sure we're all in syncself.value.load_state_dict(broadcast_object(state_dict))
when state.commit() is called, function save() above only save the local state of ElasticSampler locally. If one node is dropped by some reason, the indices of processed samples on this node is lost. So after restart and sync, those samples whill be processed again, which is not what we want.
Steps to reproduce
Add some log to SamplerStateHandler.sync()
classSamplerStateHandler(StateHandler):
......
defsync(self):
# Get the set of processed indices from all workersworld_processed_indices=_union(allgather_object(self.value.processed_indices))
print(f"world_processed_indices: {world_processed_indices}")
......
Use the code below to reproduce. Note not do shuffle for the convenience of observation.
But this will cause state.commit() to take a long time.
2. sloution 2
Maybe we can save the number of processed samples instead of save all the processed indices. The number of processed samples can be calculated locally by batch_size and num_replicas.
The text was updated successfully, but these errors were encountered:
hgx1991
changed the title
[Elastic Horovod] It will loss some indices of trained samples in hvd.elastic.state when some nodes dropped
[Elastic Horovod] It will loss some indices of processed samples in hvd.elastic.state when some nodes dropped
Sep 1, 2021
Environment:
Checklist:
yes
yes
yes
yes
Bug report:
horovod/torch/elastic/state.py
when
state.commit()
is called, functionsave()
above only save the local state ofElasticSampler
locally. If one node is dropped by some reason, the indices of processed samples on this node is lost. So after restart and sync, those samples whill be processed again, which is not what we want.Steps to reproduce
SamplerStateHandler.sync()
SamplerStateHandler.sync()
.Sloutions
Save the global state. To get the global state of ElasticSampler on every node, function
save
should call functionsync
first.But this will cause
state.commit()
to take a long time.2. sloution 2
Maybe we can save the number of processed samples instead of save all the processed indices. The number of processed samples can be calculated locally by
batch_size
andnum_replicas
.The text was updated successfully, but these errors were encountered: