Skip to content

Commit

Permalink
Implement len in IterableDatasetShard (huggingface#13780)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger authored and Alberto B茅gu茅 committed Jan 27, 2022
1 parent e27fbf2 commit b6207de
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/transformers/trainer_pt_utils.py
Expand Up @@ -152,8 +152,6 @@ def nested_xla_mesh_reduce(tensors, name):

if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
if tensors.ndim == 0:
tensors = tensors[None]
return xm.mesh_reduce(name, tensors, torch.cat)
else:
raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
Expand Down Expand Up @@ -777,9 +775,9 @@ def __iter__(self):
def __len__(self):
# Will raise an error if the underlying dataset is not sized.
if self.drop_last:
return (len(self.dataset) // (self.batch_size * self.num_processes)) * self.batch_size
return len(self.dataset) // self.num_processes
else:
return math.ceil(len(self.dataset) / (self.batch_size * self.num_processes)) * self.batch_size
return math.ceil(len(self.dataset) / self.num_processes)


# In order to keep `trainer.py` compact and easy to understand, place any secondary PT Trainer
Expand Down

0 comments on commit b6207de

Please sign in to comment.