sirekist98 commited on
Commit
73f6e2b
·
verified ·
1 Parent(s): d1b0389

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -21
app.py CHANGED
@@ -4,6 +4,13 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from peft import PeftModel
5
  from snac import SNAC
6
  import gradio as gr
 
 
 
 
 
 
 
7
 
8
  # Config
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -12,24 +19,24 @@ lora_model_id = "sirekist98/spanish_conversational_tts"
12
  snac_model_id = "hubertsiuzdak/snac_24khz"
13
 
14
  # Load models
15
- tokenizer = AutoTokenizer.from_pretrained(base_model_id)
16
  base_model = AutoModelForCausalLM.from_pretrained(
17
  base_model_id,
18
  torch_dtype=torch.float16 if device.type == "cuda" else torch.float32,
 
19
  )
20
- model = PeftModel.from_pretrained(base_model, lora_model_id)
21
  model = model.to(device)
22
  model.eval()
23
 
24
  snac_model = SNAC.from_pretrained(snac_model_id).to(device)
25
 
26
- # Speakers (sin emociones)
27
  speakers = [
28
  "Alex", "Carmen", "Daniel", "Diego", "Hugo", "Lucía", "María", "Pablo", "Sofía"
29
  ]
30
 
31
  # Helper to decode tokens to audio
32
-
33
  def decode_snac(code_list):
34
  layer_1, layer_2, layer_3 = [], [], []
35
  for i in range((len(code_list)+1)//7):
@@ -41,9 +48,7 @@ def decode_snac(code_list):
41
  layer_3.append(code_list[7*i+5]-(5*4096))
42
  layer_3.append(code_list[7*i+6]-(6*4096))
43
 
44
- # Obtener dispositivo del primer codebook
45
  device_snac = snac_model.quantizer.quantizers[0].codebook.weight.device
46
-
47
  layers = [
48
  torch.tensor(layer_1).unsqueeze(0).to(device_snac),
49
  torch.tensor(layer_2).unsqueeze(0).to(device_snac),
@@ -54,22 +59,17 @@ def decode_snac(code_list):
54
  audio = snac_model.decode(layers).squeeze().cpu().numpy()
55
  return audio
56
 
57
-
58
- # Inference (sin emociones)
59
  @GPU
60
  def tts(prompt, speaker):
61
- # Estructura de prompt: "<SPEAKER>: <texto>"
62
  full_prompt = f"{speaker}: {prompt}"
63
-
64
  input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids.to(device)
65
 
66
- # Tokens especiales (iguales que tu versión anterior)
67
  start_token = torch.tensor([[128259]], dtype=torch.long).to(device)
68
  end_tokens = torch.tensor([[128009, 128260]], dtype=torch.long).to(device)
69
 
70
  input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
71
 
72
- # Padding fijo a 4260 para que encaje con el entrenamiento
73
  padding_len = max(0, 4260 - input_ids.shape[1])
74
  if padding_len > 0:
75
  pad = torch.full((1, padding_len), 128263, dtype=torch.long).to(device)
@@ -95,7 +95,6 @@ def tts(prompt, speaker):
95
  use_cache=True,
96
  )
97
 
98
- # Post-procesado: recortar desde el último token 128257 y limpiar 128258
99
  token_to_find = 128257
100
  token_to_remove = 128258
101
  token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
@@ -106,26 +105,22 @@ def tts(prompt, speaker):
106
  cropped = generated_ids
107
 
108
  cleaned = cropped[cropped != token_to_remove]
109
-
110
- # Asegurar múltiplos de 7 y ajustar offset SNAC
111
  trimmed = cleaned[: (len(cleaned) // 7) * 7]
112
  trimmed = [int(t) - 128266 for t in trimmed]
113
 
114
  audio = decode_snac(trimmed)
115
  return (24000, audio)
116
 
117
-
118
- # Gradio UI (simple: texto + speaker)
119
  with gr.Blocks() as demo:
120
  gr.Markdown("# 🗣️ Orpheus Spanish TTS — sin emociones\nSelecciona un *speaker* y escribe el texto.")
121
-
122
  with gr.Row():
123
  with gr.Column():
124
  text_input = gr.Textbox(label="Texto", placeholder="Escribe aquí el texto a locutar")
125
  speaker_dropdown = gr.Dropdown(choices=speakers, value=speakers[0], label="Speaker")
126
  submit_btn = gr.Button("Generar audio")
127
  with gr.Column():
128
- audio_output = gr.Audio(label="Audio generado")
129
 
130
  submit_btn.click(
131
  fn=tts,
@@ -133,5 +128,4 @@ with gr.Blocks() as demo:
133
  outputs=audio_output,
134
  )
135
 
136
- if __name__ == "__main__":
137
- demo.launch()
 
4
  from peft import PeftModel
5
  from snac import SNAC
6
  import gradio as gr
7
+ import os
8
+
9
+ # Autenticación Hugging Face para modelo privado
10
+ from huggingface_hub import login
11
+ hf_token = os.environ.get("HF_TOKEN")
12
+ if hf_token:
13
+ login(token=hf_token)
14
 
15
  # Config
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
19
  snac_model_id = "hubertsiuzdak/snac_24khz"
20
 
21
  # Load models
22
+ tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_auth_token=True)
23
  base_model = AutoModelForCausalLM.from_pretrained(
24
  base_model_id,
25
  torch_dtype=torch.float16 if device.type == "cuda" else torch.float32,
26
+ use_auth_token=True
27
  )
28
+ model = PeftModel.from_pretrained(base_model, lora_model_id, use_auth_token=True)
29
  model = model.to(device)
30
  model.eval()
31
 
32
  snac_model = SNAC.from_pretrained(snac_model_id).to(device)
33
 
34
+ # Speakers
35
  speakers = [
36
  "Alex", "Carmen", "Daniel", "Diego", "Hugo", "Lucía", "María", "Pablo", "Sofía"
37
  ]
38
 
39
  # Helper to decode tokens to audio
 
40
  def decode_snac(code_list):
41
  layer_1, layer_2, layer_3 = [], [], []
42
  for i in range((len(code_list)+1)//7):
 
48
  layer_3.append(code_list[7*i+5]-(5*4096))
49
  layer_3.append(code_list[7*i+6]-(6*4096))
50
 
 
51
  device_snac = snac_model.quantizer.quantizers[0].codebook.weight.device
 
52
  layers = [
53
  torch.tensor(layer_1).unsqueeze(0).to(device_snac),
54
  torch.tensor(layer_2).unsqueeze(0).to(device_snac),
 
59
  audio = snac_model.decode(layers).squeeze().cpu().numpy()
60
  return audio
61
 
62
+ # Inference
 
63
  @GPU
64
  def tts(prompt, speaker):
 
65
  full_prompt = f"{speaker}: {prompt}"
 
66
  input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids.to(device)
67
 
 
68
  start_token = torch.tensor([[128259]], dtype=torch.long).to(device)
69
  end_tokens = torch.tensor([[128009, 128260]], dtype=torch.long).to(device)
70
 
71
  input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
72
 
 
73
  padding_len = max(0, 4260 - input_ids.shape[1])
74
  if padding_len > 0:
75
  pad = torch.full((1, padding_len), 128263, dtype=torch.long).to(device)
 
95
  use_cache=True,
96
  )
97
 
 
98
  token_to_find = 128257
99
  token_to_remove = 128258
100
  token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
 
105
  cropped = generated_ids
106
 
107
  cleaned = cropped[cropped != token_to_remove]
 
 
108
  trimmed = cleaned[: (len(cleaned) // 7) * 7]
109
  trimmed = [int(t) - 128266 for t in trimmed]
110
 
111
  audio = decode_snac(trimmed)
112
  return (24000, audio)
113
 
114
+ # Gradio UI
 
115
  with gr.Blocks() as demo:
116
  gr.Markdown("# 🗣️ Orpheus Spanish TTS — sin emociones\nSelecciona un *speaker* y escribe el texto.")
 
117
  with gr.Row():
118
  with gr.Column():
119
  text_input = gr.Textbox(label="Texto", placeholder="Escribe aquí el texto a locutar")
120
  speaker_dropdown = gr.Dropdown(choices=speakers, value=speakers[0], label="Speaker")
121
  submit_btn = gr.Button("Generar audio")
122
  with gr.Column():
123
+ audio_output = gr.Audio(label="Audio generado", type="numpy")
124
 
125
  submit_btn.click(
126
  fn=tts,
 
128
  outputs=audio_output,
129
  )
130
 
131
+ demo.queue().launch(show_error=True)