Skip to content

Commit

Permalink
Merge pull request #368 from claritychallenge/jpb-fix-to-cec3-recipe
Browse files Browse the repository at this point in the history
Fix to the CEC3 baseline recipes
  • Loading branch information
jonbarker68 committed Apr 4, 2024
2 parents 66ed680 + 98e615a commit 707c5a3
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 50 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Expand Up @@ -40,7 +40,7 @@ dependencies = [
"pytorch-lightning",
"resampy",
"scikit-learn>=1.0.2",
"scipy>=1.7.3",
"scipy>=1.7.3, <1.13.0",
"SoundFile>=0.10.3.post1",
"soxr",
"torch>=2",
Expand Down
19 changes: 14 additions & 5 deletions recipes/cec3/README.md
Expand Up @@ -65,22 +65,31 @@ These can be set in the `config.yaml` file or provided on the command line. In t

The baseline enhancement simply takes the 6-channel hearing aid inputs and reduces this to a stereo hearing aid output by passing through the 'front' microphone signal of the left and right ear.

Alternatively, you can provide the root variable on the command line, e.g.,
The stereo pair is then passed through a provided hearing aid amplification stage using a NAL-R [[1](#references)] fitting amplification and a simple automatic gain compressor. The amplification is determined by the audiograms defined by the scene-listener pairs in `clarity_data/metadata/scenes_listeners.dev.json` for the development set. After amplification, the evaluate function calculates the better-ear HASPI [[2](#references)].

To run the baseline enhancement system use, first set the `task`, `path.root` and `path.exp` variables in the `config.yaml` file and then run,

```bash
python enhance.py
```

Alternatively, you can provide the task and paths on the command line, e.g.,

```bash
python enhance.py task=task1 path.root=/Users/jon/clarity_CEC3_data path.exp=/Users/jon/exp
```

Where '/Users/jon' is replaced with the path to the root of the clarity data and the experiment folder.
Where `/Users/jon` is replaced with the path to the root of the clarity data and the experiment folder.

The folder `enhanced_signals` will appear in the `exp` folder. Note, the experiment folder will be created if it does not already exist.
The folders `enhanced_signals` and `amplified_signals` will appear in the `exp` folder. Note, the experiment folder will be created if it does not already exist.

### Evaluation

The `evaluate.py` will first pass signals through a provided hearing aid amplification stage using a NAL-R [[1](#references)] fitting amplification and a simple automatic gain compressor. The amplification is determined by the audiograms defined by the scene-listener pairs in `clarity_data/metadata/scenes_listeners.dev.json` for the development set. After amplification, the evaluate function calculates the better-ear HASPI [[2](#references)].
The evaluate script computes the HASPI scores for the signals stored in the `amplified_signals` folder. The script will read the scene-listener pairs from the development set and calculate the HASPI score for each pair. The final score is the mean HASPI score across all pairs. It can be run as,

```bash
python evaluate.py
python evaluate.py task=task1 path.root=/Users/jon/clarity_CEC3_data path.exp=/Users/jon/exp

```

The full evaluation set is 7500 scene-listener pairs and will take a long time to run, i.e., around 8 hours on a MacBook Pro. A standard small set which uses 1/15 of the data has been defined. This takes around 30 minutes to evaluate and can be run with,
Expand Down
39 changes: 33 additions & 6 deletions recipes/cec3/baseline/enhance.py
Expand Up @@ -10,12 +10,22 @@
from scipy.io import wavfile
from tqdm import tqdm

from clarity.utils.audiogram import Listener
from clarity.enhancer.compressor import Compressor
from clarity.enhancer.nalr import NALR
from clarity.utils.audiogram import Audiogram, Listener
from recipes.icassp_2023.baseline.evaluate import make_scene_listener_list

logger = logging.getLogger(__name__)


def amplify_signal(signal, audiogram: Audiogram, enhancer, compressor):
"""Amplify signal for a given audiogram"""
nalr_fir, _ = enhancer.build(audiogram)
out = enhancer.apply(nalr_fir, signal)
out, _, _ = compressor.process(out)
return out


@hydra.main(config_path=".", config_name="config")
def enhance(cfg: DictConfig) -> None:
"""Run the dummy enhancement."""
Expand All @@ -27,6 +37,10 @@ def enhance(cfg: DictConfig) -> None:
scenes_listeners = json.load(fp)

listener_dict = Listener.load_listener_dict(cfg.path.listeners_file)
enhancer = NALR(**cfg.nalr)
compressor = Compressor(**cfg.compressor)
amplified_folder = pathlib.Path(cfg.path.exp) / "amplified_signals"
amplified_folder.mkdir(parents=True, exist_ok=True)

# Make list of all scene listener pairs that will be run
scene_listener_pairs = make_scene_listener_list(
Expand Down Expand Up @@ -56,14 +70,27 @@ def enhance(cfg: DictConfig) -> None:
# pylint: disable=unused-variable
listener = listener_dict[listener_id] # noqa: F841

# Note: The audiograms are stored in the listener object,
# but they are not needed for the baseline
wavfile.write(
enhanced_folder / f"{scene}_{listener_id}_enhanced.wav", sample_rate, signal
)

# Apply the baseline NALR amplification

# Baseline just reads the signal from the front microphone pair
# and write it out as the enhanced signal
out_l = amplify_signal(
signal[:, 0], listener.audiogram_left, enhancer, compressor
)
out_r = amplify_signal(
signal[:, 1], listener.audiogram_right, enhancer, compressor
)
amplified = np.stack([out_l, out_r], axis=1)

if cfg.soft_clip:
amplified = np.tanh(amplified)

wavfile.write(
enhanced_folder / f"{scene}_{listener_id}_enhanced.wav", sample_rate, signal
amplified_folder / f"{scene}_{listener_id}_HA-output.wav",
sample_rate,
amplified.astype(np.float32),
)


Expand Down
42 changes: 4 additions & 38 deletions recipes/cec3/baseline/evaluate.py
Expand Up @@ -13,22 +13,12 @@
from scipy.io import wavfile
from tqdm import tqdm

from clarity.enhancer.compressor import Compressor
from clarity.enhancer.nalr import NALR
from clarity.evaluator.haspi import haspi_v2_be
from clarity.utils.audiogram import Audiogram, Listener
from clarity.utils.audiogram import Listener

logger = logging.getLogger(__name__)


def amplify_signal(signal, audiogram: Audiogram, enhancer, compressor):
"""Amplify signal for a given audiogram"""
nalr_fir, _ = enhancer.build(audiogram)
out = enhancer.apply(nalr_fir, signal)
out, _, _ = compressor.process(out)
return out


def set_scene_seed(scene):
"""Set a seed that is unique for the given scene"""
scene_encoded = hashlib.md5(scene.encode("utf-8")).hexdigest()
Expand Down Expand Up @@ -100,16 +90,12 @@ def run_calculate_si(cfg: DictConfig) -> None:
scenes_listeners = json.load(fp)

listeners_dict = Listener.load_listener_dict(cfg.path.listeners_file)
enhancer = NALR(**cfg.nalr)
compressor = Compressor(**cfg.compressor)

enhanced_folder = pathlib.Path(cfg.path.exp) / "enhanced_signals"
amplified_folder = pathlib.Path(cfg.path.exp) / "amplified_signals"
scenes_folder = pathlib.Path(cfg.path.scenes_folder)
amplified_folder.mkdir(parents=True, exist_ok=True)

# Make list of all scene listener pairs that will be run

scene_listener_pairs = make_scene_listener_list(
scenes_listeners, cfg.evaluate.small_test
)
Expand All @@ -135,9 +121,8 @@ def run_calculate_si(cfg: DictConfig) -> None:
set_scene_seed(scene)

# Read signals

sr_signal, signal = wavfile.read(
enhanced_folder / f"{scene}_{listener_id}_enhanced.wav"
amplified_folder / f"{scene}_{listener_id}_HA-output.wav",
)
_, reference = wavfile.read(scenes_folder / f"{scene}_reference.wav")

Expand All @@ -147,30 +132,11 @@ def run_calculate_si(cfg: DictConfig) -> None:

reference = reference / 32768.0

# amplify left and right ear signals
# Evaluate the HA-output signals
listener = listeners_dict[listener_id]

out_l = amplify_signal(
signal[:, 0], listener.audiogram_left, enhancer, compressor
)
out_r = amplify_signal(
signal[:, 1], listener.audiogram_right, enhancer, compressor
)
amplified = np.stack([out_l, out_r], axis=1)

if cfg.soft_clip:
amplified = np.tanh(amplified)

wavfile.write(
amplified_folder / f"{scene}_{listener_id}_HA-output.wav",
sr_signal,
amplified.astype(np.float32),
)

# Evaluate the amplified signal

haspi_score = compute_metric(
haspi_v2_be, amplified, reference, listener, sr_signal
haspi_v2_be, signal, reference, listener, sr_signal
)

results_file.add_result(scene, listener_id, haspi_score)
Expand Down

0 comments on commit 707c5a3

Please sign in to comment.