Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| from spaces import GPU | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| from snac import SNAC | |
| import gradio as gr | |
| import os | |
| # Autenticación Hugging Face para modelo privado | |
| from huggingface_hub import login | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if hf_token: | |
| login(token=hf_token) | |
| # Config | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| base_model_id = "canopylabs/3b-es_it-pretrain-research_release" | |
| lora_model_id = "sirekist98/spanish_conversational_tts" | |
| snac_model_id = "hubertsiuzdak/snac_24khz" | |
| # Load models | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_auth_token=True) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| base_model_id, | |
| torch_dtype=torch.float16 if device.type == "cuda" else torch.float32, | |
| use_auth_token=True | |
| ) | |
| model = PeftModel.from_pretrained(base_model, lora_model_id, use_auth_token=True) | |
| model = model.to(device) | |
| model.eval() | |
| snac_model = SNAC.from_pretrained(snac_model_id).to(device) | |
| # Speakers | |
| speakers = [ | |
| "Alex", "Carmen", "Daniel", "Diego", "Hugo", "Lucía", "María", "Pablo", "Sofía" | |
| ] | |
| # Helper to decode tokens to audio | |
| def decode_snac(code_list): | |
| layer_1, layer_2, layer_3 = [], [], [] | |
| for i in range((len(code_list)+1)//7): | |
| layer_1.append(code_list[7*i]) | |
| layer_2.append(code_list[7*i+1]-4096) | |
| layer_3.append(code_list[7*i+2]-(2*4096)) | |
| layer_3.append(code_list[7*i+3]-(3*4096)) | |
| layer_2.append(code_list[7*i+4]-(4*4096)) | |
| layer_3.append(code_list[7*i+5]-(5*4096)) | |
| layer_3.append(code_list[7*i+6]-(6*4096)) | |
| device_snac = snac_model.quantizer.quantizers[0].codebook.weight.device | |
| layers = [ | |
| torch.tensor(layer_1).unsqueeze(0).to(device_snac), | |
| torch.tensor(layer_2).unsqueeze(0).to(device_snac), | |
| torch.tensor(layer_3).unsqueeze(0).to(device_snac), | |
| ] | |
| with torch.no_grad(): | |
| audio = snac_model.decode(layers).squeeze().cpu().numpy() | |
| return audio | |
| # Inference | |
| def tts(prompt, speaker): | |
| full_prompt = f"{speaker}: {prompt}" | |
| input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids.to(device) | |
| start_token = torch.tensor([[128259]], dtype=torch.long).to(device) | |
| end_tokens = torch.tensor([[128009, 128260]], dtype=torch.long).to(device) | |
| input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) | |
| padding_len = max(0, 4260 - input_ids.shape[1]) | |
| if padding_len > 0: | |
| pad = torch.full((1, padding_len), 128263, dtype=torch.long).to(device) | |
| input_ids = torch.cat([pad, input_ids], dim=1) | |
| attention_mask = torch.cat([ | |
| torch.zeros((1, padding_len), dtype=torch.long), | |
| torch.ones((1, input_ids.shape[1]-padding_len), dtype=torch.long) | |
| ], dim=1).to(device) | |
| else: | |
| attention_mask = torch.ones_like(input_ids, dtype=torch.long).to(device) | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=1200, | |
| do_sample=True, | |
| temperature=0.6, | |
| top_p=0.95, | |
| repetition_penalty=1.1, | |
| num_return_sequences=1, | |
| eos_token_id=128258, | |
| use_cache=True, | |
| ) | |
| token_to_find = 128257 | |
| token_to_remove = 128258 | |
| token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True) | |
| if len(token_indices[1]) > 0: | |
| last_occurrence_idx = token_indices[1][-1].item() | |
| cropped = generated_ids[:, last_occurrence_idx+1:] | |
| else: | |
| cropped = generated_ids | |
| cleaned = cropped[cropped != token_to_remove] | |
| trimmed = cleaned[: (len(cleaned) // 7) * 7] | |
| trimmed = [int(t) - 128266 for t in trimmed] | |
| audio = decode_snac(trimmed) | |
| return (24000, audio) | |
| # Gradio UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🗣️ Orpheus Spanish TTS — sin emociones\nSelecciona un *speaker* y escribe el texto.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox(label="Texto", placeholder="Escribe aquí el texto a locutar") | |
| speaker_dropdown = gr.Dropdown(choices=speakers, value=speakers[0], label="Speaker") | |
| submit_btn = gr.Button("Generar audio") | |
| with gr.Column(): | |
| audio_output = gr.Audio(label="Audio generado", type="numpy") | |
| submit_btn.click( | |
| fn=tts, | |
| inputs=[text_input, speaker_dropdown], | |
| outputs=audio_output, | |
| ) | |
| demo.queue().launch(show_error=True) | |