Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -56,36 +56,33 @@ from omegaconf import OmegaConf | |
| 56 | 
             
            import torchaudio
         | 
| 57 | 
             
            from torchaudio.transforms import Resample
         | 
| 58 | 
             
            import soundfile as sf
         | 
| 59 | 
            -
             | 
| 60 | 
             
            from tqdm import tqdm
         | 
| 61 | 
             
            from einops import rearrange
         | 
| 62 | 
             
            from codecmanipulator import CodecManipulator
         | 
| 63 | 
             
            from mmtokenizer import _MMSentencePieceTokenizer
         | 
| 64 | 
             
            from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
         | 
| 65 | 
            -
            import glob
         | 
| 66 | 
            -
            import time
         | 
| 67 | 
            -
            import copy
         | 
| 68 | 
            -
            from collections import Counter
         | 
| 69 | 
             
            from models.soundstream_hubert_new import SoundStream
         | 
| 70 | 
             
            from vocoder import build_codec_model, process_audio
         | 
| 71 | 
             
            from post_process_audio import replace_low_freq_with_energy_matched
         | 
| 72 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 73 | 
             
            device = "cuda:0"
         | 
| 74 |  | 
|  | |
|  | |
| 75 | 
             
            model = AutoModelForCausalLM.from_pretrained(
         | 
| 76 | 
             
                "m-a-p/YuE-s1-7B-anneal-en-cot",
         | 
| 77 | 
             
                torch_dtype=torch.float16,
         | 
| 78 | 
            -
                attn_implementation="flash_attention_2", | 
| 79 | 
            -
            ).to(device)
         | 
| 80 | 
            -
            # assistant_model = AutoModelForCausalLM.from_pretrained(
         | 
| 81 | 
            -
            #     "m-a-p/YuE-s2-1B-general",
         | 
| 82 | 
            -
            #     torch_dtype=torch.float16,
         | 
| 83 | 
            -
            #     attn_implementation="flash_attention_2",  # To enable flashattn, you have to install flash-attn
         | 
| 84 | 
            -
            # ).to(device)
         | 
| 85 | 
            -
            # assistant_model = torch.compile(assistant_model)
         | 
| 86 | 
            -
            # model = torch.compile(model)
         | 
| 87 | 
            -
            # assistant_model.eval()
         | 
| 88 | 
            -
            model.eval()
         | 
| 89 |  | 
| 90 | 
             
            basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
         | 
| 91 | 
             
            resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
         | 
| @@ -93,308 +90,130 @@ config_path = './xcodec_mini_infer/decoders/config.yaml' | |
| 93 | 
             
            vocal_decoder_path = './xcodec_mini_infer/decoders/decoder_131000.pth'
         | 
| 94 | 
             
            inst_decoder_path = './xcodec_mini_infer/decoders/decoder_151000.pth'
         | 
| 95 |  | 
| 96 | 
            -
            mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
         | 
| 97 | 
            -
             | 
| 98 | 
            -
            codectool = CodecManipulator("xcodec", 0, 1)
         | 
| 99 | 
            -
            model_config = OmegaConf.load(basic_model_config)
         | 
| 100 | 
             
            # Load codec model
         | 
|  | |
| 101 | 
             
            codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
         | 
| 102 | 
            -
             | 
| 103 | 
            -
            codec_model.load_state_dict(parameter_dict['codec_model'])
         | 
| 104 | 
            -
            # codec_model = torch.compile(codec_model)
         | 
| 105 | 
             
            codec_model.eval()
         | 
| 106 |  | 
| 107 | 
             
            # Preload and compile vocoders
         | 
| 108 | 
             
            vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
         | 
| 109 | 
            -
            vocal_decoder.to(device)
         | 
| 110 | 
            -
            inst_decoder.to(device)
         | 
| 111 | 
            -
            # vocal_decoder = torch.compile(vocal_decoder)
         | 
| 112 | 
            -
            # inst_decoder = torch.compile(inst_decoder)
         | 
| 113 | 
            -
            vocal_decoder.eval()
         | 
| 114 | 
            -
            inst_decoder.eval()
         | 
| 115 | 
            -
             | 
| 116 | 
            -
             | 
| 117 | 
            -
            def generate_music(
         | 
| 118 | 
            -
                    max_new_tokens=5,
         | 
| 119 | 
            -
                    run_n_segments=2,
         | 
| 120 | 
            -
                    genre_txt=None,
         | 
| 121 | 
            -
                    lyrics_txt=None,
         | 
| 122 | 
            -
                    use_audio_prompt=False,
         | 
| 123 | 
            -
                    audio_prompt_path="",
         | 
| 124 | 
            -
                    prompt_start_time=0.0,
         | 
| 125 | 
            -
                    prompt_end_time=30.0,
         | 
| 126 | 
            -
                    cuda_idx=0,
         | 
| 127 | 
            -
                    rescale=False,
         | 
| 128 | 
            -
            ):
         | 
| 129 | 
            -
                if use_audio_prompt and not audio_prompt_path:
         | 
| 130 | 
            -
                    raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
         | 
| 131 | 
            -
                cuda_idx = cuda_idx
         | 
| 132 | 
            -
                max_new_tokens = max_new_tokens * 100
         | 
| 133 | 
            -
                
         | 
| 134 | 
            -
                with tempfile.TemporaryDirectory() as output_dir:
         | 
| 135 | 
            -
                    stage1_output_dir = os.path.join(output_dir, f"stage1")
         | 
| 136 | 
            -
                    os.makedirs(stage1_output_dir, exist_ok=True)
         | 
| 137 | 
            -
             | 
| 138 | 
            -
                    class BlockTokenRangeProcessor(LogitsProcessor):
         | 
| 139 | 
            -
                        def __init__(self, start_id, end_id):
         | 
| 140 | 
            -
                            self.blocked_token_ids = list(range(start_id, end_id))
         | 
| 141 | 
            -
             | 
| 142 | 
            -
                        def __call__(self, input_ids, scores):
         | 
| 143 | 
            -
                            scores[:, self.blocked_token_ids] = -float("inf")
         | 
| 144 | 
            -
                            return scores
         | 
| 145 | 
            -
             | 
| 146 | 
            -
                    def load_audio_mono(filepath, sampling_rate=16000):
         | 
| 147 | 
            -
                        audio, sr = torchaudio.load(filepath)
         | 
| 148 | 
            -
                        # Convert to mono
         | 
| 149 | 
            -
                        audio = torch.mean(audio, dim=0, keepdim=True)
         | 
| 150 | 
            -
                        # Resample if needed
         | 
| 151 | 
            -
                        if sr != sampling_rate:
         | 
| 152 | 
            -
                            resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
         | 
| 153 | 
            -
                            audio = resampler(audio)
         | 
| 154 | 
            -
                        return audio
         | 
| 155 | 
            -
             | 
| 156 | 
            -
                    def split_lyrics(lyrics: str):
         | 
| 157 | 
            -
                        pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
         | 
| 158 | 
            -
                        segments = re.findall(pattern, lyrics, re.DOTALL)
         | 
| 159 | 
            -
                        structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
         | 
| 160 | 
            -
                        return structured_lyrics
         | 
| 161 | 
            -
             | 
| 162 | 
            -
                    # Call the function and print the result
         | 
| 163 | 
            -
                    stage1_output_set = []
         | 
| 164 | 
            -
             | 
| 165 | 
            -
                    genres = genre_txt.strip()
         | 
| 166 | 
            -
                    lyrics = split_lyrics(lyrics_txt + "\n")
         | 
| 167 | 
            -
                    # intruction
         | 
| 168 | 
            -
                    full_lyrics = "\n".join(lyrics)
         | 
| 169 | 
            -
                    prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
         | 
| 170 | 
            -
                    prompt_texts += lyrics
         | 
| 171 | 
            -
             | 
| 172 | 
            -
                    random_id = uuid.uuid4()
         | 
| 173 | 
            -
                    output_seq = None
         | 
| 174 | 
            -
                    # Here is suggested decoding config
         | 
| 175 | 
            -
                    top_p = 0.93
         | 
| 176 | 
            -
                    temperature = 1.0
         | 
| 177 | 
            -
                    repetition_penalty = 1.2
         | 
| 178 | 
            -
                    # special tokens
         | 
| 179 | 
            -
                    start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
         | 
| 180 | 
            -
                    end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
         | 
| 181 | 
            -
             | 
| 182 | 
            -
                    raw_output = None
         | 
| 183 | 
            -
             | 
| 184 | 
            -
                    # Format text prompt
         | 
| 185 | 
            -
                    run_n_segments = min(run_n_segments + 1, len(lyrics))
         | 
| 186 | 
            -
             | 
| 187 | 
            -
                    print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
         | 
| 188 | 
            -
             | 
| 189 | 
            -
                    for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
         | 
| 190 | 
            -
                        section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
         | 
| 191 | 
            -
                        guidance_scale = 1.5 if i <= 1 else 1.2
         | 
| 192 | 
            -
                        if i == 0:
         | 
| 193 | 
            -
                            continue
         | 
| 194 | 
            -
                        if i == 1:
         | 
| 195 | 
            -
                            if use_audio_prompt:
         | 
| 196 | 
            -
                                audio_prompt = load_audio_mono(audio_prompt_path)
         | 
| 197 | 
            -
                                audio_prompt.unsqueeze_(0)
         | 
| 198 | 
            -
                                with torch.no_grad():
         | 
| 199 | 
            -
                                    raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
         | 
| 200 | 
            -
                                raw_codes = raw_codes.transpose(0, 1)
         | 
| 201 | 
            -
                                raw_codes = raw_codes.cpu().numpy().astype(np.int16)
         | 
| 202 | 
            -
                                # Format audio prompt
         | 
| 203 | 
            -
                                code_ids = codectool.npy2ids(raw_codes[0])
         | 
| 204 | 
            -
                                audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)]  # 50 is tps of xcodec
         | 
| 205 | 
            -
                                audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [
         | 
| 206 | 
            -
                                    mmtokenizer.eoa]
         | 
| 207 | 
            -
                                sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize(
         | 
| 208 | 
            -
                                    "[end_of_reference]")
         | 
| 209 | 
            -
                                head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
         | 
| 210 | 
            -
                            else:
         | 
| 211 | 
            -
                                head_id = mmtokenizer.tokenize(prompt_texts[0])
         | 
| 212 | 
            -
                            prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
         | 
| 213 | 
            -
                        else:
         | 
| 214 | 
            -
                            prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
         | 
| 215 | 
            -
             | 
| 216 | 
            -
                        prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
         | 
| 217 | 
            -
                        input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
         | 
| 218 | 
            -
                        # Use window slicing in case output sequence exceeds the context of model
         | 
| 219 | 
            -
                        max_context = 16384 - max_new_tokens - 1
         | 
| 220 | 
            -
                        if input_ids.shape[-1] > max_context:
         | 
| 221 | 
            -
                            print(
         | 
| 222 | 
            -
                                f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
         | 
| 223 | 
            -
                            input_ids = input_ids[:, -(max_context):]
         | 
| 224 | 
            -
                        with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
         | 
| 225 | 
            -
                            output_seq = model.generate(
         | 
| 226 | 
            -
                                input_ids=input_ids,
         | 
| 227 | 
            -
                                max_new_tokens=max_new_tokens,
         | 
| 228 | 
            -
                                min_new_tokens=100,
         | 
| 229 | 
            -
                                do_sample=True,
         | 
| 230 | 
            -
                                top_p=top_p,
         | 
| 231 | 
            -
                                temperature=temperature,
         | 
| 232 | 
            -
                                repetition_penalty=repetition_penalty,
         | 
| 233 | 
            -
                                eos_token_id=mmtokenizer.eoa,
         | 
| 234 | 
            -
                                pad_token_id=mmtokenizer.eoa,
         | 
| 235 | 
            -
                                logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
         | 
| 236 | 
            -
                                guidance_scale=guidance_scale,
         | 
| 237 | 
            -
                                use_cache=True,
         | 
| 238 | 
            -
                                top_k=50,
         | 
| 239 | 
            -
                                num_beams=1
         | 
| 240 | 
            -
                            )
         | 
| 241 | 
            -
                            if output_seq[0][-1].item() != mmtokenizer.eoa:
         | 
| 242 | 
            -
                                tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
         | 
| 243 | 
            -
                                output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
         | 
| 244 | 
            -
                        if i > 1:
         | 
| 245 | 
            -
                            raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
         | 
| 246 | 
            -
                        else:
         | 
| 247 | 
            -
                            raw_output = output_seq
         | 
| 248 | 
            -
                        print(len(raw_output))
         | 
| 249 | 
            -
             | 
| 250 | 
            -
                    # save raw output and check sanity
         | 
| 251 | 
            -
                    ids = raw_output[0].cpu().numpy()
         | 
| 252 | 
            -
                    soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
         | 
| 253 | 
            -
                    eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
         | 
| 254 | 
            -
                    if len(soa_idx) != len(eoa_idx):
         | 
| 255 | 
            -
                        raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
         | 
| 256 | 
            -
             | 
| 257 | 
            -
                    vocals = []
         | 
| 258 | 
            -
                    instrumentals = []
         | 
| 259 | 
            -
                    range_begin = 1 if use_audio_prompt else 0
         | 
| 260 | 
            -
                    for i in range(range_begin, len(soa_idx)):
         | 
| 261 | 
            -
                        codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
         | 
| 262 | 
            -
                        if codec_ids[0] == 32016:
         | 
| 263 | 
            -
                            codec_ids = codec_ids[1:]
         | 
| 264 | 
            -
                        codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
         | 
| 265 | 
            -
                        vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
         | 
| 266 | 
            -
                        vocals.append(vocals_ids)
         | 
| 267 | 
            -
                        instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
         | 
| 268 | 
            -
                        instrumentals.append(instrumentals_ids)
         | 
| 269 | 
            -
                    vocals = np.concatenate(vocals, axis=1)
         | 
| 270 | 
            -
                    instrumentals = np.concatenate(instrumentals, axis=1)
         | 
| 271 | 
            -
                    
         | 
| 272 | 
            -
                    vocal_save_path = os.path.join(stage1_output_dir, f"vocal_{random_id}".replace('.', '@') + '.npy')
         | 
| 273 | 
            -
                    inst_save_path = os.path.join(stage1_output_dir, f"instrumental_{random_id}".replace('.', '@') + '.npy')
         | 
| 274 | 
            -
                    np.save(vocal_save_path, vocals)
         | 
| 275 | 
            -
                    np.save(inst_save_path, instrumentals)
         | 
| 276 | 
            -
                    stage1_output_set.append(vocal_save_path)
         | 
| 277 | 
            -
                    stage1_output_set.append(inst_save_path)
         | 
| 278 | 
            -
                    
         | 
| 279 |  | 
| 280 | 
            -
             | 
| 281 | 
            -
             | 
| 282 | 
            -
             | 
| 283 | 
            -
                    def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
         | 
| 284 | 
            -
                        folder_path = os.path.dirname(path)
         | 
| 285 | 
            -
                        if not os.path.exists(folder_path):
         | 
| 286 | 
            -
                            os.makedirs(folder_path)
         | 
| 287 | 
            -
                        limit = 0.99
         | 
| 288 | 
            -
                        max_val = wav.abs().max()
         | 
| 289 | 
            -
                        wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
         | 
| 290 | 
            -
                        torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
         | 
| 291 | 
            -
             | 
| 292 | 
            -
                    # reconstruct tracks
         | 
| 293 | 
            -
                    recons_output_dir = os.path.join(output_dir, "recons")
         | 
| 294 | 
            -
                    recons_mix_dir = os.path.join(recons_output_dir, 'mix')
         | 
| 295 | 
            -
                    os.makedirs(recons_mix_dir, exist_ok=True)
         | 
| 296 | 
            -
                    tracks = []
         | 
| 297 | 
            -
                    for npy in stage1_output_set:
         | 
| 298 | 
            -
                        codec_result = np.load(npy)
         | 
| 299 | 
            -
                        decodec_rlt = []
         | 
| 300 | 
            -
                        with torch.no_grad():
         | 
| 301 | 
            -
                            decoded_waveform = codec_model.decode(
         | 
| 302 | 
            -
                                torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(
         | 
| 303 | 
            -
                                    device))
         | 
| 304 | 
            -
                        decoded_waveform = decoded_waveform.cpu().squeeze(0)
         | 
| 305 | 
            -
                        decodec_rlt.append(torch.as_tensor(decoded_waveform))
         | 
| 306 | 
            -
                        decodec_rlt = torch.cat(decodec_rlt, dim=-1)
         | 
| 307 | 
            -
                        save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
         | 
| 308 | 
            -
                        tracks.append(save_path)
         | 
| 309 | 
            -
                        save_audio(decodec_rlt, save_path, 16000)
         | 
| 310 | 
            -
                    # mix tracks
         | 
| 311 | 
            -
                    for inst_path in tracks:
         | 
| 312 | 
            -
                        try:
         | 
| 313 | 
            -
                            if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
         | 
| 314 | 
            -
                                    and 'instrumental' in inst_path:
         | 
| 315 | 
            -
                                # find pair
         | 
| 316 | 
            -
                                vocal_path = inst_path.replace('instrumental', 'vocal')
         | 
| 317 | 
            -
                                if not os.path.exists(vocal_path):
         | 
| 318 | 
            -
                                    continue
         | 
| 319 | 
            -
                                # mix
         | 
| 320 | 
            -
                                recons_mix = os.path.join(recons_mix_dir,
         | 
| 321 | 
            -
                                                          os.path.basename(inst_path).replace('instrumental', 'mixed'))
         | 
| 322 | 
            -
                                vocal_stem, sr = sf.read(inst_path)
         | 
| 323 | 
            -
                                instrumental_stem, _ = sf.read(vocal_path)
         | 
| 324 | 
            -
                                mix_stem = (vocal_stem + instrumental_stem) / 1
         | 
| 325 | 
            -
                                sf.write(recons_mix, mix_stem, sr)
         | 
| 326 | 
            -
                        except Exception as e:
         | 
| 327 | 
            -
                            print(e)
         | 
| 328 | 
            -
             | 
| 329 | 
            -
                    # vocoder to upsample audios
         | 
| 330 | 
            -
                    vocoder_output_dir = os.path.join(output_dir, 'vocoder')
         | 
| 331 | 
            -
                    vocoder_stems_dir = os.path.join(vocoder_output_dir, 'stems')
         | 
| 332 | 
            -
                    vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
         | 
| 333 | 
            -
                    os.makedirs(vocoder_mix_dir, exist_ok=True)
         | 
| 334 | 
            -
                    os.makedirs(vocoder_stems_dir, exist_ok=True)
         | 
| 335 | 
            -
                    instrumental_output = None
         | 
| 336 | 
            -
                    vocal_output = None
         | 
| 337 | 
            -
                    for npy in stage1_output_set:
         | 
| 338 | 
            -
                        if 'instrumental' in npy:
         | 
| 339 | 
            -
                            # Process instrumental
         | 
| 340 | 
            -
                            instrumental_output = process_audio(
         | 
| 341 | 
            -
                                npy,
         | 
| 342 | 
            -
                                os.path.join(vocoder_stems_dir, 'instrumental.mp3'),
         | 
| 343 | 
            -
                                rescale,
         | 
| 344 | 
            -
                                argparse.Namespace(**locals()),  # Convert local variables to argparse.Namespace
         | 
| 345 | 
            -
                                inst_decoder,
         | 
| 346 | 
            -
                                codec_model
         | 
| 347 | 
            -
                            )
         | 
| 348 | 
            -
                        else:
         | 
| 349 | 
            -
                            # Process vocal
         | 
| 350 | 
            -
                            vocal_output = process_audio(
         | 
| 351 | 
            -
                                npy,
         | 
| 352 | 
            -
                                os.path.join(vocoder_stems_dir, 'vocal.mp3'),
         | 
| 353 | 
            -
                                rescale,
         | 
| 354 | 
            -
                                argparse.Namespace(**locals()),  # Convert local variables to argparse.Namespace
         | 
| 355 | 
            -
                                vocal_decoder,
         | 
| 356 | 
            -
                                codec_model
         | 
| 357 | 
            -
                            )
         | 
| 358 | 
            -
                    # mix tracks
         | 
| 359 | 
            -
                    try:
         | 
| 360 | 
            -
                        mix_output = instrumental_output + vocal_output
         | 
| 361 | 
            -
                        vocoder_mix = os.path.join(vocoder_mix_dir, os.path.basename(recons_mix))
         | 
| 362 | 
            -
                        save_audio(mix_output, vocoder_mix, 44100, rescale)
         | 
| 363 | 
            -
                        print(f"Created mix: {vocoder_mix}")
         | 
| 364 | 
            -
                    except RuntimeError as e:
         | 
| 365 | 
            -
                        print(e)
         | 
| 366 | 
            -
                        print(f"mix {vocoder_mix} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")
         | 
| 367 | 
            -
             | 
| 368 | 
            -
                    # Post process
         | 
| 369 | 
            -
                    final_output_path = os.path.join(output_dir, os.path.basename(recons_mix))
         | 
| 370 | 
            -
                    replace_low_freq_with_energy_matched(
         | 
| 371 | 
            -
                        a_file=recons_mix,  # 16kHz
         | 
| 372 | 
            -
                        b_file=vocoder_mix,  # 48kHz
         | 
| 373 | 
            -
                        c_file=final_output_path,
         | 
| 374 | 
            -
                        cutoff_freq=5500.0
         | 
| 375 | 
            -
                    )
         | 
| 376 | 
            -
                    print("All process Done")
         | 
| 377 | 
            -
                    
         | 
| 378 | 
            -
                    # Load the final audio file and return the numpy array
         | 
| 379 | 
            -
                    final_audio, sr = torchaudio.load(final_output_path)
         | 
| 380 | 
            -
                    return (sr, final_audio.squeeze().numpy())
         | 
| 381 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 382 |  | 
| 383 | 
             
            @spaces.GPU(duration=120)
         | 
| 384 | 
             
            def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=10):
         | 
| 385 | 
            -
                # Execute the command
         | 
| 386 | 
             
                try:
         | 
| 387 | 
            -
                     | 
| 388 | 
            -
                                           cuda_idx=0, max_new_tokens=max_new_tokens)
         | 
| 389 | 
            -
                    return audio_data
         | 
| 390 | 
             
                except Exception as e:
         | 
| 391 | 
            -
                    gr.Warning("An Error  | 
| 392 | 
             
                    return None
         | 
| 393 | 
            -
                finally:
         | 
| 394 | 
            -
                    print("Temporary files deleted.")
         | 
| 395 | 
            -
             | 
| 396 |  | 
| 397 | 
            -
            # Gradio
         | 
| 398 | 
             
            with gr.Blocks() as demo:
         | 
| 399 | 
             
                with gr.Column():
         | 
| 400 | 
             
                    gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
         | 
| @@ -493,4 +312,5 @@ Living out my dreams with this mic and a deal | |
| 493 | 
             
                    inputs=[genre_txt, lyrics_txt, num_segments, max_new_tokens],
         | 
| 494 | 
             
                    outputs=[music_out]
         | 
| 495 | 
             
                )
         | 
|  | |
| 496 | 
             
            demo.queue().launch(show_error=True)
         | 
|  | |
| 56 | 
             
            import torchaudio
         | 
| 57 | 
             
            from torchaudio.transforms import Resample
         | 
| 58 | 
             
            import soundfile as sf
         | 
|  | |
| 59 | 
             
            from tqdm import tqdm
         | 
| 60 | 
             
            from einops import rearrange
         | 
| 61 | 
             
            from codecmanipulator import CodecManipulator
         | 
| 62 | 
             
            from mmtokenizer import _MMSentencePieceTokenizer
         | 
| 63 | 
             
            from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
         | 
|  | |
|  | |
|  | |
|  | |
| 64 | 
             
            from models.soundstream_hubert_new import SoundStream
         | 
| 65 | 
             
            from vocoder import build_codec_model, process_audio
         | 
| 66 | 
             
            from post_process_audio import replace_low_freq_with_energy_matched
         | 
| 67 |  | 
| 68 | 
            +
            # Install flash attention
         | 
| 69 | 
            +
            print("Installing flash-attn...")
         | 
| 70 | 
            +
            subprocess.run(
         | 
| 71 | 
            +
                "pip install flash-attn --no-build-isolation",
         | 
| 72 | 
            +
                env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
         | 
| 73 | 
            +
                shell=True,
         | 
| 74 | 
            +
            )
         | 
| 75 | 
            +
             | 
| 76 | 
            +
            # Initialize device
         | 
| 77 | 
             
            device = "cuda:0"
         | 
| 78 |  | 
| 79 | 
            +
            # Load models once and reuse
         | 
| 80 | 
            +
            print("Loading models...")
         | 
| 81 | 
             
            model = AutoModelForCausalLM.from_pretrained(
         | 
| 82 | 
             
                "m-a-p/YuE-s1-7B-anneal-en-cot",
         | 
| 83 | 
             
                torch_dtype=torch.float16,
         | 
| 84 | 
            +
                attn_implementation="flash_attention_2",
         | 
| 85 | 
            +
            ).to(device).eval()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 86 |  | 
| 87 | 
             
            basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
         | 
| 88 | 
             
            resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
         | 
|  | |
| 90 | 
             
            vocal_decoder_path = './xcodec_mini_infer/decoders/decoder_131000.pth'
         | 
| 91 | 
             
            inst_decoder_path = './xcodec_mini_infer/decoders/decoder_151000.pth'
         | 
| 92 |  | 
|  | |
|  | |
|  | |
|  | |
| 93 | 
             
            # Load codec model
         | 
| 94 | 
            +
            model_config = OmegaConf.load(basic_model_config)
         | 
| 95 | 
             
            codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
         | 
| 96 | 
            +
            codec_model.load_state_dict(torch.load(resume_path, map_location='cpu')['codec_model'])
         | 
|  | |
|  | |
| 97 | 
             
            codec_model.eval()
         | 
| 98 |  | 
| 99 | 
             
            # Preload and compile vocoders
         | 
| 100 | 
             
            vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
         | 
| 101 | 
            +
            vocal_decoder.to(device).eval()
         | 
| 102 | 
            +
            inst_decoder.to(device).eval()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 103 |  | 
| 104 | 
            +
            # Tokenizer and codec tool
         | 
| 105 | 
            +
            mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
         | 
| 106 | 
            +
            codectool = CodecManipulator("xcodec", 0, 1)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 107 |  | 
| 108 | 
            +
            def generate_music(genre_txt, lyrics_txt, max_new_tokens=5, run_n_segments=2, use_audio_prompt=False, audio_prompt_path="", prompt_start_time=0.0, prompt_end_time=30.0, rescale=False):
         | 
| 109 | 
            +
                if use_audio_prompt and not audio_prompt_path:
         | 
| 110 | 
            +
                    raise FileNotFoundError("Please provide an audio prompt filepath when enabling 'use_audio_prompt'!")
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                max_new_tokens *= 100
         | 
| 113 | 
            +
                top_p = 0.93
         | 
| 114 | 
            +
                temperature = 1.0
         | 
| 115 | 
            +
                repetition_penalty = 1.2
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                # Split lyrics into segments
         | 
| 118 | 
            +
                def split_lyrics(lyrics):
         | 
| 119 | 
            +
                    pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
         | 
| 120 | 
            +
                    segments = re.findall(pattern, lyrics, re.DOTALL)
         | 
| 121 | 
            +
                    return [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                lyrics = split_lyrics(lyrics_txt + "\n")
         | 
| 124 | 
            +
                full_lyrics = "\n".join(lyrics)
         | 
| 125 | 
            +
                prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genre_txt.strip()}\n{full_lyrics}"] + lyrics
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                raw_output = None
         | 
| 128 | 
            +
                stage1_output_set = []
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
         | 
| 131 | 
            +
                    section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
         | 
| 132 | 
            +
                    guidance_scale = 1.5 if i <= 1 else 1.2
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    if i == 0:
         | 
| 135 | 
            +
                        continue
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    if i == 1 and use_audio_prompt:
         | 
| 138 | 
            +
                        audio_prompt = load_audio_mono(audio_prompt_path)
         | 
| 139 | 
            +
                        audio_prompt = audio_prompt.unsqueeze(0).to(device)
         | 
| 140 | 
            +
                        raw_codes = codec_model.encode(audio_prompt, target_bw=0.5).transpose(0, 1).cpu().numpy().astype(np.int16)
         | 
| 141 | 
            +
                        audio_prompt_codec = codectool.npy2ids(raw_codes[0])[int(prompt_start_time * 50): int(prompt_end_time * 50)]
         | 
| 142 | 
            +
                        audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
         | 
| 143 | 
            +
                        sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
         | 
| 144 | 
            +
                        head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
         | 
| 145 | 
            +
                    else:
         | 
| 146 | 
            +
                        head_id = mmtokenizer.tokenize(prompt_texts[0])
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    prompt_ids = head_id + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
         | 
| 149 | 
            +
                    prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    max_context = 16384 - max_new_tokens - 1
         | 
| 154 | 
            +
                    if input_ids.shape[-1] > max_context:
         | 
| 155 | 
            +
                        input_ids = input_ids[:, -(max_context):]
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
         | 
| 158 | 
            +
                        output_seq = model.generate(
         | 
| 159 | 
            +
                            input_ids=input_ids,
         | 
| 160 | 
            +
                            max_new_tokens=max_new_tokens,
         | 
| 161 | 
            +
                            min_new_tokens=100,
         | 
| 162 | 
            +
                            do_sample=True,
         | 
| 163 | 
            +
                            top_p=top_p,
         | 
| 164 | 
            +
                            temperature=temperature,
         | 
| 165 | 
            +
                            repetition_penalty=repetition_penalty,
         | 
| 166 | 
            +
                            eos_token_id=mmtokenizer.eoa,
         | 
| 167 | 
            +
                            pad_token_id=mmtokenizer.eoa,
         | 
| 168 | 
            +
                            logits_processor=LogitsProcessorList([
         | 
| 169 | 
            +
                                BlockTokenRangeProcessor(0, 32002),
         | 
| 170 | 
            +
                                BlockTokenRangeProcessor(32016, 32016)
         | 
| 171 | 
            +
                            ]),
         | 
| 172 | 
            +
                            guidance_scale=guidance_scale,
         | 
| 173 | 
            +
                            use_cache=True,
         | 
| 174 | 
            +
                            top_k=50,
         | 
| 175 | 
            +
                            num_beams=1
         | 
| 176 | 
            +
                        )
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    if output_seq[0][-1].item() != mmtokenizer.eoa:
         | 
| 179 | 
            +
                        tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(device)
         | 
| 180 | 
            +
                        output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1) if i > 1 else output_seq
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                # Process and save outputs
         | 
| 185 | 
            +
                ids = raw_output[0].cpu().numpy()
         | 
| 186 | 
            +
                soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
         | 
| 187 | 
            +
                eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                vocals, instrumentals = [], []
         | 
| 190 | 
            +
                for i in range(len(soa_idx)):
         | 
| 191 | 
            +
                    codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
         | 
| 192 | 
            +
                    if codec_ids[0] == 32016:
         | 
| 193 | 
            +
                        codec_ids = codec_ids[1:]
         | 
| 194 | 
            +
                    codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
         | 
| 195 | 
            +
                    vocals.append(codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0]))
         | 
| 196 | 
            +
                    instrumentals.append(codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1]))
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                vocals = np.concatenate(vocals, axis=1)
         | 
| 199 | 
            +
                instrumentals = np.concatenate(instrumentals, axis=1)
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                # Decode and mix audio
         | 
| 202 | 
            +
                decoded_vocals = codec_model.decode(torch.as_tensor(vocals.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device)).cpu().squeeze(0)
         | 
| 203 | 
            +
                decoded_instrumentals = codec_model.decode(torch.as_tensor(instrumentals.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device)).cpu().squeeze(0)
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                mixed_audio = (decoded_vocals + decoded_instrumentals) / 2
         | 
| 206 | 
            +
                return (16000, mixed_audio.numpy())
         | 
| 207 |  | 
| 208 | 
             
            @spaces.GPU(duration=120)
         | 
| 209 | 
             
            def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=10):
         | 
|  | |
| 210 | 
             
                try:
         | 
| 211 | 
            +
                    return generate_music(genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments, max_new_tokens=max_new_tokens)
         | 
|  | |
|  | |
| 212 | 
             
                except Exception as e:
         | 
| 213 | 
            +
                    gr.Warning("An Error Occurred: " + str(e))
         | 
| 214 | 
             
                    return None
         | 
|  | |
|  | |
|  | |
| 215 |  | 
| 216 | 
            +
            # Gradio Interface
         | 
| 217 | 
             
            with gr.Blocks() as demo:
         | 
| 218 | 
             
                with gr.Column():
         | 
| 219 | 
             
                    gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
         | 
|  | |
| 312 | 
             
                    inputs=[genre_txt, lyrics_txt, num_segments, max_new_tokens],
         | 
| 313 | 
             
                    outputs=[music_out]
         | 
| 314 | 
             
                )
         | 
| 315 | 
            +
             | 
| 316 | 
             
            demo.queue().launch(show_error=True)
         | 
 
			
