New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add DynamicBalanceClassSampler #954
Changes from all commits
ffb45ff
2146b89
93a9d92
8573676
f4b21ae
a12a05c
79332e1
ef33956
2ad65c6
d61fc8f
2be40b3
594328f
7c6a68e
f5dafe4
070a4ad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -1,5 +1,6 @@ | ||||
from typing import Iterator, List, Optional, Union | ||||
from collections import Counter | ||||
import logging | ||||
from operator import itemgetter | ||||
from random import choices, sample | ||||
|
||||
|
@@ -192,6 +193,154 @@ def __iter__(self) -> Iterator[int]: | |||
return iter(inds) | ||||
|
||||
|
||||
class DynamicBalanceClassSampler(Sampler): | ||||
""" | ||||
This kind of sampler can be used for classification tasks with significant | ||||
class imbalance. | ||||
|
||||
The idea of this sampler that we start with the original class distribution | ||||
and gradually move to uniform class distribution like with downsampling. | ||||
|
||||
Let's define :math: D_i = #C_i/ #C_min where :math: #C_i is a size of class | ||||
i and :math: #C_min is a size of the rarest class, so :math: D_i define | ||||
class distribution. Also define :math: g(n_epoch) is a exponential | ||||
scheduler. On each epoch current :math: D_i calculated as | ||||
:math: current D_i = D_i ^ g(n_epoch), | ||||
after this data samples according this distribution. | ||||
|
||||
Notes: | ||||
In the end of the training, epochs will contain only | ||||
min_size_class * n_classes examples. So, possible it will not | ||||
necessary to do validation on each epoch. For this reason use | ||||
ControlFlowCallback. | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we also add import path for this callback? |
||||
|
||||
Examples: | ||||
|
||||
>>> import torch | ||||
>>> import numpy as np | ||||
|
||||
>>> from catalyst.data import DynamicBalanceClassSampler | ||||
>>> from torch.utils import data | ||||
|
||||
>>> features = torch.Tensor(np.random.random((200, 100))) | ||||
>>> labels = np.random.randint(0, 4, size=(200,)) | ||||
>>> sampler = DynamicBalanceClassSampler(labels) | ||||
>>> labels = torch.LongTensor(labels) | ||||
>>> dataset = data.TensorDataset(features, labels) | ||||
>>> loader = data.dataloader.DataLoader(dataset, batch_size=8) | ||||
|
||||
>>> for batch in loader: | ||||
>>> b_features, b_labels = batch | ||||
|
||||
Sampler was inspired by https://arxiv.org/abs/1901.06783 | ||||
""" | ||||
|
||||
def __init__( | ||||
self, | ||||
labels: List[Union[int, str]], | ||||
exp_lambda: float = 0.9, | ||||
start_epoch: int = 0, | ||||
max_d: Optional[int] = None, | ||||
mode: Union[str, int] = "downsampling", | ||||
ignore_warning: bool = False, | ||||
): | ||||
""" | ||||
Args: | ||||
labels: list of labels for each elem in the dataset | ||||
exp_lambda: exponent figure for schedule | ||||
start_epoch: start epoch number, can be useful for multi-stage | ||||
experiments | ||||
max_d: if not None, limit on the difference between the most | ||||
frequent and the rarest classes, heuristic | ||||
mode: number of samples per class in the end of training. Must be | ||||
"downsampling" or number. Before change it, make sure that you | ||||
understand how does it work | ||||
ignore_warning: ignore warning about min class size | ||||
""" | ||||
assert isinstance(start_epoch, int) | ||||
assert 0 < exp_lambda < 1, "exp_lambda must be in (0, 1)" | ||||
super().__init__(labels) | ||||
self.exp_lambda = exp_lambda | ||||
if max_d is None: | ||||
max_d = np.inf | ||||
self.max_d = max_d | ||||
self.epoch = start_epoch | ||||
labels = np.array(labels) | ||||
samples_per_class = Counter(labels) | ||||
self.min_class_size = min(samples_per_class.values()) | ||||
|
||||
if self.min_class_size < 100 and not ignore_warning: | ||||
logger = logging.getLogger(__name__) | ||||
logger.warning( | ||||
f"the smallest class contains only" | ||||
f" {self.min_class_size} examples. At the end of" | ||||
f" training, epochs will contain only" | ||||
f" {self.min_class_size * len(samples_per_class)}" | ||||
f" examples" | ||||
) | ||||
|
||||
self.original_d = { | ||||
key: value / self.min_class_size | ||||
for key, value in samples_per_class.items() | ||||
} | ||||
self.label2idxes = { | ||||
label: np.arange(len(labels))[labels == label].tolist() | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe it will be better to use pure python instead of numpy + conversion to list? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @bagxi This code a simple copy from BalanceClassSampler catalyst/catalyst/data/sampler.py Line 36 in dfd21c5
|
||||
for label in set(labels) | ||||
} | ||||
|
||||
if isinstance(mode, int): | ||||
self.min_class_size = mode | ||||
else: | ||||
assert mode == "downsampling" | ||||
|
||||
self.labels = labels | ||||
self._update() | ||||
|
||||
def _update(self) -> None: | ||||
""" | ||||
Update d coefficients | ||||
Returns: None | ||||
""" | ||||
current_d = { | ||||
key: min(value ** self._exp_scheduler(), self.max_d) | ||||
for key, value in self.original_d.items() | ||||
} | ||||
samples_per_classes = { | ||||
key: int(value * self.min_class_size) | ||||
for key, value in current_d.items() | ||||
} | ||||
self.samples_per_classes = samples_per_classes | ||||
self.length = np.sum(list(samples_per_classes.values())) | ||||
self.epoch += 1 | ||||
|
||||
def _exp_scheduler(self) -> float: | ||||
return self.exp_lambda ** self.epoch | ||||
|
||||
def __iter__(self) -> Iterator[int]: | ||||
""" | ||||
Yields: | ||||
indices of stratified sample | ||||
""" | ||||
indices = [] | ||||
for key in sorted(self.label2idxes): | ||||
samples_per_class = self.samples_per_classes[key] | ||||
replace_flag = samples_per_class > len(self.label2idxes[key]) | ||||
indices += np.random.choice( | ||||
self.label2idxes[key], samples_per_class, replace=replace_flag | ||||
).tolist() | ||||
assert len(indices) == self.length | ||||
np.random.shuffle(indices) | ||||
self._update() | ||||
return iter(indices) | ||||
|
||||
def __len__(self) -> int: | ||||
""" | ||||
Returns: | ||||
length of result sample | ||||
""" | ||||
return self.length | ||||
|
||||
|
||||
class MiniEpochSampler(Sampler): | ||||
""" | ||||
Sampler iterates mini epochs from the dataset used by ``mini_epoch_len``. | ||||
|
@@ -424,5 +573,6 @@ def __iter__(self): | |||
"BalanceBatchSampler", | ||||
"MiniEpochSampler", | ||||
"DistributedSamplerWrapper", | ||||
"DynamicBalanceClassSampler", | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you please add it to |
||||
"DynamicLenBatchSampler", | ||||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi, @Dokholyan
Could you please provide a small example for this
DynamicBalanceClassSampler
usage?for example, like here - https://github.com/catalyst-team/catalyst/blob/master/catalyst/data/sampler.py#L306L325
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw, could you also please add this class to the docs? here - https://github.com/catalyst-team/catalyst/blob/master/docs/api/data.rst#samplers
but please keep the alphabetical order ;)