From 142404f9ce8fe4e7a72e8c59be70a012b6f707cd Mon Sep 17 00:00:00 2001 From: Polina Kazakova Date: Tue, 20 Sep 2022 15:12:51 +0200 Subject: [PATCH] decode mp3 with librosa if torchaudio >= 0.12 doesn't work as a temporary workaround (#4923) * decode mp3 with librosa if torchaudio is > 0.12 (ideally version of ffmpeg should be checked too) * decode mp3 with torchaudio>=0.12 if it works (instead of librosa) * fix incorrect marks for mp3 tests (require torchaudio, not sndfile) * add tests for latest torchaudio + separate stage in CI for it (first try) * install ffmpeg only on ubuntu * use mock to emulate torchaudio fail, add tests for librosa (not all of them) * test torchaudio_latest only on ubuntu * try/except decoding with librosa for file-like objects * more tests for latest torchaudio, should be comlpete set now * replace logging with warnings * fix tests: catch warnings with a pytest context manager Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> --- .github/workflows/ci.yml | 10 ++ src/datasets/features/audio.py | 58 ++++++++--- tests/conftest.py | 4 + tests/features/test_audio.py | 171 ++++++++++++++++++++++++++++++++- tests/utils.py | 11 ++- 5 files changed, 237 insertions(+), 17 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e2f18e0b50e..2dff4e6bd0c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -72,3 +72,13 @@ jobs: - name: Test with pytest run: | python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/ + - name: Install dependencies to test torchaudio>=0.12 on Ubuntu + if: ${{ matrix.os == 'ubuntu-latest' }} + run: | + pip uninstall -y torchaudio torch + pip install "torchaudio>=0.12" + sudo apt-get -y install ffmpeg + - name: Test torchaudio>=0.12 on Ubuntu + if: ${{ matrix.os == 'ubuntu-latest' }} + run: | + python -m pytest -rfExX -m torchaudio_latest -n 2 --dist loadfile -sv ./tests/features/test_audio.py diff --git a/src/datasets/features/audio.py b/src/datasets/features/audio.py index d68944010dd..b53dd2d2663 100644 --- a/src/datasets/features/audio.py +++ b/src/datasets/features/audio.py @@ -1,4 +1,5 @@ import os +import warnings from dataclasses import dataclass, field from io import BytesIO from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Union @@ -268,7 +269,7 @@ def _decode_non_mp3_file_like(self, file, format=None): if version.parse(sf.__libsndfile_version__) < version.parse("1.0.30"): raise RuntimeError( "Decoding .opus files requires 'libsndfile'>=1.0.30, " - + "it can be installed via conda: `conda install -c conda-forge libsndfile>=1.0.30`" + + 'it can be installed via conda: `conda install -c conda-forge "libsndfile>=1.0.30"`' ) array, sampling_rate = sf.read(file) array = array.T @@ -282,19 +283,44 @@ def _decode_non_mp3_file_like(self, file, format=None): def _decode_mp3(self, path_or_file): try: import torchaudio - import torchaudio.transforms as T except ImportError as err: - raise ImportError( - "Decoding 'mp3' audio files, requires 'torchaudio<0.12.0': pip install 'torchaudio<0.12.0'" - ) from err - if not version.parse(torchaudio.__version__) < version.parse("0.12.0"): - raise RuntimeError( - "Decoding 'mp3' audio files, requires 'torchaudio<0.12.0': pip install 'torchaudio<0.12.0'" - ) - try: - torchaudio.set_audio_backend("sox_io") - except RuntimeError as err: - raise ImportError("To support decoding 'mp3' audio files, please install 'sox'.") from err + raise ImportError("To support decoding 'mp3' audio files, please install 'torchaudio'.") from err + if version.parse(torchaudio.__version__) < version.parse("0.12.0"): + try: + torchaudio.set_audio_backend("sox_io") + except RuntimeError as err: + raise ImportError("To support decoding 'mp3' audio files, please install 'sox'.") from err + array, sampling_rate = self._decode_mp3_torchaudio(path_or_file) + else: + try: # try torchaudio anyway because sometimes it works (depending on the os and os packages installed) + array, sampling_rate = self._decode_mp3_torchaudio(path_or_file) + except RuntimeError: + try: + # flake8: noqa + import librosa + except ImportError as err: + raise ImportError( + "Your version of `torchaudio` (>=0.12.0) doesn't support decoding 'mp3' files on your machine. " + "To support 'mp3' decoding with `torchaudio>=0.12.0`, please install `ffmpeg>=4` system package " + 'or downgrade `torchaudio` to <0.12: `pip install "torchaudio<0.12"`. ' + "To support decoding 'mp3' audio files without `torchaudio`, please install `librosa`: " + "`pip install librosa`. Note that decoding will be extremely slow in that case." + ) from err + # try to decode with librosa for torchaudio>=0.12.0 as a workaround + warnings.warn("Decoding mp3 with `librosa` instead of `torchaudio`, decoding is slow.") + try: + array, sampling_rate = self._decode_mp3_librosa(path_or_file) + except RuntimeError as err: + raise RuntimeError( + "Decoding of 'mp3' failed, probably because of streaming mode " + "(`librosa` cannot decode 'mp3' file-like objects, only path-like)." + ) from err + + return array, sampling_rate + + def _decode_mp3_torchaudio(self, path_or_file): + import torchaudio + import torchaudio.transforms as T array, sampling_rate = torchaudio.load(path_or_file, format="mp3") if self.sampling_rate and self.sampling_rate != sampling_rate: @@ -306,3 +332,9 @@ def _decode_mp3(self, path_or_file): if self.mono: array = array.mean(axis=0) return array, sampling_rate + + def _decode_mp3_librosa(self, path_or_file): + import librosa + + array, sampling_rate = librosa.load(path_or_file, mono=self.mono, sr=self.sampling_rate) + return array, sampling_rate diff --git a/tests/conftest.py b/tests/conftest.py index c1a1af48e57..beb32b7119e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,10 @@ def pytest_collection_modifyitems(config, items): item.add_marker(pytest.mark.unit) +def pytest_configure(config): + config.addinivalue_line("markers", "torchaudio_latest: mark test to run with torchaudio>=0.12") + + @pytest.fixture(autouse=True) def set_test_cache_config(tmp_path_factory, monkeypatch): # test_hf_cache_home = tmp_path_factory.mktemp("cache") # TODO: why a cache dir per test function does not work? diff --git a/tests/features/test_audio.py b/tests/features/test_audio.py index 43101fe762a..c2373e7007c 100644 --- a/tests/features/test_audio.py +++ b/tests/features/test_audio.py @@ -1,5 +1,7 @@ import os import tarfile +from contextlib import nullcontext +from unittest.mock import patch import pyarrow as pa import pytest @@ -7,7 +9,13 @@ from datasets import Dataset, concatenate_datasets, load_dataset from datasets.features import Audio, Features, Sequence, Value -from ..utils import require_libsndfile_with_opus, require_sndfile, require_sox, require_torchaudio +from ..utils import ( + require_libsndfile_with_opus, + require_sndfile, + require_sox, + require_torchaudio, + require_torchaudio_latest, +) @pytest.fixture() @@ -135,6 +143,26 @@ def test_audio_decode_example_mp3(shared_datadir): assert decoded_example["sampling_rate"] == 44100 +@pytest.mark.torchaudio_latest +@require_torchaudio_latest +@pytest.mark.parametrize("torchaudio_failed", [False, True]) +def test_audio_decode_example_mp3_torchaudio_latest(shared_datadir, torchaudio_failed): + audio_path = str(shared_datadir / "test_audio_44100.mp3") + audio = Audio() + + with patch("torchaudio.load") if torchaudio_failed else nullcontext() as load_mock, pytest.warns( + UserWarning, match=r"Decoding mp3 with `librosa` instead of `torchaudio`.+?" + ) if torchaudio_failed else nullcontext(): + + if torchaudio_failed: + load_mock.side_effect = RuntimeError() + + decoded_example = audio.decode_example(audio.encode_example(audio_path)) + assert decoded_example["path"] == audio_path + assert decoded_example["array"].shape == (110592,) + assert decoded_example["sampling_rate"] == 44100 + + @require_libsndfile_with_opus def test_audio_decode_example_opus(shared_datadir): audio_path = str(shared_datadir / "test_audio_48000.opus") @@ -178,6 +206,34 @@ def test_audio_resampling_mp3_different_sampling_rates(shared_datadir): assert decoded_example["sampling_rate"] == 48000 +@pytest.mark.torchaudio_latest +@require_torchaudio_latest +@pytest.mark.parametrize("torchaudio_failed", [False, True]) +def test_audio_resampling_mp3_different_sampling_rates_torchaudio_latest(shared_datadir, torchaudio_failed): + audio_path = str(shared_datadir / "test_audio_44100.mp3") + audio_path2 = str(shared_datadir / "test_audio_16000.mp3") + audio = Audio(sampling_rate=48000) + + # if torchaudio>=0.12 failed, mp3 must be decoded anyway (with librosa) + with patch("torchaudio.load") if torchaudio_failed else nullcontext() as load_mock, pytest.warns( + UserWarning, match=r"Decoding mp3 with `librosa` instead of `torchaudio`.+?" + ) if torchaudio_failed else nullcontext(): + if torchaudio_failed: + load_mock.side_effect = RuntimeError() + + decoded_example = audio.decode_example(audio.encode_example(audio_path)) + assert decoded_example.keys() == {"path", "array", "sampling_rate"} + assert decoded_example["path"] == audio_path + assert decoded_example["array"].shape == (120373,) + assert decoded_example["sampling_rate"] == 48000 + + decoded_example = audio.decode_example(audio.encode_example(audio_path2)) + assert decoded_example.keys() == {"path", "array", "sampling_rate"} + assert decoded_example["path"] == audio_path2 + assert decoded_example["array"].shape == (122688,) + assert decoded_example["sampling_rate"] == 48000 + + @require_sndfile def test_dataset_with_audio_feature(shared_datadir): audio_path = str(shared_datadir / "test_audio_44100.wav") @@ -266,6 +322,38 @@ def test_dataset_with_audio_feature_tar_mp3(tar_mp3_path): assert column[0]["sampling_rate"] == 44100 +@pytest.mark.torchaudio_latest +@require_torchaudio_latest +def test_dataset_with_audio_feature_tar_mp3_torchaudio_latest(tar_mp3_path): + # no test for librosa here because it doesn't support file-like objects, only paths + audio_filename = "test_audio_44100.mp3" + data = {"audio": []} + for file_path, file_obj in iter_archive(tar_mp3_path): + data["audio"].append({"path": file_path, "bytes": file_obj.read()}) + break + features = Features({"audio": Audio()}) + dset = Dataset.from_dict(data, features=features) + item = dset[0] + assert item.keys() == {"audio"} + assert item["audio"].keys() == {"path", "array", "sampling_rate"} + assert item["audio"]["path"] == audio_filename + assert item["audio"]["array"].shape == (110592,) + assert item["audio"]["sampling_rate"] == 44100 + batch = dset[:1] + assert batch.keys() == {"audio"} + assert len(batch["audio"]) == 1 + assert batch["audio"][0].keys() == {"path", "array", "sampling_rate"} + assert batch["audio"][0]["path"] == audio_filename + assert batch["audio"][0]["array"].shape == (110592,) + assert batch["audio"][0]["sampling_rate"] == 44100 + column = dset["audio"] + assert len(column) == 1 + assert column[0].keys() == {"path", "array", "sampling_rate"} + assert column[0]["path"] == audio_filename + assert column[0]["array"].shape == (110592,) + assert column[0]["sampling_rate"] == 44100 + + @require_sndfile def test_dataset_with_audio_feature_with_none(): data = {"audio": [None]} @@ -328,7 +416,7 @@ def test_resampling_at_loading_dataset_with_audio_feature(shared_datadir): @require_sox -@require_sndfile +@require_torchaudio def test_resampling_at_loading_dataset_with_audio_feature_mp3(shared_datadir): audio_path = str(shared_datadir / "test_audio_44100.mp3") data = {"audio": [audio_path]} @@ -355,6 +443,43 @@ def test_resampling_at_loading_dataset_with_audio_feature_mp3(shared_datadir): assert column[0]["sampling_rate"] == 16000 +@pytest.mark.torchaudio_latest +@require_torchaudio_latest +@pytest.mark.parametrize("torchaudio_failed", [False, True]) +def test_resampling_at_loading_dataset_with_audio_feature_mp3_torchaudio_latest(shared_datadir, torchaudio_failed): + audio_path = str(shared_datadir / "test_audio_44100.mp3") + data = {"audio": [audio_path]} + features = Features({"audio": Audio(sampling_rate=16000)}) + dset = Dataset.from_dict(data, features=features) + + # if torchaudio>=0.12 failed, mp3 must be decoded anyway (with librosa) + with patch("torchaudio.load") if torchaudio_failed else nullcontext() as load_mock, pytest.warns( + UserWarning, match=r"Decoding mp3 with `librosa` instead of `torchaudio`.+?" + ) if torchaudio_failed else nullcontext(): + if torchaudio_failed: + load_mock.side_effect = RuntimeError() + + item = dset[0] + assert item.keys() == {"audio"} + assert item["audio"].keys() == {"path", "array", "sampling_rate"} + assert item["audio"]["path"] == audio_path + assert item["audio"]["array"].shape == (40125,) + assert item["audio"]["sampling_rate"] == 16000 + batch = dset[:1] + assert batch.keys() == {"audio"} + assert len(batch["audio"]) == 1 + assert batch["audio"][0].keys() == {"path", "array", "sampling_rate"} + assert batch["audio"][0]["path"] == audio_path + assert batch["audio"][0]["array"].shape == (40125,) + assert batch["audio"][0]["sampling_rate"] == 16000 + column = dset["audio"] + assert len(column) == 1 + assert column[0].keys() == {"path", "array", "sampling_rate"} + assert column[0]["path"] == audio_path + assert column[0]["array"].shape == (40125,) + assert column[0]["sampling_rate"] == 16000 + + @require_sndfile def test_resampling_after_loading_dataset_with_audio_feature(shared_datadir): audio_path = str(shared_datadir / "test_audio_44100.wav") @@ -386,7 +511,7 @@ def test_resampling_after_loading_dataset_with_audio_feature(shared_datadir): @require_sox -@require_sndfile +@require_torchaudio def test_resampling_after_loading_dataset_with_audio_feature_mp3(shared_datadir): audio_path = str(shared_datadir / "test_audio_44100.mp3") data = {"audio": [audio_path]} @@ -416,6 +541,46 @@ def test_resampling_after_loading_dataset_with_audio_feature_mp3(shared_datadir) assert column[0]["sampling_rate"] == 16000 +@pytest.mark.torchaudio_latest +@require_torchaudio_latest +@pytest.mark.parametrize("torchaudio_failed", [False, True]) +def test_resampling_after_loading_dataset_with_audio_feature_mp3_torchaudio_latest(shared_datadir, torchaudio_failed): + audio_path = str(shared_datadir / "test_audio_44100.mp3") + data = {"audio": [audio_path]} + features = Features({"audio": Audio()}) + dset = Dataset.from_dict(data, features=features) + + # if torchaudio>=0.12 failed, mp3 must be decoded anyway (with librosa) + with patch("torchaudio.load") if torchaudio_failed else nullcontext() as load_mock, pytest.warns( + UserWarning, match=r"Decoding mp3 with `librosa` instead of `torchaudio`.+?" + ) if torchaudio_failed else nullcontext(): + if torchaudio_failed: + load_mock.side_effect = RuntimeError() + + item = dset[0] + assert item["audio"]["sampling_rate"] == 44100 + dset = dset.cast_column("audio", Audio(sampling_rate=16000)) + item = dset[0] + assert item.keys() == {"audio"} + assert item["audio"].keys() == {"path", "array", "sampling_rate"} + assert item["audio"]["path"] == audio_path + assert item["audio"]["array"].shape == (40125,) + assert item["audio"]["sampling_rate"] == 16000 + batch = dset[:1] + assert batch.keys() == {"audio"} + assert len(batch["audio"]) == 1 + assert batch["audio"][0].keys() == {"path", "array", "sampling_rate"} + assert batch["audio"][0]["path"] == audio_path + assert batch["audio"][0]["array"].shape == (40125,) + assert batch["audio"][0]["sampling_rate"] == 16000 + column = dset["audio"] + assert len(column) == 1 + assert column[0].keys() == {"path", "array", "sampling_rate"} + assert column[0]["path"] == audio_path + assert column[0]["array"].shape == (40125,) + assert column[0]["sampling_rate"] == 16000 + + @pytest.mark.parametrize( "build_data", [ diff --git a/tests/utils.py b/tests/utils.py index 70853c6cdcb..eecdef966d2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -64,7 +64,16 @@ def parse_flag_from_env(key, default=False): find_library("sox") is None, reason="test requires sox OS dependency; only available on non-Windows: 'sudo apt-get install sox'", ) -require_torchaudio = pytest.mark.skipif(find_spec("torchaudio") is None, reason="test requires torchaudio") +require_torchaudio = pytest.mark.skipif( + find_spec("torchaudio") is None + or version.parse(import_module("torchaudio").__version__) >= version.parse("0.12.0"), + reason="test requires torchaudio<0.12", +) +require_torchaudio_latest = pytest.mark.skipif( + find_spec("torchaudio") is None + or version.parse(import_module("torchaudio").__version__) < version.parse("0.12.0"), + reason="test requires torchaudio>=0.12", +) def require_beam(test_case):