sirekist98's picture
Update app.py
73f6e2b verified
raw
history blame
4.55 kB
from spaces import GPU
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from snac import SNAC
import gradio as gr
import os
# Autenticación Hugging Face para modelo privado
from huggingface_hub import login
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
# Config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model_id = "canopylabs/3b-es_it-pretrain-research_release"
lora_model_id = "sirekist98/spanish_conversational_tts"
snac_model_id = "hubertsiuzdak/snac_24khz"
# Load models
tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_auth_token=True)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype=torch.float16 if device.type == "cuda" else torch.float32,
use_auth_token=True
)
model = PeftModel.from_pretrained(base_model, lora_model_id, use_auth_token=True)
model = model.to(device)
model.eval()
snac_model = SNAC.from_pretrained(snac_model_id).to(device)
# Speakers
speakers = [
"Alex", "Carmen", "Daniel", "Diego", "Hugo", "Lucía", "María", "Pablo", "Sofía"
]
# Helper to decode tokens to audio
def decode_snac(code_list):
layer_1, layer_2, layer_3 = [], [], []
for i in range((len(code_list)+1)//7):
layer_1.append(code_list[7*i])
layer_2.append(code_list[7*i+1]-4096)
layer_3.append(code_list[7*i+2]-(2*4096))
layer_3.append(code_list[7*i+3]-(3*4096))
layer_2.append(code_list[7*i+4]-(4*4096))
layer_3.append(code_list[7*i+5]-(5*4096))
layer_3.append(code_list[7*i+6]-(6*4096))
device_snac = snac_model.quantizer.quantizers[0].codebook.weight.device
layers = [
torch.tensor(layer_1).unsqueeze(0).to(device_snac),
torch.tensor(layer_2).unsqueeze(0).to(device_snac),
torch.tensor(layer_3).unsqueeze(0).to(device_snac),
]
with torch.no_grad():
audio = snac_model.decode(layers).squeeze().cpu().numpy()
return audio
# Inference
@GPU
def tts(prompt, speaker):
full_prompt = f"{speaker}: {prompt}"
input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids.to(device)
start_token = torch.tensor([[128259]], dtype=torch.long).to(device)
end_tokens = torch.tensor([[128009, 128260]], dtype=torch.long).to(device)
input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
padding_len = max(0, 4260 - input_ids.shape[1])
if padding_len > 0:
pad = torch.full((1, padding_len), 128263, dtype=torch.long).to(device)
input_ids = torch.cat([pad, input_ids], dim=1)
attention_mask = torch.cat([
torch.zeros((1, padding_len), dtype=torch.long),
torch.ones((1, input_ids.shape[1]-padding_len), dtype=torch.long)
], dim=1).to(device)
else:
attention_mask = torch.ones_like(input_ids, dtype=torch.long).to(device)
with torch.no_grad():
generated_ids = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=1200,
do_sample=True,
temperature=0.6,
top_p=0.95,
repetition_penalty=1.1,
num_return_sequences=1,
eos_token_id=128258,
use_cache=True,
)
token_to_find = 128257
token_to_remove = 128258
token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
if len(token_indices[1]) > 0:
last_occurrence_idx = token_indices[1][-1].item()
cropped = generated_ids[:, last_occurrence_idx+1:]
else:
cropped = generated_ids
cleaned = cropped[cropped != token_to_remove]
trimmed = cleaned[: (len(cleaned) // 7) * 7]
trimmed = [int(t) - 128266 for t in trimmed]
audio = decode_snac(trimmed)
return (24000, audio)
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# 🗣️ Orpheus Spanish TTS — sin emociones\nSelecciona un *speaker* y escribe el texto.")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(label="Texto", placeholder="Escribe aquí el texto a locutar")
speaker_dropdown = gr.Dropdown(choices=speakers, value=speakers[0], label="Speaker")
submit_btn = gr.Button("Generar audio")
with gr.Column():
audio_output = gr.Audio(label="Audio generado", type="numpy")
submit_btn.click(
fn=tts,
inputs=[text_input, speaker_dropdown],
outputs=audio_output,
)
demo.queue().launch(show_error=True)