sirekist98 commited on
Commit
d1b0389
·
verified ·
1 Parent(s): 9207326

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -0
app.py CHANGED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from spaces import GPU
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from peft import PeftModel
5
+ from snac import SNAC
6
+ import gradio as gr
7
+
8
+ # Config
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ base_model_id = "canopylabs/3b-es_it-pretrain-research_release"
11
+ lora_model_id = "sirekist98/spanish_conversational_tts"
12
+ snac_model_id = "hubertsiuzdak/snac_24khz"
13
+
14
+ # Load models
15
+ tokenizer = AutoTokenizer.from_pretrained(base_model_id)
16
+ base_model = AutoModelForCausalLM.from_pretrained(
17
+ base_model_id,
18
+ torch_dtype=torch.float16 if device.type == "cuda" else torch.float32,
19
+ )
20
+ model = PeftModel.from_pretrained(base_model, lora_model_id)
21
+ model = model.to(device)
22
+ model.eval()
23
+
24
+ snac_model = SNAC.from_pretrained(snac_model_id).to(device)
25
+
26
+ # Speakers (sin emociones)
27
+ speakers = [
28
+ "Alex", "Carmen", "Daniel", "Diego", "Hugo", "Lucía", "María", "Pablo", "Sofía"
29
+ ]
30
+
31
+ # Helper to decode tokens to audio
32
+
33
+ def decode_snac(code_list):
34
+ layer_1, layer_2, layer_3 = [], [], []
35
+ for i in range((len(code_list)+1)//7):
36
+ layer_1.append(code_list[7*i])
37
+ layer_2.append(code_list[7*i+1]-4096)
38
+ layer_3.append(code_list[7*i+2]-(2*4096))
39
+ layer_3.append(code_list[7*i+3]-(3*4096))
40
+ layer_2.append(code_list[7*i+4]-(4*4096))
41
+ layer_3.append(code_list[7*i+5]-(5*4096))
42
+ layer_3.append(code_list[7*i+6]-(6*4096))
43
+
44
+ # Obtener dispositivo del primer codebook
45
+ device_snac = snac_model.quantizer.quantizers[0].codebook.weight.device
46
+
47
+ layers = [
48
+ torch.tensor(layer_1).unsqueeze(0).to(device_snac),
49
+ torch.tensor(layer_2).unsqueeze(0).to(device_snac),
50
+ torch.tensor(layer_3).unsqueeze(0).to(device_snac),
51
+ ]
52
+
53
+ with torch.no_grad():
54
+ audio = snac_model.decode(layers).squeeze().cpu().numpy()
55
+ return audio
56
+
57
+
58
+ # Inference (sin emociones)
59
+ @GPU
60
+ def tts(prompt, speaker):
61
+ # Estructura de prompt: "<SPEAKER>: <texto>"
62
+ full_prompt = f"{speaker}: {prompt}"
63
+
64
+ input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids.to(device)
65
+
66
+ # Tokens especiales (iguales que tu versión anterior)
67
+ start_token = torch.tensor([[128259]], dtype=torch.long).to(device)
68
+ end_tokens = torch.tensor([[128009, 128260]], dtype=torch.long).to(device)
69
+
70
+ input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
71
+
72
+ # Padding fijo a 4260 para que encaje con el entrenamiento
73
+ padding_len = max(0, 4260 - input_ids.shape[1])
74
+ if padding_len > 0:
75
+ pad = torch.full((1, padding_len), 128263, dtype=torch.long).to(device)
76
+ input_ids = torch.cat([pad, input_ids], dim=1)
77
+ attention_mask = torch.cat([
78
+ torch.zeros((1, padding_len), dtype=torch.long),
79
+ torch.ones((1, input_ids.shape[1]-padding_len), dtype=torch.long)
80
+ ], dim=1).to(device)
81
+ else:
82
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long).to(device)
83
+
84
+ with torch.no_grad():
85
+ generated_ids = model.generate(
86
+ input_ids=input_ids,
87
+ attention_mask=attention_mask,
88
+ max_new_tokens=1200,
89
+ do_sample=True,
90
+ temperature=0.6,
91
+ top_p=0.95,
92
+ repetition_penalty=1.1,
93
+ num_return_sequences=1,
94
+ eos_token_id=128258,
95
+ use_cache=True,
96
+ )
97
+
98
+ # Post-procesado: recortar desde el último token 128257 y limpiar 128258
99
+ token_to_find = 128257
100
+ token_to_remove = 128258
101
+ token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
102
+ if len(token_indices[1]) > 0:
103
+ last_occurrence_idx = token_indices[1][-1].item()
104
+ cropped = generated_ids[:, last_occurrence_idx+1:]
105
+ else:
106
+ cropped = generated_ids
107
+
108
+ cleaned = cropped[cropped != token_to_remove]
109
+
110
+ # Asegurar múltiplos de 7 y ajustar offset SNAC
111
+ trimmed = cleaned[: (len(cleaned) // 7) * 7]
112
+ trimmed = [int(t) - 128266 for t in trimmed]
113
+
114
+ audio = decode_snac(trimmed)
115
+ return (24000, audio)
116
+
117
+
118
+ # Gradio UI (simple: texto + speaker)
119
+ with gr.Blocks() as demo:
120
+ gr.Markdown("# 🗣️ Orpheus Spanish TTS — sin emociones\nSelecciona un *speaker* y escribe el texto.")
121
+
122
+ with gr.Row():
123
+ with gr.Column():
124
+ text_input = gr.Textbox(label="Texto", placeholder="Escribe aquí el texto a locutar")
125
+ speaker_dropdown = gr.Dropdown(choices=speakers, value=speakers[0], label="Speaker")
126
+ submit_btn = gr.Button("Generar audio")
127
+ with gr.Column():
128
+ audio_output = gr.Audio(label="Audio generado")
129
+
130
+ submit_btn.click(
131
+ fn=tts,
132
+ inputs=[text_input, speaker_dropdown],
133
+ outputs=audio_output,
134
+ )
135
+
136
+ if __name__ == "__main__":
137
+ demo.launch()