Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add DynamicBalanceClassSampler #954

Merged
merged 15 commits into from Nov 9, 2020
Merged
1 change: 1 addition & 0 deletions catalyst/data/__init__.py
Expand Up @@ -22,6 +22,7 @@
BalanceBatchSampler,
DistributedSamplerWrapper,
DynamicLenBatchSampler,
DynamicBalanceClassSampler,
MiniEpochSampler,
)
from catalyst.data.sampler_inbatch import (
Expand Down
150 changes: 150 additions & 0 deletions catalyst/data/sampler.py
@@ -1,5 +1,6 @@
from typing import Iterator, List, Optional, Union
from collections import Counter
import logging
from operator import itemgetter
from random import choices, sample

Expand Down Expand Up @@ -192,6 +193,154 @@ def __iter__(self) -> Iterator[int]:
return iter(inds)


class DynamicBalanceClassSampler(Sampler):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi, @Dokholyan
Could you please provide a small example for this DynamicBalanceClassSampler usage?
for example, like here - https://github.com/catalyst-team/catalyst/blob/master/catalyst/data/sampler.py#L306L325

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, could you also please add this class to the docs? here - https://github.com/catalyst-team/catalyst/blob/master/docs/api/data.rst#samplers
but please keep the alphabetical order ;)

"""
This kind of sampler can be used for classification tasks with significant
class imbalance.

The idea of this sampler that we start with the original class distribution
and gradually move to uniform class distribution like with downsampling.

Let's define :math: D_i = #C_i/ #C_min where :math: #C_i is a size of class
i and :math: #C_min is a size of the rarest class, so :math: D_i define
class distribution. Also define :math: g(n_epoch) is a exponential
scheduler. On each epoch current :math: D_i calculated as
:math: current D_i = D_i ^ g(n_epoch),
after this data samples according this distribution.

Notes:
In the end of the training, epochs will contain only
min_size_class * n_classes examples. So, possible it will not
necessary to do validation on each epoch. For this reason use
ControlFlowCallback.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also add import path for this callback?


Examples:

>>> import torch
>>> import numpy as np

>>> from catalyst.data import DynamicBalanceClassSampler
>>> from torch.utils import data

>>> features = torch.Tensor(np.random.random((200, 100)))
>>> labels = np.random.randint(0, 4, size=(200,))
>>> sampler = DynamicBalanceClassSampler(labels)
>>> labels = torch.LongTensor(labels)
>>> dataset = data.TensorDataset(features, labels)
>>> loader = data.dataloader.DataLoader(dataset, batch_size=8)

>>> for batch in loader:
>>> b_features, b_labels = batch

Sampler was inspired by https://arxiv.org/abs/1901.06783
"""

def __init__(
self,
labels: List[Union[int, str]],
exp_lambda: float = 0.9,
start_epoch: int = 0,
max_d: Optional[int] = None,
mode: Union[str, int] = "downsampling",
ignore_warning: bool = False,
):
"""
Args:
labels: list of labels for each elem in the dataset
exp_lambda: exponent figure for schedule
start_epoch: start epoch number, can be useful for multi-stage
experiments
max_d: if not None, limit on the difference between the most
frequent and the rarest classes, heuristic
mode: number of samples per class in the end of training. Must be
"downsampling" or number. Before change it, make sure that you
understand how does it work
ignore_warning: ignore warning about min class size
"""
assert isinstance(start_epoch, int)
assert 0 < exp_lambda < 1, "exp_lambda must be in (0, 1)"
super().__init__(labels)
self.exp_lambda = exp_lambda
if max_d is None:
max_d = np.inf
self.max_d = max_d
self.epoch = start_epoch
labels = np.array(labels)
samples_per_class = Counter(labels)
self.min_class_size = min(samples_per_class.values())

if self.min_class_size < 100 and not ignore_warning:
logger = logging.getLogger(__name__)
logger.warning(
f"the smallest class contains only"
f" {self.min_class_size} examples. At the end of"
f" training, epochs will contain only"
f" {self.min_class_size * len(samples_per_class)}"
f" examples"
)

self.original_d = {
key: value / self.min_class_size
for key, value in samples_per_class.items()
}
self.label2idxes = {
label: np.arange(len(labels))[labels == label].tolist()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it will be better to use pure python instead of numpy + conversion to list?

Copy link
Contributor Author

@Dokholyan Dokholyan Nov 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bagxi This code a simple copy from BalanceClassSampler

self.lbl2idx = {

for label in set(labels)
}

if isinstance(mode, int):
self.min_class_size = mode
else:
assert mode == "downsampling"

self.labels = labels
self._update()

def _update(self) -> None:
"""
Update d coefficients
Returns: None
"""
current_d = {
key: min(value ** self._exp_scheduler(), self.max_d)
for key, value in self.original_d.items()
}
samples_per_classes = {
key: int(value * self.min_class_size)
for key, value in current_d.items()
}
self.samples_per_classes = samples_per_classes
self.length = np.sum(list(samples_per_classes.values()))
self.epoch += 1

def _exp_scheduler(self) -> float:
return self.exp_lambda ** self.epoch

def __iter__(self) -> Iterator[int]:
"""
Yields:
indices of stratified sample
"""
indices = []
for key in sorted(self.label2idxes):
samples_per_class = self.samples_per_classes[key]
replace_flag = samples_per_class > len(self.label2idxes[key])
indices += np.random.choice(
self.label2idxes[key], samples_per_class, replace=replace_flag
).tolist()
assert len(indices) == self.length
np.random.shuffle(indices)
self._update()
return iter(indices)

def __len__(self) -> int:
"""
Returns:
length of result sample
"""
return self.length


class MiniEpochSampler(Sampler):
"""
Sampler iterates mini epochs from the dataset used by ``mini_epoch_len``.
Expand Down Expand Up @@ -424,5 +573,6 @@ def __iter__(self):
"BalanceBatchSampler",
"MiniEpochSampler",
"DistributedSamplerWrapper",
"DynamicBalanceClassSampler",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please add it to catalyst/data/__init__.py?

"DynamicLenBatchSampler",
]
68 changes: 67 additions & 1 deletion catalyst/data/tests/test_sampler.py
Expand Up @@ -3,9 +3,13 @@
from operator import itemgetter
from random import randint, shuffle

import numpy as np
import pytest

from catalyst.data.sampler import BalanceBatchSampler
from catalyst.data.sampler import (
BalanceBatchSampler,
DynamicBalanceClassSampler,
)

TLabelsPK = List[Tuple[List[int], int, int]]

Expand Down Expand Up @@ -122,3 +126,65 @@ def test_balance_batch_sampler(
"""
for labels, p, k in input_for_balance_batch_sampler:
check_balance_batch_sampler_epoch(labels=labels, p=p, k=k)


@pytest.fixture()
def input_for_dynamic_balance_class_sampler() -> List[Tuple[list, float]]:
"""
This function generates some valid inputs for DynamicBalanceClassSampler

Returns:
inputs in the folowing order: (labels, exp_l)
"""
labels = [
# class imbalance
np.array([0] * 100 + [1] * 10 + [2] * 20),
# uniform class distribution
np.array([0] * 10 + [1] * 10 + [2] * 10),
# random class distribution
np.random.randint(0, 4, size=(200,)),
]
exp_lambda = np.linspace(0.1, 0.95, 11)
input_cases = np.transpose(
[np.tile(labels, len(exp_lambda)), np.repeat(exp_lambda, len(labels))]
)
return input_cases


def check_dynamic_balance_class_sampler(labels: List, exp_l: float) -> None:
"""
Check DynamicBalanceClassSampler on certain inputs

Args:
labels: list of labels
exp_l: exponent figure for schedule
"""
sampler = DynamicBalanceClassSampler(labels, exp_l)
n_labels = len(np.unique(labels))
labels_counter = Counter(labels)
min_class_key, min_class_size = labels_counter.most_common(n_labels)[-1]
current_d = {
key: value / min_class_size for key, value in Counter(labels).items()
}
for _ in range(10): # noqa: WPS122
new_counter = Counter(labels[list(sampler.__iter__())])
new_d = {
key: value / min_class_size for key, value in new_counter.items()
}
for key, value in new_d.items():
assert value <= current_d[key]
assert new_d[min_class_key] == 1
current_d = new_d


def test_dynamic_balance_class_sampler(
input_for_dynamic_balance_class_sampler, # noqa: WPS442
) -> None:
"""
Tests for DynamicBalanceClassSampler

Args:
input_for_dynamic_balance_class_sampler: list of (labels, exp_l)
"""
for labels, exp_l in input_for_dynamic_balance_class_sampler:
check_dynamic_balance_class_sampler(labels, exp_l)
7 changes: 7 additions & 0 deletions docs/api/data.rst
Expand Up @@ -234,6 +234,13 @@ DistributedSamplerWrapper
:undoc-members:
:special-members: __iter__, __len__

DynamicBalanceClassSampler
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: catalyst.data.sampler.DynamicBalanceClassSampler
:members:
:undoc-members:
:special-members: __iter__, __len__

DynamicLenBatchSampler
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: catalyst.data.sampler.DynamicLenBatchSampler
Expand Down