Skip to content

Commit

Permalink
Convert generator in Sampler back to lazy construction
Browse files Browse the repository at this point in the history
ghstack-source-id: 17692bbbb9a7b9922e17426a543b2d0f6b59e1bf
Pull Request resolved: #63646
  • Loading branch information
ejguan committed Aug 20, 2021
1 parent 913c1f8 commit 89431a1
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 7 deletions.
22 changes: 22 additions & 0 deletions test/test_dataloader.py
Expand Up @@ -1495,6 +1495,28 @@ def test_sampler_reproducibility(self):
):
self.assertEqual(list(fn()), list(fn()))

for sampler in (
RandomSampler(self.dataset, num_samples=5, replacement=True),
RandomSampler(self.dataset, replacement=False),
WeightedRandomSampler(weights, num_samples=5, replacement=True),
WeightedRandomSampler(weights, num_samples=5, replacement=False),
SubsetRandomSampler(range(10)),
):
torch.manual_seed(0)
l1 = list(sampler) + list(sampler)

torch.manual_seed(0)
l2 = list(sampler) + list(sampler)
self.assertEqual(l1, l2)

its = (iter(sampler), iter(sampler))
ls = ([], [])
for idx in range(len(sampler)):
for i in range(2):
if idx == 0:
torch.manual_seed(0)
ls[i].append(next(its[i]))
self.assertEqual(ls[0], ls[1])

def _test_sampler(self, **kwargs):
indices = range(2, 12) # using a regular iterable
Expand Down
18 changes: 11 additions & 7 deletions torch/utils/data/sampler.py
Expand Up @@ -112,15 +112,18 @@ def num_samples(self) -> int:
def __iter__(self) -> Iterator[int]:
n = len(self.data_source)
if self.generator is None:
self.generator = torch.Generator()
self.generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator()
generator.manual_seed(seed)
else:
generator = self.generator

if self.replacement:
for _ in range(self.num_samples // 32):
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=self.generator).tolist()
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=self.generator).tolist()
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
else:
yield from torch.randperm(n, generator=self.generator).tolist()
yield from torch.randperm(n, generator=generator).tolist()

def __len__(self) -> int:
return self.num_samples
Expand All @@ -140,7 +143,8 @@ def __init__(self, indices: Sequence[int], generator=None) -> None:
self.generator = generator

def __iter__(self) -> Iterator[int]:
return (self.indices[i] for i in torch.randperm(len(self.indices), generator=self.generator))
for i in torch.randperm(len(self.indices), generator=self.generator):
yield self.indices[i]

def __len__(self) -> int:
return len(self.indices)
Expand Down Expand Up @@ -183,7 +187,7 @@ def __init__(self, weights: Sequence[float], num_samples: int,

def __iter__(self) -> Iterator[int]:
rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)
return iter(rand_tensor.tolist())
yield from iter(rand_tensor.tolist())

def __len__(self) -> int:
return self.num_samples
Expand Down

0 comments on commit 89431a1

Please sign in to comment.