File size: 7,394 Bytes
fe781a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import argparse
import os
from pathlib import Path
import sys
import torchaudio
import numpy as np
from time import time
import torch
import typing as tp
from omegaconf import OmegaConf
from vocos import VocosDecoder
from models.soundstream_hubert_new import SoundStream
from tqdm import tqdm

def build_soundstream_model(config):
    model = eval(config.generator.name)(**config.generator.config)
    return model

def build_codec_model(config_path, vocal_decoder_path, inst_decoder_path):
    vocal_decoder = VocosDecoder.from_hparams(config_path=config_path)
    vocal_decoder.load_state_dict(torch.load(vocal_decoder_path))
    inst_decoder = VocosDecoder.from_hparams(config_path=config_path)
    inst_decoder.load_state_dict(torch.load(inst_decoder_path))
    return vocal_decoder, inst_decoder

def save_audio(wav: torch.Tensor, path: tp.Union[Path, str], sample_rate: int, rescale: bool = False):
    limit = 0.99
    mx = wav.abs().max()
    if rescale:
        wav = wav * min(limit / mx, 1)
    else:
        wav = wav.clamp(-limit, limit)
    
    path = str(Path(path).with_suffix('.mp3'))
    torchaudio.save(path, wav, sample_rate=sample_rate)

def process_audio(input_file, output_file, rescale, args, decoder, soundstream):
    compressed = np.load(input_file, allow_pickle=True).astype(np.int16)
    print(f"Processing {input_file}")
    print(f"Compressed shape: {compressed.shape}")
    
    args.bw = float(4)
    compressed = torch.as_tensor(compressed, dtype=torch.long).unsqueeze(1)
    compressed = soundstream.get_embed(compressed.to(f"cuda:{args.cuda_idx}"))
    compressed = torch.tensor(compressed).to(f"cuda:{args.cuda_idx}")
    
    start_time = time()
    with torch.no_grad():
        decoder.eval()
        decoder = decoder.to(f"cuda:{args.cuda_idx}")
        out = decoder(compressed)
        out = out.detach().cpu()
    duration = time() - start_time
    rtf = (out.shape[1] / 44100.0) / duration
    print(f"Decoded in {duration:.2f}s ({rtf:.2f}x RTF)")
    
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    save_audio(out, output_file, 44100, rescale=rescale)
    print(f"Saved: {output_file}")
    return out

def find_matching_pairs(input_folder):
    if str(input_folder).endswith('.lst'):  # Convert to string
        with open(input_folder, 'r') as file:
            files = [line.strip() for line in file if line.strip()]
    else:
        files = list(Path(input_folder).glob('*.npy'))
    print(f"found {len(files)} npy.")
    instrumental_files = {}
    vocal_files = {}
    
    for file in files:
        if not isinstance(file, Path):
            file = Path(file)
        name = file.stem
        if 'instrumental' in name.lower():
            base_name = name.lower().replace('instrumental', '')#.strip('_')
            instrumental_files[base_name] = file
        elif 'vocal' in name.lower():
            # base_name = name.lower().replace('vocal', '').strip('_')
            last_index = name.lower().rfind('vocal')
            if last_index != -1:
                # Create a new string with the last 'vocal' removed
                base_name = name.lower()[:last_index] + name.lower()[last_index + len('vocal'):]
            else:
                base_name = name.lower()
            vocal_files[base_name] = file
    
    # Find matching pairs
    pairs = []
    for base_name in instrumental_files.keys():
        if base_name in vocal_files:
            pairs.append((
                instrumental_files[base_name],
                vocal_files[base_name],
                base_name
            ))
    
    return pairs

def main():
    parser = argparse.ArgumentParser(description='High fidelity neural audio codec using Vocos decoder.')
    parser.add_argument('--input_folder', type=Path, required=True, help='Input folder containing NPY files.')
    parser.add_argument('--output_base', type=Path, required=True, help='Base output folder.')
    parser.add_argument('--resume_path', type=str, default='./final_ckpt/ckpt_00360000.pth', help='Path to model checkpoint.')
    parser.add_argument('--config_path', type=str, default='./config.yaml', help='Path to Vocos config file.')
    parser.add_argument('--vocal_decoder_path', type=str, default='/aifs4su/mmcode/codeclm/xcodec_mini_infer_newdecoder/decoders/decoder_131000.pth', help='Path to Vocos decoder weights.')
    parser.add_argument('--inst_decoder_path', type=str, default='/aifs4su/mmcode/codeclm/xcodec_mini_infer_newdecoder/decoders/decoder_151000.pth', help='Path to Vocos decoder weights.')
    parser.add_argument('-r', '--rescale', action='store_true', help='Rescale output to avoid clipping.')
    args = parser.parse_args()

    # Validate inputs
    if not args.input_folder.exists():
        sys.exit(f"Input folder {args.input_folder} does not exist.")
    if not os.path.isfile(args.config_path):
        sys.exit(f"{args.config_path} file does not exist.")
    # if not os.path.isfile(args.decoder_path):
    #     sys.exit(f"{args.decoder_path} file does not exist.")

    # Create output directories
    mix_dir = args.output_base / 'mix'
    stems_dir = args.output_base / 'stems'
    os.makedirs(mix_dir, exist_ok=True)
    os.makedirs(stems_dir, exist_ok=True)

    # Initialize models
    config_ss = OmegaConf.load("./final_ckpt/config.yaml")
    soundstream = build_soundstream_model(config_ss)
    parameter_dict = torch.load(args.resume_path)
    soundstream.load_state_dict(parameter_dict['codec_model'])
    soundstream.eval()
    
    vocal_decoder, inst_decoder = build_codec_model(args.config_path, args.vocal_decoder_path, args.inst_decoder_path)
    
    # Find and process matching pairs
    pairs = find_matching_pairs(args.input_folder)
    print(f"Found {len(pairs)} matching pairs")
    pairs = [p for p in pairs if not os.path.exists(mix_dir / f'{p[2]}.mp3')]
    print(f"{len(pairs)} to reconstruct...")
    
    for instrumental_file, vocal_file, base_name in tqdm(pairs):
        print(f"\nProcessing pair: {base_name}")
        # Create stems directory for this song
        song_stems_dir = stems_dir / base_name
        os.makedirs(song_stems_dir, exist_ok=True)
        
        try:
            # Process instrumental
            instrumental_output = process_audio(
                instrumental_file,
                song_stems_dir / 'instrumental.mp3',
                args.rescale,
                args,
                inst_decoder,
                soundstream
            )
            
            # Process vocal
            vocal_output = process_audio(
                vocal_file,
                song_stems_dir / 'vocal.mp3',
                args.rescale,
                args,
                vocal_decoder,
                soundstream
            )
        except IndexError as e:
            print(e)
            continue
        
        # Create and save mix
        try:
            mix_output = instrumental_output + vocal_output
            save_audio(mix_output, mix_dir / f'{base_name}.mp3', 44100, args.rescale)
            print(f"Created mix: {mix_dir / f'{base_name}.mp3'}")
        except RuntimeError as e:
            print(e)
            print(f"mix {base_name} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")

if __name__ == '__main__':
    main()

    # Example Usage
    # python reconstruct_separately.py --input_folder test_samples --output_base test