diff --git a/.github/workflows/test_requirements.txt b/.github/workflows/test_requirements.txt index fb832230..77be14e0 100644 --- a/.github/workflows/test_requirements.txt +++ b/.github/workflows/test_requirements.txt @@ -9,3 +9,4 @@ py-cpuinfo>=7.0.0 pytest==5.3.4 pytest-cov==2.8.1 coverage==4.5.2 +PyYAML>=5.3.1 diff --git a/environment.yml b/environment.yml index 6946bbc7..98d0e7b0 100644 --- a/environment.yml +++ b/environment.yml @@ -24,3 +24,4 @@ dependencies: - setuptools>=41.0.0 - tqdm==4.49.0 - twine + - PyYAML==5.3.1 diff --git a/setup.py b/setup.py index 13b3fbb7..35634142 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ def find_version(*file_paths): exclude=["build", "scripts", "dist", "images", "test_fixtures", "tests"] ), install_requires=["torch>=1.2.0"], - tests_require=["pytest", "pytest-cov"], + tests_require=["pytest", "pytest-cov", "PyYAML"], python_requires=">=3.6", classifiers=[ "Programming Language :: Python :: 3", diff --git a/test_fixtures/config.yml b/test_fixtures/config.yml new file mode 100644 index 00000000..94ad7b8c --- /dev/null +++ b/test_fixtures/config.yml @@ -0,0 +1,4 @@ +transform: Gain +params: + min_gain_in_db: -12.0 + mode: per_channel diff --git a/test_fixtures/config_compose.yml b/test_fixtures/config_compose.yml new file mode 100644 index 00000000..1af641d6 --- /dev/null +++ b/test_fixtures/config_compose.yml @@ -0,0 +1,9 @@ +transform: Compose +params: + transforms: + - transform: Gain + params: + min_gain_in_db: -12.0 + mode: per_channel + - transform: PolarityInversion + shuffle: True diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 00000000..43480e32 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,50 @@ +import unittest + +from tests.utils import TEST_FIXTURES_DIR +from torch_audiomentations import from_dict, from_yaml +from torch_audiomentations import Gain, Compose + + +class TestFromConfig(unittest.TestCase): + def test_from_dict(self): + config = { + "transform": "Gain", + "params": {"min_gain_in_db": -12.0, "mode": "per_channel"}, + } + transform = from_dict(config) + + assert isinstance(transform, Gain) + assert transform.min_gain_in_db == -12.0 + assert transform.max_gain_in_db == 6.0 + assert transform.mode == "per_channel" + + def test_from_yaml(self): + file_yml = TEST_FIXTURES_DIR / "config.yml" + transform = from_yaml(file_yml) + + assert isinstance(transform, Gain) + assert transform.min_gain_in_db == -12.0 + assert transform.max_gain_in_db == 6.0 + assert transform.mode == "per_channel" + + def test_from_dict_compose(self): + config = { + "transform": "Compose", + "params": { + "shuffle": True, + "transforms": [ + { + "transform": "Gain", + "params": {"min_gain_in_db": -12.0, "mode": "per_channel"}, + }, + {"transform": "PolarityInversion"}, + ], + }, + } + transform = from_dict(config) + assert isinstance(transform, Compose) + + def test_from_yaml_compose(self): + file_yml = TEST_FIXTURES_DIR / "config_compose.yml" + transform = from_yaml(file_yml) + assert isinstance(transform, Compose) diff --git a/torch_audiomentations/__init__.py b/torch_audiomentations/__init__.py index 9cfd3cc4..7112e3af 100644 --- a/torch_audiomentations/__init__.py +++ b/torch_audiomentations/__init__.py @@ -3,5 +3,6 @@ from .augmentations.peak_normalization import PeakNormalization from .utils.convolution import convolve +from .utils.config import from_dict, from_yaml __version__ = "0.3.0" diff --git a/torch_audiomentations/utils/config.py b/torch_audiomentations/utils/config.py new file mode 100644 index 00000000..cb6a9328 --- /dev/null +++ b/torch_audiomentations/utils/config.py @@ -0,0 +1,120 @@ +import warnings +from typing import Any, Dict, Text, Union +from pathlib import Path +import torch_audiomentations + +from torch_audiomentations.core.transforms_interface import BaseWaveformTransform +from torch_audiomentations import Compose + +# TODO: define this elsewhere? +# TODO: update when a new type of transform is added (e.g. BaseSpectrogramTransform? OneOf? SomeOf?) +# https://github.com/asteroid-team/torch-audiomentations/issues/26 +Transform = Union[BaseWaveformTransform, Compose] + + +def from_dict(config: Dict[Text, Union[Text, Dict[Text, Any]]]) -> Transform: + """Instantiate a transform from a configuration dictionary. + + `from_dict` can be used to instantiate a transform from its class name. + For instance, these two pieces of code are equivalent: + + >>> from torch_audiomentations import Gain + >>> transform = Gain(min_gain_in_db=-12.0) + + >>> transform = from_dict({'transform': 'Gain', + ... 'params': {'min_gain_in_db': -12.0}}) + + Transforms composition is also supported: + + >>> compose = from_dict( + ... {'transform': 'Compose', + ... 'params': {'transforms': [{'transform': 'Gain', + ... 'params': {'min_gain_in_db': -12.0, + ... 'mode': 'per_channel'}}, + ... {'transform': 'PolarityInversion'}], + ... 'shuffle': True}}) + + :param config: configuration - a configuration dictionary + :returns: A transform. + :rtype Transform: + """ + + try: + TransformClassName: Text = config["transform"] + except KeyError: + raise ValueError( + "A (currently missing) 'transform' key should be used to define the transform type." + ) + + try: + TransformClass = getattr(torch_audiomentations, TransformClassName) + except AttributeError: + raise ValueError( + f"torch_audiomentations does not implement {TransformClassName} transform." + ) + + transform_params: Dict = config.get("params", dict()) + if not isinstance(transform_params, dict): + raise ValueError( + "Transform parameters must be provided as {'param_name': param_value} dictionary." + ) + + if TransformClassName in ["Compose", "OneOf", "SomeOf"]: + transform_params["transforms"] = [ + from_dict(sub_transform_config) + for sub_transform_config in transform_params["transforms"] + ] + + return TransformClass(**transform_params) + + +def from_yaml(file_yml: Union[Path, Text]) -> Transform: + """Instantiate a transform from a YAML configuration file. + + `from_yaml` can be used to instantiate a transform from a YAML file. + For instance, these two pieces of code are equivalent: + + >>> from torch_audiomentations import Gain + >>> transform = Gain(min_gain_in_db=-12.0, mode="per_channel") + + >>> transform = from_yaml("config.yml") + + where the content of `config.yml` is something like: + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # config.yml + transform: Gain + params: + min_gain_in_db: -12.0 + mode: per_channel + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Transforms composition is also supported: + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # config.yml + transform: Compose + params: + shuffle: True + transforms: + - transform: Gain + params: + min_gain_in_db: -12.0 + mode: per_channel + - transform: PolarityInversion + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + :param file_yml: configuration file - a path to a YAML file with the following structure: + :returns: A transform. + :rtype Transform: + """ + + try: + import yaml + except ImportError as e: + raise ImportError( + "PyYAML package is needed by `from_yaml`: please install it first." + ) + + with open(file_yml, "r") as f: + config = yaml.load(f, Loader=yaml.SafeLoader) + + return from_dict(config)