Curinha commited on
Commit
451ea25
1 Parent(s): f26b7a5

Refactor app.py to remove GPU quota management and simplify sound generation endpoints

Browse files
Files changed (1) hide show
  1. app.py +16 -179
app.py CHANGED
@@ -1,14 +1,8 @@
1
- import asyncio
2
- from datetime import datetime
3
  import os
4
- import random
5
- import time
6
- from typing import Dict, List
7
- import torch
8
  import uvicorn
9
 
10
  from sound_generator import generate_sound, generate_music
11
- from fastapi import Depends, FastAPI, HTTPException, Request
12
  from fastapi.middleware.cors import CORSMiddleware
13
  from fastapi.templating import Jinja2Templates
14
  from fastapi.responses import FileResponse, HTMLResponse
@@ -24,9 +18,10 @@ app = FastAPI(
24
  )
25
 
26
 
27
- # Configuraci贸n de templates
28
  templates = Jinja2Templates(directory="templates")
29
 
 
30
  # Configuraci贸n de CORS
31
  app.add_middleware(
32
  CORSMiddleware,
@@ -36,146 +31,18 @@ app.add_middleware(
36
  allow_headers=["*"],
37
  )
38
 
 
 
39
  class AudioRequest(BaseModel):
40
  prompt: str
41
 
42
- class GPUQuotaConfig:
43
- MAX_REQUEST_DURATION = 20 # segundos m谩ximos por solicitud
44
- DAILY_QUOTA = 300 # 5 minutos en total (300 segundos)
45
-
46
- class QuotaTracker:
47
- def __init__(self):
48
- self.users_quota: Dict[str, int] = {}
49
- self.user_reset_times: Dict[str, datetime] = {}
50
- self.current_user_index = 0
51
- self.registered_users: List[str] = []
52
-
53
- def register_user(self, user_id: str):
54
- if user_id not in self.registered_users:
55
- self.registered_users.append(user_id)
56
- self.users_quota[user_id] = GPUQuotaConfig.DAILY_QUOTA
57
- self.user_reset_times[user_id] = datetime.now() + datetime.timedelta(days=1)
58
-
59
- def get_next_available_user(self):
60
- # Verificar resets
61
- for user_id in list(self.user_reset_times.keys()):
62
- if datetime.now() > self.user_reset_times[user_id]:
63
- self.users_quota[user_id] = GPUQuotaConfig.DAILY_QUOTA
64
- self.user_reset_times[user_id] = datetime.now() + datetime.timedelta(days=1)
65
-
66
- # Encontrar usuario con cuota
67
- attempts = 0
68
- while attempts < len(self.registered_users):
69
- self.current_user_index = (self.current_user_index + 1) % max(1, len(self.registered_users))
70
- current_user = self.registered_users[self.current_user_index]
71
- if self.users_quota.get(current_user, 0) >= GPUQuotaConfig.MAX_REQUEST_DURATION:
72
- return current_user
73
- attempts += 1
74
-
75
- return None
76
-
77
- def consume_quota(self, user_id: str, seconds: int):
78
- if user_id in self.users_quota:
79
- self.users_quota[user_id] = max(0, self.users_quota[user_id] - seconds)
80
- return True
81
- return False
82
-
83
- def get_remaining_quota(self, user_id: str):
84
- if user_id in self.users_quota:
85
- # Verificar si se debe resetear
86
- if datetime.now() > self.user_reset_times.get(user_id, datetime.max):
87
- self.users_quota[user_id] = GPUQuotaConfig.DAILY_QUOTA
88
- self.user_reset_times[user_id] = datetime.now() + datetime.timedelta(days=1)
89
- return self.users_quota[user_id]
90
- return 0
91
-
92
- def get_system_status(self):
93
- return {
94
- "registered_users": len(self.registered_users),
95
- "users_with_quota": sum(1 for q in self.users_quota.values() if q >= GPUQuotaConfig.MAX_REQUEST_DURATION),
96
- "total_available_seconds": sum(self.users_quota.values())
97
- }
98
-
99
- # Inicializar sistema
100
- quota_tracker = QuotaTracker()
101
-
102
- # Registrar usuarios virtuales
103
- for i in range(5):
104
- quota_tracker.register_user(f"virtual_user_{i}")
105
-
106
- # Sem谩foro para controlar acceso a GPU - solo una tarea a la vez
107
- gpu_semaphore = asyncio.Semaphore(1)
108
-
109
- # Middleware para asignar user_id
110
- @app.middleware("http")
111
- async def assign_user_id(request: Request, call_next):
112
- if "user-id" not in request.headers:
113
- request.state.user_id = f"anonymous_{random.randint(1000, 9999)}"
114
- quota_tracker.register_user(request.state.user_id)
115
- else:
116
- request.state.user_id = request.headers["user-id"]
117
- quota_tracker.register_user(request.state.user_id)
118
-
119
- response = await call_next(request)
120
- return response
121
-
122
- async def get_user_id(request: Request):
123
- return request.state.user_id
124
-
125
- # Funci贸n para manejar la generaci贸n con control de GPU
126
- async def process_with_gpu(generation_func, prompt, process_id):
127
- start_time = time.time()
128
- print(f"[{process_id}] Iniciando procesamiento GPU")
129
-
130
- # Buscar usuario con cuota disponible
131
- user_id = quota_tracker.get_next_available_user()
132
- if not user_id:
133
- raise HTTPException(status_code=429, detail="No hay cuota GPU disponible en el sistema")
134
-
135
- quota_available = quota_tracker.get_remaining_quota(user_id)
136
- print(f"[{process_id}] Usando cuota de usuario {user_id}: {quota_available}s disponibles")
137
-
138
- # Verificar si hay suficiente cuota
139
- if quota_available < GPUQuotaConfig.MAX_REQUEST_DURATION:
140
- raise HTTPException(status_code=429, detail=f"Cuota GPU insuficiente ({quota_available}s disponibles)")
141
-
142
- # Verificar que los modelos usen GPU si est谩 disponible
143
- use_gpu = torch.cuda.is_available()
144
- device = 'cuda' if use_gpu else 'cpu'
145
- print(f"[{process_id}] Usando dispositivo: {device}")
146
-
147
- try:
148
- # Llamar a la funci贸n de generaci贸n con l铆mite de tiempo
149
- audio_file_path = await asyncio.to_thread(
150
- generation_func, prompt, device, user_id
151
- )
152
-
153
- # Liberar memoria GPU si se utiliz贸
154
- if use_gpu:
155
- torch.cuda.empty_cache()
156
-
157
- # Calcular tiempo real usado
158
- elapsed_time = min(GPUQuotaConfig.MAX_REQUEST_DURATION, int(time.time() - start_time))
159
-
160
- # Consumir cuota
161
- quota_tracker.consume_quota(user_id, elapsed_time)
162
- print(f"[{process_id}] Procesamiento completado en {elapsed_time}s, cuota restante: {quota_tracker.get_remaining_quota(user_id)}s")
163
-
164
- return audio_file_path
165
-
166
- except Exception as e:
167
- # Asegurar que liberamos memoria en caso de error
168
- if use_gpu:
169
- torch.cuda.empty_cache()
170
- print(f"[{process_id}] Error: {str(e)}")
171
- raise e
172
-
173
 
174
- # Home page with API information
175
  @app.get("/", response_class=HTMLResponse)
176
  def home(request: Request):
 
177
  return templates.TemplateResponse("home.html", {"request": request})
178
 
 
179
  # Prueba para verificar si la API funciona - la dejamos por ahora para debugging
180
  @app.get("/health")
181
  def health_check():
@@ -184,15 +51,10 @@ def health_check():
184
 
185
 
186
  @app.post("/generate-sound/")
187
- async def generate_sound_endpoint(request: AudioRequest, user_id: str = Depends(get_user_id)):
188
  try:
189
- process_id = f"sound_{random.randint(1000, 9999)}"
190
-
191
- # Usar sem谩foro para asegurar acceso exclusivo a GPU
192
- async with gpu_semaphore:
193
- audio_file_path = await process_with_gpu(
194
- generate_sound, request.prompt, process_id
195
- )
196
 
197
  # Verifica si el archivo se ha generado correctamente
198
  if not os.path.exists(audio_file_path):
@@ -205,55 +67,30 @@ async def generate_sound_endpoint(request: AudioRequest, user_id: str = Depends(
205
  audio_file_path, media_type="audio/wav", filename="generated_audio.wav"
206
  )
207
 
208
- except HTTPException as e:
209
- # Reenviar excepciones HTTP
210
- raise e
211
  except Exception as e:
212
  raise HTTPException(status_code=500, detail=str(e))
213
 
 
214
  @app.post("/generate-music/")
215
- async def generate_music_endpoint(request: AudioRequest, user_id: str = Depends(get_user_id)):
216
  try:
217
- process_id = f"music_{random.randint(1000, 9999)}"
218
-
219
- # Usar sem谩foro para asegurar acceso exclusivo a GPU
220
- async with gpu_semaphore:
221
- audio_file_path = await process_with_gpu(
222
- generate_music, request.prompt, process_id
223
- )
224
 
225
- # Verifica si el archivo se ha generado correctamente
226
  if not os.path.exists(audio_file_path):
227
  raise HTTPException(
228
  status_code=404, detail="Archivo de audio no encontrado."
229
  )
230
 
231
- # Regresar el archivo generado como una respuesta de descarga
232
  return FileResponse(
233
  audio_file_path, media_type="audio/wav", filename="generated_audio.wav"
234
  )
235
 
236
- except HTTPException as e:
237
- # Reenviar excepciones HTTP
238
- raise e
239
  except Exception as e:
240
  raise HTTPException(status_code=500, detail=str(e))
241
 
242
- @app.get("/quota-status")
243
- async def quota_status_endpoint(user_id: str = Depends(get_user_id)):
244
- user_quota = quota_tracker.get_remaining_quota(user_id)
245
- system_status = quota_tracker.get_system_status()
246
-
247
- return {
248
- "user_id": user_id,
249
- "quota_remaining": user_quota,
250
- "reset_time": quota_tracker.user_reset_times.get(user_id, None),
251
- "system_status": system_status,
252
- "gpu_available": torch.cuda.is_available(),
253
- "device_info": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
254
- }
255
-
256
-
257
 
258
  if __name__ == "__main__":
259
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
1
  import os
 
 
 
 
2
  import uvicorn
3
 
4
  from sound_generator import generate_sound, generate_music
5
+ from fastapi import FastAPI, HTTPException, Request
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from fastapi.templating import Jinja2Templates
8
  from fastapi.responses import FileResponse, HTMLResponse
 
18
  )
19
 
20
 
21
+ # Cargar las plantillas desde la carpeta "templates"
22
  templates = Jinja2Templates(directory="templates")
23
 
24
+
25
  # Configuraci贸n de CORS
26
  app.add_middleware(
27
  CORSMiddleware,
 
31
  allow_headers=["*"],
32
  )
33
 
34
+
35
+ # Define a Pydantic model to handle the input prompt
36
  class AudioRequest(BaseModel):
37
  prompt: str
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
 
40
  @app.get("/", response_class=HTMLResponse)
41
  def home(request: Request):
42
+ """P谩gina de inicio con informaci贸n de la API"""
43
  return templates.TemplateResponse("home.html", {"request": request})
44
 
45
+
46
  # Prueba para verificar si la API funciona - la dejamos por ahora para debugging
47
  @app.get("/health")
48
  def health_check():
 
51
 
52
 
53
  @app.post("/generate-sound/")
54
+ async def generate_sound_endpoint(request: AudioRequest):
55
  try:
56
+ # Llamada a la funci贸n para generar el sonido
57
+ audio_file_path = generate_sound(request.prompt)
 
 
 
 
 
58
 
59
  # Verifica si el archivo se ha generado correctamente
60
  if not os.path.exists(audio_file_path):
 
67
  audio_file_path, media_type="audio/wav", filename="generated_audio.wav"
68
  )
69
 
 
 
 
70
  except Exception as e:
71
  raise HTTPException(status_code=500, detail=str(e))
72
 
73
+
74
  @app.post("/generate-music/")
75
+ async def generate_music_endpoint(request: AudioRequest):
76
  try:
77
+ # Call the synchronous generate_music function
78
+ audio_file_path = generate_music(request.prompt)
 
 
 
 
 
79
 
80
+ # Verifies if the file has been generated correctly
81
  if not os.path.exists(audio_file_path):
82
  raise HTTPException(
83
  status_code=404, detail="Archivo de audio no encontrado."
84
  )
85
 
86
+ # Return the generated file as a download response
87
  return FileResponse(
88
  audio_file_path, media_type="audio/wav", filename="generated_audio.wav"
89
  )
90
 
 
 
 
91
  except Exception as e:
92
  raise HTTPException(status_code=500, detail=str(e))
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  if __name__ == "__main__":
96
  uvicorn.run(app, host="0.0.0.0", port=7860)