-
-
Notifications
You must be signed in to change notification settings - Fork 385
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
extra sampler + docs #1262
extra sampler + docs #1262
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
from collections import Counter | ||
import logging | ||
from operator import itemgetter | ||
from random import choices, sample | ||
import random | ||
|
||
import numpy as np | ||
import torch | ||
|
@@ -20,6 +20,46 @@ class BalanceClassSampler(Sampler): | |
labels: list of class label for each elem in the dataset | ||
mode: Strategy to balance classes. | ||
Must be one of [downsampling, upsampling] | ||
|
||
Python API examples: | ||
|
||
.. code-block:: python | ||
|
||
import os | ||
from torch import nn, optim | ||
from torch.utils.data import DataLoader | ||
from catalyst import dl | ||
from catalyst.data import ToTensor, BalanceClassSampler | ||
from catalyst.contrib.datasets import MNIST | ||
|
||
train_data = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()) | ||
train_labels = train_data.targets.cpu().numpy().tolist() | ||
train_sampler = BalanceClassSampler(train_labels, mode=5000) | ||
valid_data = MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()) | ||
|
||
loaders = { | ||
"train": DataLoader(train_data, sampler=train_sampler, batch_size=32), | ||
"valid": DataLoader(valid_data, batch_size=32), | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [pep8] reported by reviewdog 🐶 |
||
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) | ||
criterion = nn.CrossEntropyLoss() | ||
optimizer = optim.Adam(model.parameters(), lr=0.02) | ||
|
||
runner = dl.SupervisedRunner() | ||
# model training | ||
runner.train( | ||
model=model, | ||
criterion=criterion, | ||
optimizer=optimizer, | ||
loaders=loaders, | ||
num_epochs=1, | ||
logdir="./logs", | ||
valid_loader="valid", | ||
valid_metric="loss", | ||
minimize_valid_metric=True, | ||
verbose=True, | ||
) | ||
""" | ||
|
||
def __init__(self, labels: List[int], mode: Union[str, int] = "downsampling"): | ||
|
@@ -165,24 +205,165 @@ def __iter__(self) -> Iterator[int]: | |
""" | ||
inds = [] | ||
|
||
for cls_id in sample(self._classes, self._num_epoch_classes): | ||
for cls_id in random.sample(self._classes, self._num_epoch_classes): | ||
all_cls_inds = find_value_ids(self._labels, cls_id) | ||
|
||
# we've checked in __init__ that this value must be > 1 | ||
num_samples_exists = len(all_cls_inds) | ||
|
||
if num_samples_exists < self._k: | ||
selected_inds = sample(all_cls_inds, k=num_samples_exists) + choices( | ||
selected_inds = random.sample(all_cls_inds, k=num_samples_exists) + random.choices( | ||
all_cls_inds, k=self._k - num_samples_exists | ||
) | ||
else: | ||
selected_inds = sample(all_cls_inds, k=self._k) | ||
selected_inds = random.sample(all_cls_inds, k=self._k) | ||
|
||
inds.extend(selected_inds) | ||
|
||
return iter(inds) | ||
|
||
|
||
class BatchBalanceClassSampler(Sampler): | ||
""" | ||
BatchSampler version of BalanceBatchSampler. | ||
This kind of sampler can be used for both metric learning and classification task. | ||
|
||
BatchSampler with the given strategy for the C unique classes dataset: | ||
- Selection `num_classes` of C classes for each batch | ||
- Selection `num_samples` instances for each class in the batch | ||
The epoch ends after `num_batches`. | ||
So, the batch sise is `num_classes` * `num_samples`. | ||
|
||
One of the purposes of this sampler is to be used for | ||
forming triplets and pos/neg pairs inside the batch. | ||
To guarante existance of these pairs in the batch, | ||
`num_classes` and `num_samples` should be > 1. (1) | ||
|
||
This type of sampling can be found in the classical paper of Person Re-Id, | ||
where P (`num_classes`) equals 32 and K (`num_samples`) equals 4: | ||
`In Defense of the Triplet Loss for Person Re-Identification`_. | ||
|
||
Args: | ||
labels: list of classes labeles for each elem in the dataset | ||
num_classes: number of classes in a batch, should be > 1 | ||
num_samples: number of instances of each class in a batch, should be > 1 | ||
num_batches: number of batches in epoch | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [pep8] reported by reviewdog 🐶 |
||
(default = len(labels) // (num_classes * num_samples)) | ||
|
||
.. _In Defense of the Triplet Loss for Person Re-Identification: | ||
https://arxiv.org/abs/1703.07737 | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [pep8] reported by reviewdog 🐶 |
||
Python API examples: | ||
|
||
.. code-block:: python | ||
|
||
import os | ||
from torch import nn, optim | ||
from torch.utils.data import DataLoader | ||
from catalyst import dl | ||
from catalyst.data import ToTensor, BatchBalanceClassSampler | ||
from catalyst.contrib.datasets import MNIST | ||
|
||
train_data = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()) | ||
train_labels = train_data.targets.cpu().numpy().tolist() | ||
train_sampler = BatchBalanceClassSampler(train_labels, num_classes=10, num_samples=4) | ||
valid_data = MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()) | ||
|
||
loaders = { | ||
"train": DataLoader(train_data, batch_sampler=train_sampler), | ||
"valid": DataLoader(valid_data, batch_size=32), | ||
} | ||
|
||
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) | ||
criterion = nn.CrossEntropyLoss() | ||
optimizer = optim.Adam(model.parameters(), lr=0.02) | ||
|
||
runner = dl.SupervisedRunner() | ||
# model training | ||
runner.train( | ||
model=model, | ||
criterion=criterion, | ||
optimizer=optimizer, | ||
loaders=loaders, | ||
num_epochs=1, | ||
logdir="./logs", | ||
valid_loader="valid", | ||
valid_metric="loss", | ||
minimize_valid_metric=True, | ||
verbose=True, | ||
) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
labels: Union[List[int], np.ndarray], | ||
num_classes: int, | ||
num_samples: int, | ||
num_batches: int = None, | ||
): | ||
"""Sampler initialisation.""" | ||
super().__init__(labels) | ||
classes = set(labels) | ||
|
||
assert isinstance(num_classes, int) and isinstance(num_samples, int) | ||
assert (1 < num_classes <= len(classes)) and (1 < num_samples) | ||
assert all( | ||
n > 1 for n in Counter(labels).values() | ||
), "Each class shoud contain at least 2 instances to fit (1)" | ||
|
||
labels = np.array(labels) | ||
self._labels = list(set(labels.tolist())) | ||
self._num_classes = num_classes | ||
self._num_samples = num_samples | ||
self._batch_size = self._num_classes * self._num_samples | ||
self._num_batches = num_batches or len(labels) // self._batch_size | ||
self.lbl2idx = { | ||
label: np.arange(len(labels))[labels == label].tolist() for label in set(labels) | ||
} | ||
|
||
@property | ||
def batch_size(self) -> int: | ||
""" | ||
Returns: | ||
this value should be used in DataLoader as batch size | ||
""" | ||
return self._batch_size | ||
|
||
@property | ||
def batches_in_epoch(self) -> int: | ||
""" | ||
Returns: | ||
number of batches in an epoch | ||
""" | ||
return self._num_batches | ||
|
||
def __len__(self) -> int: | ||
""" | ||
Returns: | ||
number of samples in an epoch | ||
""" | ||
return self._num_batches # * self._batch_size | ||
|
||
def __iter__(self) -> Iterator[int]: | ||
""" | ||
Returns: | ||
indeces for sampling dataset elems during an epoch | ||
""" | ||
indices = [] | ||
for _ in range(self._num_batches): | ||
batch_indices = [] | ||
classes_for_batch = random.sample(self._labels, self._num_classes) | ||
while self._num_classes != len(set(classes_for_batch)): | ||
classes_for_batch = random.sample(self._labels, self._num_classes) | ||
for cls_id in classes_for_batch: | ||
replace_flag = self._num_samples > len(self.lbl2idx[cls_id]) | ||
batch_indices += np.random.choice( | ||
self.lbl2idx[cls_id], self._num_samples, replace=replace_flag | ||
).tolist() | ||
indices.append(batch_indices) | ||
return iter(indices) | ||
|
||
|
||
class DynamicBalanceClassSampler(Sampler): | ||
""" | ||
This kind of sampler can be used for classification tasks with significant | ||
|
@@ -552,6 +733,7 @@ def __iter__(self) -> Iterator[int]: | |
__all__ = [ | ||
"BalanceClassSampler", | ||
"BalanceBatchSampler", | ||
"BatchBalanceClassSampler", | ||
"DistributedSamplerWrapper", | ||
"DynamicBalanceClassSampler", | ||
"DynamicLenBatchSampler", | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[pep8] reported by reviewdog 🐶
W293 blank line contains whitespace