Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add from_dict/from_yaml utility functions #31

Merged
merged 7 commits into from Nov 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test_requirements.txt
Expand Up @@ -9,3 +9,4 @@ py-cpuinfo>=7.0.0
pytest==5.3.4
pytest-cov==2.8.1
coverage==4.5.2
PyYAML>=5.3.1
1 change: 1 addition & 0 deletions environment.yml
Expand Up @@ -24,3 +24,4 @@ dependencies:
- setuptools>=41.0.0
- tqdm==4.49.0
- twine
- PyYAML==5.3.1
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -37,7 +37,7 @@ def find_version(*file_paths):
exclude=["build", "scripts", "dist", "images", "test_fixtures", "tests"]
),
install_requires=["torch>=1.2.0"],
tests_require=["pytest", "pytest-cov"],
tests_require=["pytest", "pytest-cov", "PyYAML"],
python_requires=">=3.6",
classifiers=[
"Programming Language :: Python :: 3",
Expand Down
4 changes: 4 additions & 0 deletions test_fixtures/config.yml
@@ -0,0 +1,4 @@
transform: Gain
params:
min_gain_in_db: -12.0
mode: per_channel
9 changes: 9 additions & 0 deletions test_fixtures/config_compose.yml
@@ -0,0 +1,9 @@
transform: Compose
params:
transforms:
- transform: Gain
params:
min_gain_in_db: -12.0
mode: per_channel
- transform: PolarityInversion
shuffle: True
50 changes: 50 additions & 0 deletions tests/test_config.py
@@ -0,0 +1,50 @@
import unittest

from tests.utils import TEST_FIXTURES_DIR
from torch_audiomentations import from_dict, from_yaml
from torch_audiomentations import Gain, Compose


class TestFromConfig(unittest.TestCase):
def test_from_dict(self):
config = {
"transform": "Gain",
"params": {"min_gain_in_db": -12.0, "mode": "per_channel"},
}
transform = from_dict(config)

assert isinstance(transform, Gain)
assert transform.min_gain_in_db == -12.0
assert transform.max_gain_in_db == 6.0
assert transform.mode == "per_channel"

def test_from_yaml(self):
file_yml = TEST_FIXTURES_DIR / "config.yml"
transform = from_yaml(file_yml)

assert isinstance(transform, Gain)
assert transform.min_gain_in_db == -12.0
assert transform.max_gain_in_db == 6.0
assert transform.mode == "per_channel"

def test_from_dict_compose(self):
config = {
"transform": "Compose",
"params": {
"shuffle": True,
"transforms": [
{
"transform": "Gain",
"params": {"min_gain_in_db": -12.0, "mode": "per_channel"},
},
{"transform": "PolarityInversion"},
],
},
}
transform = from_dict(config)
assert isinstance(transform, Compose)

def test_from_yaml_compose(self):
file_yml = TEST_FIXTURES_DIR / "config_compose.yml"
transform = from_yaml(file_yml)
assert isinstance(transform, Compose)
1 change: 1 addition & 0 deletions torch_audiomentations/__init__.py
Expand Up @@ -3,5 +3,6 @@
from .augmentations.peak_normalization import PeakNormalization

from .utils.convolution import convolve
from .utils.config import from_dict, from_yaml

__version__ = "0.3.0"
120 changes: 120 additions & 0 deletions torch_audiomentations/utils/config.py
@@ -0,0 +1,120 @@
import warnings
from typing import Any, Dict, Text, Union
from pathlib import Path
import torch_audiomentations

from torch_audiomentations.core.transforms_interface import BaseWaveformTransform
from torch_audiomentations import Compose

# TODO: define this elsewhere?
# TODO: update when a new type of transform is added (e.g. BaseSpectrogramTransform? OneOf? SomeOf?)
# https://github.com/asteroid-team/torch-audiomentations/issues/26
Transform = Union[BaseWaveformTransform, Compose]


def from_dict(config: Dict[Text, Union[Text, Dict[Text, Any]]]) -> Transform:
"""Instantiate a transform from a configuration dictionary.

`from_dict` can be used to instantiate a transform from its class name.
For instance, these two pieces of code are equivalent:

>>> from torch_audiomentations import Gain
>>> transform = Gain(min_gain_in_db=-12.0)

>>> transform = from_dict({'transform': 'Gain',
... 'params': {'min_gain_in_db': -12.0}})

Transforms composition is also supported:

>>> compose = from_dict(
... {'transform': 'Compose',
... 'params': {'transforms': [{'transform': 'Gain',
... 'params': {'min_gain_in_db': -12.0,
... 'mode': 'per_channel'}},
... {'transform': 'PolarityInversion'}],
... 'shuffle': True}})

:param config: configuration - a configuration dictionary
:returns: A transform.
:rtype Transform:
"""

try:
TransformClassName: Text = config["transform"]
except KeyError:
raise ValueError(
"A (currently missing) 'transform' key should be used to define the transform type."
)

try:
TransformClass = getattr(torch_audiomentations, TransformClassName)
except AttributeError:
raise ValueError(
f"torch_audiomentations does not implement {TransformClassName} transform."
)

transform_params: Dict = config.get("params", dict())
if not isinstance(transform_params, dict):
raise ValueError(
"Transform parameters must be provided as {'param_name': param_value} dictionary."
)

if TransformClassName in ["Compose", "OneOf", "SomeOf"]:
transform_params["transforms"] = [
from_dict(sub_transform_config)
for sub_transform_config in transform_params["transforms"]
]

return TransformClass(**transform_params)


def from_yaml(file_yml: Union[Path, Text]) -> Transform:
"""Instantiate a transform from a YAML configuration file.

`from_yaml` can be used to instantiate a transform from a YAML file.
For instance, these two pieces of code are equivalent:

>>> from torch_audiomentations import Gain
>>> transform = Gain(min_gain_in_db=-12.0, mode="per_channel")

>>> transform = from_yaml("config.yml")

where the content of `config.yml` is something like:
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# config.yml
transform: Gain
params:
min_gain_in_db: -12.0
mode: per_channel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Transforms composition is also supported:
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# config.yml
transform: Compose
params:
shuffle: True
transforms:
- transform: Gain
params:
min_gain_in_db: -12.0
mode: per_channel
- transform: PolarityInversion
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

:param file_yml: configuration file - a path to a YAML file with the following structure:
:returns: A transform.
:rtype Transform:
"""

try:
import yaml
except ImportError as e:
raise ImportError(
"PyYAML package is needed by `from_yaml`: please install it first."
)

with open(file_yml, "r") as f:
config = yaml.load(f, Loader=yaml.SafeLoader)

return from_dict(config)