File size: 6,769 Bytes
3978e51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import shutil
from test import test_settings
from scripts.redact_config import redact_config
from utils import load_config
from pathlib import Path
import os
import numpy as np
import soundfile as sf
from typing import List, Dict

MODEL_CONFIGS = {
    'config_apollo.yaml': {'model_type': 'apollo'},
    'config_dnr_bandit_bsrnn_multi_mus64.yaml': {'model_type': 'bandit'},
    'config_dnr_bandit_v2_mus64.yaml': {'model_type': 'bandit_v2'},
    'config_drumsep.yaml': {'model_type': 'htdemucs'},
    'config_htdemucs_6stems.yaml': {'model_type': 'htdemucs'},
    'config_musdb18_bs_roformer.yaml': {'model_type': 'bs_roformer'},
    'config_musdb18_demucs3_mmi.yaml': {'model_type': 'htdemucs'},
    'config_musdb18_htdemucs.yaml': {'model_type': 'htdemucs'},
    'config_musdb18_mdx23c.yaml': {'model_type': 'mdx23c'},
    'config_musdb18_mel_band_roformer.yaml': {'model_type': 'mel_band_roformer'},
    'config_musdb18_mel_band_roformer_all_stems.yaml': {'model_type': 'mel_band_roformer'},
    'config_musdb18_scnet.yaml': {'model_type': 'scnet'},
    'config_musdb18_scnet_large.yaml': {'model_type': 'scnet'},
    # 'config_musdb18_scnet_large_starrytong.yaml': {'model_type': 'scnet'},
    'config_vocals_bandit_bsrnn_multi_mus64.yaml': {'model_type': 'bandit'},
    'config_vocals_bs_roformer.yaml': {'model_type': 'bs_roformer'},
    'config_vocals_htdemucs.yaml': {'model_type': 'htdemucs'},
    'config_vocals_mdx23c.yaml': {'model_type': 'mdx23c'},
    'config_vocals_mel_band_roformer.yaml': {'model_type': 'mel_band_roformer'},
    'config_vocals_scnet.yaml': {'model_type': 'scnet'},
    'config_vocals_scnet_large.yaml': {'model_type': 'scnet'},
    'config_vocals_scnet_unofficial.yaml': {'model_type': 'scnet_unofficial'},
    'config_vocals_segm_models.yaml': {'model_type': 'segm_models'},


    # 'config_vocals_swin_upernet.yaml': {'model_type': 'swin_upernet'},
    # 'config_musdb18_torchseg.yaml': {'model_type': 'torchseg'},
    # 'config_musdb18_segm_models.yaml': {'model_type': 'segm_models'},
    # 'config_musdb18_bs_mamba2.yaml': {'model_type': 'bs_mamba2'},
    # 'config_vocals_bs_mamba2.yaml': {'model_type': 'bs_mamba2'},
    # 'config_vocals_torchseg.yaml': {'model_type': 'torchseg'}
}


# Folders for tests
ROOT_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
CONFIGS_DIR = ROOT_DIR / 'configs/'
TEST_DIR = ROOT_DIR / "tests_cache/"
TRAIN_DIR = TEST_DIR / "train_tracks/"
VALID_DIR = TEST_DIR / "valid_tracks/"


def create_dummy_tracks(directory: Path, num_tracks: int, instruments: List[str],
                        duration: float = 5.0, sample_rate: int = 44100) -> None:
    """
    Generates random audio tracks for stems in two subdirectories within the specified directory.

    Parameters:
    ----------
    directory : Path
        Path to the directory where the tracks will be saved.
    num_tracks : int
        Number of tracks to generate in each folder.
    instruments : List[str]
        List of instrument names (stems) to create.
    duration : float, optional
        Duration of each track in seconds. Default is 5.0.
    sample_rate : int, optional
        Sampling rate of the generated audio. Default is 44100 Hz.

    Returns:
    -------
    None
    """

    os.makedirs(directory, exist_ok=True)

    for folder_name in [str(i) for i in range(1, num_tracks+1)]:
        folder_path = directory / folder_name
        os.makedirs(folder_path, exist_ok=True)
        for instrument in instruments:
            # Generate random noice for each track
            samples = int(duration * sample_rate)
            track = np.random.uniform(-1.0, 1.0, (2, samples)).astype(np.float32)
            file_path = folder_path / f"{instrument}.wav"
            sf.write(file_path, track.T, sample_rate)


def cleanup_test_tracks() -> None:
    """
    Removes all cached test tracks.

    This function deletes the entire directory specified by the global `TEST_DIR` variable
    if it exists.

    Returns:
    -------
    None
        This function does not return a value. It performs cleanup of test data.
    """


def modify_configs() -> Dict[str, Path]:
    """
    Updates configuration files in the `configs` directory for use with test data.

    This function processes configuration files defined in the global `MODEL_CONFIGS` dictionary,
    modifies them to be compatible with test scenarios, and saves the updated configurations
    in a test-specific directory.

    Returns:
    -------
    Dict[str, Path]
        A dictionary where the keys are the original configuration file names, and the values
        are the paths to the updated configuration files.
    """
    config_dir = CONFIGS_DIR
    updated_configs = {}
    for config, args in MODEL_CONFIGS.items():
        model_type = args['model_type']
        config_path = config_dir / config
        updated_config_path = redact_config({
            'orig_config': str(config_path),
            'model_type': model_type,
            'new_config': str(TEST_DIR / 'configs' / config)
        })
        updated_configs[config] = updated_config_path
    return updated_configs


def run_tests() -> None:
    """
    Executes validation tests for all configurations.

    This function updates configurations, generates random dummy data for testing,
    and runs a series of tests (training, validation, and inference checks) for each
    model configuration specified in the global `MODEL_CONFIGS` dictionary.

    Returns:
    -------
    None
    """

    updated_configs = modify_configs()

    # For every config
    for config, args in MODEL_CONFIGS.items():
        model_type = args['model_type']
        cfg = load_config(model_type=model_type, config_path=TEST_DIR / 'configs' / config)
        # Random tracks
        create_dummy_tracks(TRAIN_DIR, instruments=cfg.training.instruments+['mixture'], num_tracks=2)
        create_dummy_tracks(VALID_DIR, instruments=cfg.training.instruments+['mixture'], num_tracks=2)

        print(f"\nRunning tests for model: {model_type} (config: {config})")

        test_args = {
            'check_train': False,
            'check_valid': True,
            'check_inference': True,
            'config_path': updated_configs[config],
            'data_path': str(TRAIN_DIR),
            'valid_path': str(VALID_DIR),
            'results_path': str(TEST_DIR / "results" / model_type),
            'store_dir': str(TEST_DIR / "inference_results" / model_type),
            'metrics': ['sdr', 'si_sdr', 'l1_freq']
        }

        test_args.update(args)

        test_settings(test_args, 'admin')
        print(f"Tests for model {model_type} completed successfully.")

    # Remove test_cache
    cleanup_test_tracks()


if __name__ == "__main__":
    run_tests()