|
import shutil |
|
from test import test_settings |
|
from scripts.redact_config import redact_config |
|
from utils import load_config |
|
from pathlib import Path |
|
import os |
|
import numpy as np |
|
import soundfile as sf |
|
from typing import List, Dict |
|
|
|
MODEL_CONFIGS = { |
|
'config_apollo.yaml': {'model_type': 'apollo'}, |
|
'config_dnr_bandit_bsrnn_multi_mus64.yaml': {'model_type': 'bandit'}, |
|
'config_dnr_bandit_v2_mus64.yaml': {'model_type': 'bandit_v2'}, |
|
'config_drumsep.yaml': {'model_type': 'htdemucs'}, |
|
'config_htdemucs_6stems.yaml': {'model_type': 'htdemucs'}, |
|
'config_musdb18_bs_roformer.yaml': {'model_type': 'bs_roformer'}, |
|
'config_musdb18_demucs3_mmi.yaml': {'model_type': 'htdemucs'}, |
|
'config_musdb18_htdemucs.yaml': {'model_type': 'htdemucs'}, |
|
'config_musdb18_mdx23c.yaml': {'model_type': 'mdx23c'}, |
|
'config_musdb18_mel_band_roformer.yaml': {'model_type': 'mel_band_roformer'}, |
|
'config_musdb18_mel_band_roformer_all_stems.yaml': {'model_type': 'mel_band_roformer'}, |
|
'config_musdb18_scnet.yaml': {'model_type': 'scnet'}, |
|
'config_musdb18_scnet_large.yaml': {'model_type': 'scnet'}, |
|
|
|
'config_vocals_bandit_bsrnn_multi_mus64.yaml': {'model_type': 'bandit'}, |
|
'config_vocals_bs_roformer.yaml': {'model_type': 'bs_roformer'}, |
|
'config_vocals_htdemucs.yaml': {'model_type': 'htdemucs'}, |
|
'config_vocals_mdx23c.yaml': {'model_type': 'mdx23c'}, |
|
'config_vocals_mel_band_roformer.yaml': {'model_type': 'mel_band_roformer'}, |
|
'config_vocals_scnet.yaml': {'model_type': 'scnet'}, |
|
'config_vocals_scnet_large.yaml': {'model_type': 'scnet'}, |
|
'config_vocals_scnet_unofficial.yaml': {'model_type': 'scnet_unofficial'}, |
|
'config_vocals_segm_models.yaml': {'model_type': 'segm_models'}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
ROOT_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
CONFIGS_DIR = ROOT_DIR / 'configs/' |
|
TEST_DIR = ROOT_DIR / "tests_cache/" |
|
TRAIN_DIR = TEST_DIR / "train_tracks/" |
|
VALID_DIR = TEST_DIR / "valid_tracks/" |
|
|
|
|
|
def create_dummy_tracks(directory: Path, num_tracks: int, instruments: List[str], |
|
duration: float = 5.0, sample_rate: int = 44100) -> None: |
|
""" |
|
Generates random audio tracks for stems in two subdirectories within the specified directory. |
|
|
|
Parameters: |
|
---------- |
|
directory : Path |
|
Path to the directory where the tracks will be saved. |
|
num_tracks : int |
|
Number of tracks to generate in each folder. |
|
instruments : List[str] |
|
List of instrument names (stems) to create. |
|
duration : float, optional |
|
Duration of each track in seconds. Default is 5.0. |
|
sample_rate : int, optional |
|
Sampling rate of the generated audio. Default is 44100 Hz. |
|
|
|
Returns: |
|
------- |
|
None |
|
""" |
|
|
|
os.makedirs(directory, exist_ok=True) |
|
|
|
for folder_name in [str(i) for i in range(1, num_tracks+1)]: |
|
folder_path = directory / folder_name |
|
os.makedirs(folder_path, exist_ok=True) |
|
for instrument in instruments: |
|
|
|
samples = int(duration * sample_rate) |
|
track = np.random.uniform(-1.0, 1.0, (2, samples)).astype(np.float32) |
|
file_path = folder_path / f"{instrument}.wav" |
|
sf.write(file_path, track.T, sample_rate) |
|
|
|
|
|
def cleanup_test_tracks() -> None: |
|
""" |
|
Removes all cached test tracks. |
|
|
|
This function deletes the entire directory specified by the global `TEST_DIR` variable |
|
if it exists. |
|
|
|
Returns: |
|
------- |
|
None |
|
This function does not return a value. It performs cleanup of test data. |
|
""" |
|
|
|
|
|
def modify_configs() -> Dict[str, Path]: |
|
""" |
|
Updates configuration files in the `configs` directory for use with test data. |
|
|
|
This function processes configuration files defined in the global `MODEL_CONFIGS` dictionary, |
|
modifies them to be compatible with test scenarios, and saves the updated configurations |
|
in a test-specific directory. |
|
|
|
Returns: |
|
------- |
|
Dict[str, Path] |
|
A dictionary where the keys are the original configuration file names, and the values |
|
are the paths to the updated configuration files. |
|
""" |
|
config_dir = CONFIGS_DIR |
|
updated_configs = {} |
|
for config, args in MODEL_CONFIGS.items(): |
|
model_type = args['model_type'] |
|
config_path = config_dir / config |
|
updated_config_path = redact_config({ |
|
'orig_config': str(config_path), |
|
'model_type': model_type, |
|
'new_config': str(TEST_DIR / 'configs' / config) |
|
}) |
|
updated_configs[config] = updated_config_path |
|
return updated_configs |
|
|
|
|
|
def run_tests() -> None: |
|
""" |
|
Executes validation tests for all configurations. |
|
|
|
This function updates configurations, generates random dummy data for testing, |
|
and runs a series of tests (training, validation, and inference checks) for each |
|
model configuration specified in the global `MODEL_CONFIGS` dictionary. |
|
|
|
Returns: |
|
------- |
|
None |
|
""" |
|
|
|
updated_configs = modify_configs() |
|
|
|
|
|
for config, args in MODEL_CONFIGS.items(): |
|
model_type = args['model_type'] |
|
cfg = load_config(model_type=model_type, config_path=TEST_DIR / 'configs' / config) |
|
|
|
create_dummy_tracks(TRAIN_DIR, instruments=cfg.training.instruments+['mixture'], num_tracks=2) |
|
create_dummy_tracks(VALID_DIR, instruments=cfg.training.instruments+['mixture'], num_tracks=2) |
|
|
|
print(f"\nRunning tests for model: {model_type} (config: {config})") |
|
|
|
test_args = { |
|
'check_train': False, |
|
'check_valid': True, |
|
'check_inference': True, |
|
'config_path': updated_configs[config], |
|
'data_path': str(TRAIN_DIR), |
|
'valid_path': str(VALID_DIR), |
|
'results_path': str(TEST_DIR / "results" / model_type), |
|
'store_dir': str(TEST_DIR / "inference_results" / model_type), |
|
'metrics': ['sdr', 'si_sdr', 'l1_freq'] |
|
} |
|
|
|
test_args.update(args) |
|
|
|
test_settings(test_args, 'admin') |
|
print(f"Tests for model {model_type} completed successfully.") |
|
|
|
|
|
cleanup_test_tracks() |
|
|
|
|
|
if __name__ == "__main__": |
|
run_tests() |
|
|