File size: 3,131 Bytes
a00b67a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os

import soundfile as sf
import torch
import pyloudnorm as pyln
import librosa
import matplotlib
import matplotlib.pyplot as plt

from dataloader import SingleTrackSet
from utils import db2linear


def conv_tasnet_separate(
    args, our_model, device, track_audio, track_name, meter=None, augmented_gain=None
):

    if args.use_singletrackset:
        db = SingleTrackSet(
            track_audio.squeeze(dim=0),
            hop_length=args.data_params.nhop,
            num_frame=128,
            target_name=args.target,
        )
        separated = []

        for item in db:
            item = item.unsqueeze(0).to(device)
            estimates, *estimates_vars = our_model(item)
            if args.task_params.dataset == "delimit":
                estimates = estimates_vars[0]

            estimates = estimates.cpu().detach()
            separated.append(
                estimates[..., db.trim_length : -db.trim_length].cpu().detach().clone()
            )

        estimates = torch.cat(separated, dim=-1)
        estimates = estimates[0, :, : track_audio.shape[-1]].numpy()
    else:
        estimates, *estimates_vars = our_model(track_audio)
        if args.save_histogram and args.task_params.dataset == "delimit":
            plt.figure(figsize=(10, 10))
            plt.hist(estimates.cpu().detach().numpy().flatten(), bins=100)
            os.makedirs(f"{args.test_output_dir}/{track_name}", exist_ok=True)
            plt.savefig(
                f"{args.test_output_dir}/{track_name}/{args.target}_histogram.png"
            )
        if args.task_params.dataset == "delimit":
            estimates = estimates_vars[0]

        estimates = estimates.cpu().detach().numpy()
        estimates = estimates[0, :, : track_audio.shape[-1]]

    if args.save_name_as_target:
        os.makedirs(f"{args.test_output_dir}/{track_name}", exist_ok=True)

    if args.save_output_loudnorm:
        print("SAVE Loudness normalized OUTPUT ")
        loudness = meter.integrated_loudness(estimates.T)
        estimates = estimates * db2linear(args.save_output_loudnorm - loudness, eps=0.0)
    elif augmented_gain != None and args.save_output_loudnorm == None:
        estimates = estimates * db2linear(-augmented_gain, eps=0.0)

    sf.write(
        f"{args.test_output_dir}/{track_name}/{args.target}.wav"
        if args.save_name_as_target
        else f"{args.test_output_dir}/{track_name}.wav",
        estimates.T,
        samplerate=args.data_params.sample_rate,
    )

    if args.save_16k_mono:
        estimates_16k_mono = librosa.to_mono(estimates)
        estimates_16k_mono = librosa.resample(
            estimates_16k_mono,
            orig_sr=args.data_params.sample_rate,
            target_sr=16000,
        )
        os.makedirs(f"{args.test_output_dir}_16k_mono/{track_name}", exist_ok=True)
        sf.write(
            f"{args.test_output_dir}_16k_mono/{track_name}/{args.target}.wav"
            if args.save_name_as_target
            else f"{args.test_output_dir}_16k_mono/{track_name}.wav",
            estimates_16k_mono,
            samplerate=16000,
        )

    return estimates