File size: 1,749 Bytes
a00b67a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import json

from torch.utils.data import DataLoader
import soundfile as sf
import tqdm

from dataloader import DelimitValidDataset


def main():
    # Parameters
    data_path = "/path/to/musdb18hq"
    save_path = (
        "/path/to/musdb18hq_custom_limiter_fixed_attack"
    )
    batch_size = 1
    num_workers = 1
    sr = 44100

    # Dataset
    dataset = DelimitValidDataset(
        root=data_path, use_custom_limiter=True, custom_limiter_attack_range=[2.0, 2.0]
    )
    data_loader = DataLoader(
        dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False
    )
    dict_valid_loudness = {}
    dict_limiter_params = {}
    # Preprocessing
    for (
        limited_audio,
        orig_audio,
        audio_name,
        loudness,
        custom_attack,
        custom_release,
    ) in tqdm.tqdm(data_loader):
        audio_name = audio_name[0]
        limited_audio = limited_audio[0].numpy()
        loudness = float(loudness[0].numpy())
        dict_valid_loudness[audio_name] = loudness
        dict_limiter_params[audio_name] = {
            "attack_ms": float(custom_attack[0].numpy()),
            "release_ms": float(custom_release[0].numpy()),
        }
        # Save audio
        os.makedirs(os.path.join(save_path, "valid"), exist_ok=True)
        audio_path = os.path.join(save_path, "valid", audio_name)
        sf.write(f"{audio_path}.wav", limited_audio.T, sr)
        # write json write code
    with open(os.path.join(save_path, "valid_loudness.json"), "w") as f:
        json.dump(dict_valid_loudness, f, indent=4)
    with open(os.path.join(save_path, "valid_limiter_params.json"), "w") as f:
        json.dump(dict_limiter_params, f, indent=4)


if __name__ == "__main__":
    main()