Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,15 +10,15 @@ import torch
|
|
| 10 |
import whisper
|
| 11 |
import subprocess
|
| 12 |
from pydub import AudioSegment
|
| 13 |
-
import fitz
|
| 14 |
import docx
|
| 15 |
import yt_dlp
|
| 16 |
from functools import lru_cache
|
| 17 |
import gc
|
| 18 |
import time
|
| 19 |
from huggingface_hub import login
|
| 20 |
-
from
|
| 21 |
-
|
| 22 |
|
| 23 |
# Configure logging
|
| 24 |
logging.basicConfig(
|
|
@@ -30,626 +30,1171 @@ logger = logging.getLogger(__name__)
|
|
| 30 |
# Login to Hugging Face Hub if token is available
|
| 31 |
HUGGINGFACE_TOKEN = os.environ.get('HUGGINGFACE_TOKEN')
|
| 32 |
if HUGGINGFACE_TOKEN:
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
class ModelManager:
|
| 36 |
_instance = None
|
| 37 |
-
|
| 38 |
def __new__(cls):
|
| 39 |
if cls._instance is None:
|
| 40 |
cls._instance = super(ModelManager, cls).__new__(cls)
|
| 41 |
cls._instance._initialized = False
|
| 42 |
return cls._instance
|
| 43 |
-
|
| 44 |
def __init__(self):
|
| 45 |
if not self._initialized:
|
| 46 |
self.tokenizer = None
|
| 47 |
self.model = None
|
| 48 |
-
self.
|
| 49 |
self.whisper_model = None
|
| 50 |
self._initialized = True
|
| 51 |
self.last_used = time.time()
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
| 54 |
def initialize_llm(self):
|
| 55 |
-
"""Initialize LLM model with
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
try:
|
|
|
|
| 57 |
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
| 58 |
-
|
| 59 |
-
logger.info("Loading
|
| 60 |
-
self.
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
)
|
| 67 |
-
|
| 68 |
-
#
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
use_gradient_checkpointing = True,
|
| 78 |
-
random_state = 3407,
|
| 79 |
-
max_seq_length = 2048,
|
| 80 |
)
|
| 81 |
-
|
| 82 |
-
logger.info("LLM initialized successfully
|
| 83 |
self.last_used = time.time()
|
|
|
|
| 84 |
return True
|
| 85 |
-
|
| 86 |
except Exception as e:
|
| 87 |
logger.error(f"Error initializing LLM: {str(e)}")
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
-
@spaces.GPU()
|
| 91 |
def initialize_whisper(self):
|
| 92 |
-
"""Initialize Whisper model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
try:
|
| 94 |
logger.info("Loading Whisper model...")
|
| 95 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
self.whisper_model = whisper.load_model(
|
| 97 |
-
"tiny",
|
| 98 |
device="cuda" if torch.cuda.is_available() else "cpu",
|
| 99 |
-
download_root="/tmp/whisper"
|
| 100 |
-
weights_only=True # Security fix
|
| 101 |
)
|
| 102 |
logger.info("Whisper model initialized successfully")
|
| 103 |
self.last_used = time.time()
|
|
|
|
| 104 |
return True
|
| 105 |
except Exception as e:
|
| 106 |
logger.error(f"Error initializing Whisper: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
raise
|
| 108 |
|
| 109 |
def check_llm_initialized(self):
|
| 110 |
"""Check if LLM is initialized and initialize if needed"""
|
| 111 |
-
if self.tokenizer is None or self.model is None:
|
| 112 |
logger.info("LLM not initialized, initializing...")
|
| 113 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
self.last_used = time.time()
|
| 115 |
-
|
| 116 |
def check_whisper_initialized(self):
|
| 117 |
"""Check if Whisper model is initialized and initialize if needed"""
|
| 118 |
if self.whisper_model is None:
|
| 119 |
logger.info("Whisper model not initialized, initializing...")
|
| 120 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
self.last_used = time.time()
|
| 122 |
-
|
| 123 |
def reset_models(self, force=False):
|
| 124 |
"""Reset models to free memory if they haven't been used recently"""
|
| 125 |
current_time = time.time()
|
| 126 |
-
if
|
|
|
|
| 127 |
try:
|
| 128 |
logger.info("Resetting models to free memory...")
|
| 129 |
-
|
| 130 |
-
|
|
|
|
| 131 |
del self.model
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
| 134 |
del self.tokenizer
|
| 135 |
-
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
del self.whisper_model
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
if torch.cuda.is_available():
|
| 144 |
torch.cuda.empty_cache()
|
| 145 |
-
torch.cuda.synchronize()
|
| 146 |
-
|
|
|
|
| 147 |
gc.collect()
|
| 148 |
-
logger.info("Models reset successfully")
|
| 149 |
-
|
|
|
|
| 150 |
except Exception as e:
|
| 151 |
logger.error(f"Error resetting models: {str(e)}")
|
|
|
|
| 152 |
|
|
|
|
| 153 |
model_manager = ModelManager()
|
| 154 |
|
| 155 |
-
@lru_cache(maxsize=
|
| 156 |
def download_social_media_video(url):
|
| 157 |
-
"""Download
|
|
|
|
|
|
|
|
|
|
| 158 |
ydl_opts = {
|
| 159 |
'format': 'bestaudio/best',
|
| 160 |
'postprocessors': [{
|
| 161 |
'key': 'FFmpegExtractAudio',
|
| 162 |
'preferredcodec': 'mp3',
|
| 163 |
-
'preferredquality': '192',
|
| 164 |
}],
|
| 165 |
-
'outtmpl':
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
}
|
| 167 |
try:
|
|
|
|
| 168 |
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
| 169 |
info_dict = ydl.extract_info(url, download=True)
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
except Exception as e:
|
| 174 |
-
logger.error(f"
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
-
def convert_video_to_audio(
|
| 178 |
"""Convert a video file to audio using ffmpeg directly."""
|
| 179 |
try:
|
|
|
|
| 180 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
|
| 181 |
-
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
| 183 |
command = [
|
| 184 |
-
"ffmpeg",
|
| 185 |
-
"-i",
|
| 186 |
-
"-
|
| 187 |
-
"-
|
| 188 |
-
"-
|
| 189 |
-
|
| 190 |
-
"-
|
|
|
|
|
|
|
|
|
|
| 191 |
]
|
| 192 |
-
|
| 193 |
-
subprocess.run(command, check=True,
|
| 194 |
-
|
| 195 |
-
logger.
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
except Exception as e:
|
| 198 |
-
logger.error(f"Error converting video: {str(e)}")
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
-
def preprocess_audio(
|
| 202 |
-
"""Preprocess the audio file
|
| 203 |
try:
|
| 204 |
-
audio
|
| 205 |
-
audio =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
|
|
|
|
|
|
| 210 |
except Exception as e:
|
| 211 |
-
logger.error(f"Error preprocessing audio: {str(e)}")
|
|
|
|
|
|
|
|
|
|
| 212 |
raise
|
| 213 |
|
| 214 |
-
@spaces.GPU()
|
| 215 |
-
def
|
| 216 |
-
"""Transcribe an audio or video file."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
try:
|
| 218 |
model_manager.check_whisper_initialized()
|
| 219 |
-
|
| 220 |
-
if
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
else:
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
try:
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
return transcription
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
except Exception as e:
|
| 247 |
-
logger.error(f"
|
| 248 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
|
| 250 |
-
@lru_cache(maxsize=
|
| 251 |
def read_document(document_path):
|
| 252 |
-
"""Read the content of a document."""
|
| 253 |
try:
|
| 254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
doc = fitz.open(document_path)
|
| 256 |
-
|
| 257 |
-
|
|
|
|
|
|
|
| 258 |
doc = docx.Document(document_path)
|
| 259 |
return "\n".join([paragraph.text for paragraph in doc.paragraphs])
|
| 260 |
-
elif
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
else:
|
|
|
|
| 265 |
return "Unsupported file type. Please upload a PDF, DOCX, XLSX or CSV document."
|
|
|
|
|
|
|
|
|
|
| 266 |
except Exception as e:
|
| 267 |
-
logger.error(f"Error reading document: {str(e)}")
|
|
|
|
| 268 |
return f"Error reading document: {str(e)}"
|
| 269 |
|
| 270 |
-
@lru_cache(maxsize=
|
| 271 |
def read_url(url):
|
| 272 |
-
"""Read the content of a URL."""
|
| 273 |
-
if not url or url.strip()
|
| 274 |
-
|
| 275 |
-
|
|
|
|
| 276 |
try:
|
|
|
|
| 277 |
headers = {
|
| 278 |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
| 279 |
}
|
| 280 |
-
|
| 281 |
-
response.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
soup = BeautifulSoup(response.content, 'html.parser')
|
| 283 |
-
|
| 284 |
-
|
|
|
|
| 285 |
element.extract()
|
| 286 |
-
|
| 287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
if main_content:
|
| 289 |
text = main_content.get_text(separator='\n', strip=True)
|
| 290 |
else:
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
except Exception as e:
|
| 298 |
-
logger.error(f"Error
|
| 299 |
-
|
|
|
|
| 300 |
|
| 301 |
-
def
|
| 302 |
-
"""Process social media
|
| 303 |
-
if not url or url.strip()
|
|
|
|
| 304 |
return None
|
| 305 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
try:
|
| 307 |
text_content = read_url(url)
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
|
|
|
| 313 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
return {
|
| 315 |
-
"text": text_content,
|
| 316 |
-
"video":
|
| 317 |
}
|
| 318 |
-
|
| 319 |
-
logger.error(f"
|
| 320 |
-
return None
|
|
|
|
| 321 |
|
| 322 |
-
@spaces.GPU()
|
| 323 |
def generate_news(instructions, facts, size, tone, *args):
|
| 324 |
-
"""Generate a news article based on provided data"""
|
|
|
|
|
|
|
| 325 |
try:
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
|
|
|
|
|
|
| 329 |
size = 250
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
knowledge_base = {
|
| 334 |
-
"instructions": instructions or "",
|
| 335 |
-
"facts": facts or "",
|
| 336 |
"document_content": [],
|
| 337 |
"audio_data": [],
|
| 338 |
"url_content": [],
|
| 339 |
"social_content": []
|
| 340 |
}
|
|
|
|
| 341 |
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
})
|
| 380 |
-
|
| 381 |
-
logger.info("Processing social media content...")
|
| 382 |
-
for i in range(0, len(social_urls), 3):
|
| 383 |
-
if i+2 < len(social_urls):
|
| 384 |
-
social_url, social_name, social_context = social_urls[i:i+3]
|
| 385 |
-
if social_url and isinstance(social_url, str) and social_url.strip():
|
| 386 |
-
social_content = process_social_content(social_url)
|
| 387 |
-
if social_content:
|
| 388 |
-
knowledge_base["social_content"].append({
|
| 389 |
"url": social_url,
|
| 390 |
-
"name": social_name
|
| 391 |
-
"context": social_context
|
| 392 |
-
"text":
|
| 393 |
-
"
|
| 394 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
|
| 396 |
-
transcriptions_text = ""
|
| 397 |
-
raw_transcriptions = ""
|
| 398 |
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
|
|
|
|
| 407 |
for idx, data in enumerate(knowledge_base["social_content"]):
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
if len(url_content) > 1000:
|
| 432 |
-
url_excerpt = url_content[:1000] + "... [content continues]"
|
| 433 |
-
else:
|
| 434 |
-
url_excerpt = url_content
|
| 435 |
-
url_summaries.append(f"[URL {idx+1}]: {url_excerpt}")
|
| 436 |
-
|
| 437 |
-
url_content = "\n\n".join(url_summaries)
|
| 438 |
|
| 439 |
-
|
|
|
|
| 440 |
|
| 441 |
-
Instructions: {knowledge_base["instructions"]}
|
|
|
|
| 442 |
|
| 443 |
-
|
| 444 |
|
| 445 |
-
|
| 446 |
-
{
|
| 447 |
|
| 448 |
-
|
| 449 |
-
{
|
| 450 |
|
| 451 |
-
Use these
|
| 452 |
-
{
|
| 453 |
|
| 454 |
-
|
| 455 |
-
-
|
| 456 |
-
- Write a 15-word hook that complements the title
|
| 457 |
-
- Write the body
|
| 458 |
-
-
|
| 459 |
-
-
|
| 460 |
-
- Use
|
| 461 |
-
-
|
| 462 |
-
- Do
|
| 463 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
|
| 465 |
try:
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
max_length = min(len(prompt.split()) + size * 2, 2048)
|
| 469 |
-
|
| 470 |
-
inputs = model_manager.tokenizer(
|
| 471 |
prompt,
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
max_new_tokens = size + 100,
|
| 481 |
-
temperature = 0.7,
|
| 482 |
-
do_sample = True,
|
| 483 |
-
pad_token_id = model_manager.tokenizer.eos_token_id,
|
| 484 |
)
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
else:
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
except Exception as gen_error:
|
| 500 |
-
logger.error(f"Error
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 504 |
|
| 505 |
except Exception as e:
|
| 506 |
-
|
|
|
|
|
|
|
|
|
|
| 507 |
try:
|
| 508 |
model_manager.reset_models(force=True)
|
| 509 |
except Exception as reset_error:
|
| 510 |
-
logger.error(f"Failed to reset models: {str(reset_error)}")
|
| 511 |
-
|
|
|
|
|
|
|
|
|
|
| 512 |
|
| 513 |
def create_demo():
|
|
|
|
| 514 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 515 |
gr.Markdown("# π° NewsIA - AI News Generator")
|
| 516 |
-
gr.Markdown("Create professional news articles from multiple sources.")
|
| 517 |
-
|
|
|
|
|
|
|
|
|
|
| 518 |
with gr.Row():
|
| 519 |
with gr.Column(scale=2):
|
| 520 |
instructions = gr.Textbox(
|
| 521 |
-
label="News
|
| 522 |
-
placeholder="Enter specific instructions for news
|
| 523 |
-
lines=2
|
|
|
|
| 524 |
)
|
|
|
|
|
|
|
| 525 |
facts = gr.Textbox(
|
| 526 |
-
label="
|
| 527 |
-
placeholder="Describe the most important facts
|
| 528 |
-
lines=4
|
|
|
|
| 529 |
)
|
| 530 |
-
|
|
|
|
| 531 |
with gr.Row():
|
| 532 |
-
|
| 533 |
label="Approximate Length (words)",
|
| 534 |
minimum=100,
|
| 535 |
-
maximum=
|
| 536 |
value=250,
|
| 537 |
step=50
|
| 538 |
)
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
|
|
|
|
|
|
| 542 |
value="neutral"
|
| 543 |
)
|
|
|
|
| 544 |
|
| 545 |
with gr.Column(scale=3):
|
| 546 |
-
inputs_list = []
|
| 547 |
-
inputs_list.extend([instructions, facts, size, tone])
|
| 548 |
-
|
| 549 |
with gr.Tabs():
|
| 550 |
with gr.TabItem("π Documents"):
|
| 551 |
-
documents
|
|
|
|
| 552 |
for i in range(1, 6):
|
| 553 |
-
|
| 554 |
label=f"Document {i}",
|
| 555 |
-
file_types=["pdf", "docx", "xlsx", "csv"],
|
| 556 |
-
file_count="single"
|
| 557 |
)
|
| 558 |
-
|
| 559 |
-
|
| 560 |
|
| 561 |
with gr.TabItem("π Audio/Video"):
|
| 562 |
-
|
|
|
|
|
|
|
| 563 |
with gr.Group():
|
| 564 |
gr.Markdown(f"**Source {i}**")
|
| 565 |
-
|
| 566 |
-
label=f"Audio/Video {i}",
|
| 567 |
file_types=["audio", "video"]
|
| 568 |
)
|
| 569 |
with gr.Row():
|
| 570 |
-
|
| 571 |
-
label="Name",
|
| 572 |
-
placeholder="
|
|
|
|
| 573 |
)
|
| 574 |
-
|
| 575 |
-
label="Position
|
| 576 |
-
placeholder="
|
|
|
|
| 577 |
)
|
| 578 |
-
|
|
|
|
|
|
|
|
|
|
| 579 |
|
| 580 |
with gr.TabItem("π URLs"):
|
| 581 |
-
|
| 582 |
-
|
|
|
|
|
|
|
| 583 |
label=f"URL {i}",
|
| 584 |
-
placeholder="https
|
|
|
|
| 585 |
)
|
| 586 |
-
|
|
|
|
| 587 |
|
| 588 |
with gr.TabItem("π± Social Media"):
|
| 589 |
-
|
|
|
|
|
|
|
| 590 |
with gr.Group():
|
| 591 |
-
gr.Markdown(f"**Social Media {i}**")
|
| 592 |
-
|
| 593 |
-
label="URL",
|
| 594 |
-
placeholder="https
|
|
|
|
| 595 |
)
|
| 596 |
with gr.Row():
|
| 597 |
-
|
| 598 |
-
label="Account/
|
| 599 |
-
placeholder="
|
|
|
|
| 600 |
)
|
| 601 |
-
|
| 602 |
-
label="Context",
|
| 603 |
-
placeholder="
|
|
|
|
| 604 |
)
|
| 605 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 606 |
|
| 607 |
with gr.Row():
|
| 608 |
-
|
| 609 |
-
|
| 610 |
|
| 611 |
with gr.Tabs():
|
| 612 |
-
with gr.TabItem("π Generated News"):
|
| 613 |
news_output = gr.Textbox(
|
| 614 |
-
label="News
|
| 615 |
-
lines=
|
| 616 |
-
show_copy_button=True
|
|
|
|
| 617 |
)
|
| 618 |
-
|
| 619 |
-
with gr.TabItem("ποΈ Transcriptions"):
|
| 620 |
transcriptions_output = gr.Textbox(
|
| 621 |
-
label="
|
| 622 |
-
lines=
|
| 623 |
-
show_copy_button=True
|
|
|
|
| 624 |
)
|
| 625 |
|
| 626 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 627 |
fn=generate_news,
|
| 628 |
-
inputs=
|
| 629 |
-
outputs=
|
| 630 |
)
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 639 |
)
|
| 640 |
|
|
|
|
|
|
|
|
|
|
| 641 |
return demo
|
| 642 |
|
| 643 |
if __name__ == "__main__":
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
import whisper
|
| 11 |
import subprocess
|
| 12 |
from pydub import AudioSegment
|
| 13 |
+
import fitz # PyMuPDF
|
| 14 |
import docx
|
| 15 |
import yt_dlp
|
| 16 |
from functools import lru_cache
|
| 17 |
import gc
|
| 18 |
import time
|
| 19 |
from huggingface_hub import login
|
| 20 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
| 21 |
+
import traceback # For detailed error logging
|
| 22 |
|
| 23 |
# Configure logging
|
| 24 |
logging.basicConfig(
|
|
|
|
| 30 |
# Login to Hugging Face Hub if token is available
|
| 31 |
HUGGINGFACE_TOKEN = os.environ.get('HUGGINGFACE_TOKEN')
|
| 32 |
if HUGGINGFACE_TOKEN:
|
| 33 |
+
try:
|
| 34 |
+
login(token=HUGGINGFACE_TOKEN)
|
| 35 |
+
logger.info("Successfully logged in to Hugging Face Hub.")
|
| 36 |
+
except Exception as e:
|
| 37 |
+
logger.error(f"Failed to login to Hugging Face Hub: {e}")
|
| 38 |
|
| 39 |
class ModelManager:
|
| 40 |
_instance = None
|
| 41 |
+
|
| 42 |
def __new__(cls):
|
| 43 |
if cls._instance is None:
|
| 44 |
cls._instance = super(ModelManager, cls).__new__(cls)
|
| 45 |
cls._instance._initialized = False
|
| 46 |
return cls._instance
|
| 47 |
+
|
| 48 |
def __init__(self):
|
| 49 |
if not self._initialized:
|
| 50 |
self.tokenizer = None
|
| 51 |
self.model = None
|
| 52 |
+
self.text_pipeline = None # Renamed for clarity
|
| 53 |
self.whisper_model = None
|
| 54 |
self._initialized = True
|
| 55 |
self.last_used = time.time()
|
| 56 |
+
self.llm_loading = False
|
| 57 |
+
self.whisper_loading = False
|
| 58 |
+
|
| 59 |
+
@spaces.GPU(duration=120) # Increased duration for potentially long loads
|
| 60 |
def initialize_llm(self):
|
| 61 |
+
"""Initialize LLM model with standard transformers"""
|
| 62 |
+
if self.llm_loading:
|
| 63 |
+
logger.info("LLM initialization already in progress.")
|
| 64 |
+
return True # Assume it will succeed or fail elsewhere
|
| 65 |
+
if self.tokenizer and self.model and self.text_pipeline:
|
| 66 |
+
logger.info("LLM already initialized.")
|
| 67 |
+
self.last_used = time.time()
|
| 68 |
+
return True
|
| 69 |
+
|
| 70 |
+
self.llm_loading = True
|
| 71 |
try:
|
| 72 |
+
# Use small model for ZeroGPU compatibility
|
| 73 |
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
| 74 |
+
|
| 75 |
+
logger.info("Loading LLM tokenizer...")
|
| 76 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 77 |
+
MODEL_NAME,
|
| 78 |
+
token=HUGGINGFACE_TOKEN,
|
| 79 |
+
use_fast=True
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
if self.tokenizer.pad_token is None:
|
| 83 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 84 |
+
|
| 85 |
+
# Basic memory settings for ZeroGPU
|
| 86 |
+
logger.info("Loading LLM model...")
|
| 87 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 88 |
+
MODEL_NAME,
|
| 89 |
+
token=HUGGINGFACE_TOKEN,
|
| 90 |
+
device_map="auto",
|
| 91 |
+
torch_dtype=torch.float16,
|
| 92 |
+
low_cpu_mem_usage=True,
|
| 93 |
+
# Optimizations for ZeroGPU
|
| 94 |
+
# max_memory={0: "4GB"}, # Removed for better auto handling initially
|
| 95 |
+
offload_folder="offload",
|
| 96 |
+
offload_state_dict=True
|
| 97 |
)
|
| 98 |
+
|
| 99 |
+
# Create text generation pipeline
|
| 100 |
+
logger.info("Creating LLM text generation pipeline...")
|
| 101 |
+
self.text_pipeline = pipeline(
|
| 102 |
+
"text-generation",
|
| 103 |
+
model=self.model,
|
| 104 |
+
tokenizer=self.tokenizer,
|
| 105 |
+
torch_dtype=torch.float16,
|
| 106 |
+
device_map="auto",
|
| 107 |
+
max_length=1024 # Default max length
|
|
|
|
|
|
|
|
|
|
| 108 |
)
|
| 109 |
+
|
| 110 |
+
logger.info("LLM initialized successfully")
|
| 111 |
self.last_used = time.time()
|
| 112 |
+
self.llm_loading = False
|
| 113 |
return True
|
| 114 |
+
|
| 115 |
except Exception as e:
|
| 116 |
logger.error(f"Error initializing LLM: {str(e)}")
|
| 117 |
+
logger.error(traceback.format_exc()) # Log full traceback
|
| 118 |
+
# Reset partially loaded components
|
| 119 |
+
self.tokenizer = None
|
| 120 |
+
self.model = None
|
| 121 |
+
self.text_pipeline = None
|
| 122 |
+
if torch.cuda.is_available():
|
| 123 |
+
torch.cuda.empty_cache()
|
| 124 |
+
gc.collect()
|
| 125 |
+
self.llm_loading = False
|
| 126 |
+
raise # Re-raise the exception to signal failure
|
| 127 |
|
| 128 |
+
@spaces.GPU(duration=120) # Increased duration
|
| 129 |
def initialize_whisper(self):
|
| 130 |
+
"""Initialize Whisper model for audio transcription"""
|
| 131 |
+
if self.whisper_loading:
|
| 132 |
+
logger.info("Whisper initialization already in progress.")
|
| 133 |
+
return True
|
| 134 |
+
if self.whisper_model:
|
| 135 |
+
logger.info("Whisper already initialized.")
|
| 136 |
+
self.last_used = time.time()
|
| 137 |
+
return True
|
| 138 |
+
|
| 139 |
+
self.whisper_loading = True
|
| 140 |
try:
|
| 141 |
logger.info("Loading Whisper model...")
|
| 142 |
+
# Using tiny model for efficiency but can be changed based on needs
|
| 143 |
+
# Specify weights_only=True to address the FutureWarning
|
| 144 |
+
# Note: Whisper's load_model might not directly support weights_only yet.
|
| 145 |
+
# If it errors, remove the weights_only=True. The warning is mainly informative.
|
| 146 |
+
# Let's attempt without weights_only first as whisper might handle it internally
|
| 147 |
self.whisper_model = whisper.load_model(
|
| 148 |
+
"tiny", # Consider "base" for better accuracy if "tiny" struggles
|
| 149 |
device="cuda" if torch.cuda.is_available() else "cpu",
|
| 150 |
+
download_root="/tmp/whisper" # Use persistent storage if available/needed
|
|
|
|
| 151 |
)
|
| 152 |
logger.info("Whisper model initialized successfully")
|
| 153 |
self.last_used = time.time()
|
| 154 |
+
self.whisper_loading = False
|
| 155 |
return True
|
| 156 |
except Exception as e:
|
| 157 |
logger.error(f"Error initializing Whisper: {str(e)}")
|
| 158 |
+
logger.error(traceback.format_exc())
|
| 159 |
+
self.whisper_model = None
|
| 160 |
+
if torch.cuda.is_available():
|
| 161 |
+
torch.cuda.empty_cache()
|
| 162 |
+
gc.collect()
|
| 163 |
+
self.whisper_loading = False
|
| 164 |
raise
|
| 165 |
|
| 166 |
def check_llm_initialized(self):
|
| 167 |
"""Check if LLM is initialized and initialize if needed"""
|
| 168 |
+
if self.tokenizer is None or self.model is None or self.text_pipeline is None:
|
| 169 |
logger.info("LLM not initialized, initializing...")
|
| 170 |
+
if not self.llm_loading: # Prevent re-entry if already loading
|
| 171 |
+
self.initialize_llm()
|
| 172 |
+
else:
|
| 173 |
+
logger.info("LLM initialization is already in progress by another request.")
|
| 174 |
+
# Optional: Wait a bit for the other process to finish
|
| 175 |
+
time.sleep(5)
|
| 176 |
+
if self.tokenizer is None or self.model is None or self.text_pipeline is None:
|
| 177 |
+
raise RuntimeError("LLM initialization timed out or failed.")
|
| 178 |
self.last_used = time.time()
|
| 179 |
+
|
| 180 |
def check_whisper_initialized(self):
|
| 181 |
"""Check if Whisper model is initialized and initialize if needed"""
|
| 182 |
if self.whisper_model is None:
|
| 183 |
logger.info("Whisper model not initialized, initializing...")
|
| 184 |
+
if not self.whisper_loading: # Prevent re-entry
|
| 185 |
+
self.initialize_whisper()
|
| 186 |
+
else:
|
| 187 |
+
logger.info("Whisper initialization is already in progress by another request.")
|
| 188 |
+
time.sleep(5)
|
| 189 |
+
if self.whisper_model is None:
|
| 190 |
+
raise RuntimeError("Whisper initialization timed out or failed.")
|
| 191 |
self.last_used = time.time()
|
| 192 |
+
|
| 193 |
def reset_models(self, force=False):
|
| 194 |
"""Reset models to free memory if they haven't been used recently"""
|
| 195 |
current_time = time.time()
|
| 196 |
+
# Only reset if forced or models haven't been used for 10 minutes (600 seconds)
|
| 197 |
+
if force or (current_time - self.last_used > 600):
|
| 198 |
try:
|
| 199 |
logger.info("Resetting models to free memory...")
|
| 200 |
+
|
| 201 |
+
# Check and delete attributes safely
|
| 202 |
+
if hasattr(self, 'model') and self.model is not None:
|
| 203 |
del self.model
|
| 204 |
+
self.model = None
|
| 205 |
+
logger.info("LLM model deleted.")
|
| 206 |
+
|
| 207 |
+
if hasattr(self, 'tokenizer') and self.tokenizer is not None:
|
| 208 |
del self.tokenizer
|
| 209 |
+
self.tokenizer = None
|
| 210 |
+
logger.info("LLM tokenizer deleted.")
|
| 211 |
+
|
| 212 |
+
if hasattr(self, 'text_pipeline') and self.text_pipeline is not None:
|
| 213 |
+
del self.text_pipeline
|
| 214 |
+
self.text_pipeline = None
|
| 215 |
+
logger.info("LLM pipeline deleted.")
|
| 216 |
+
|
| 217 |
+
if hasattr(self, 'whisper_model') and self.whisper_model is not None:
|
| 218 |
del self.whisper_model
|
| 219 |
+
self.whisper_model = None
|
| 220 |
+
logger.info("Whisper model deleted.")
|
| 221 |
+
|
| 222 |
+
# Explicitly clear CUDA cache and collect garbage
|
|
|
|
| 223 |
if torch.cuda.is_available():
|
| 224 |
torch.cuda.empty_cache()
|
| 225 |
+
# torch.cuda.synchronize() # May not be needed and can slow down
|
| 226 |
+
logger.info("CUDA cache cleared.")
|
| 227 |
+
|
| 228 |
gc.collect()
|
| 229 |
+
logger.info("Garbage collected. Models reset successfully.")
|
| 230 |
+
self._initialized = False # Mark as uninitialized so they reload on next use
|
| 231 |
+
|
| 232 |
except Exception as e:
|
| 233 |
logger.error(f"Error resetting models: {str(e)}")
|
| 234 |
+
logger.error(traceback.format_exc())
|
| 235 |
|
| 236 |
+
# Create global model manager instance
|
| 237 |
model_manager = ModelManager()
|
| 238 |
|
| 239 |
+
@lru_cache(maxsize=16) # Reduced cache size slightly
|
| 240 |
def download_social_media_video(url):
|
| 241 |
+
"""Download audio from a social media video URL."""
|
| 242 |
+
temp_dir = tempfile.mkdtemp()
|
| 243 |
+
output_template = os.path.join(temp_dir, '%(id)s.%(ext)s')
|
| 244 |
+
|
| 245 |
ydl_opts = {
|
| 246 |
'format': 'bestaudio/best',
|
| 247 |
'postprocessors': [{
|
| 248 |
'key': 'FFmpegExtractAudio',
|
| 249 |
'preferredcodec': 'mp3',
|
| 250 |
+
'preferredquality': '192', # Standard quality
|
| 251 |
}],
|
| 252 |
+
'outtmpl': output_template,
|
| 253 |
+
'quiet': True,
|
| 254 |
+
'no_warnings': True,
|
| 255 |
+
'nocheckcertificate': True, # Sometimes needed for tricky sites
|
| 256 |
+
'retries': 3, # Add retries
|
| 257 |
+
'socket_timeout': 15, # Timeout
|
| 258 |
}
|
| 259 |
try:
|
| 260 |
+
logger.info(f"Attempting to download audio from: {url}")
|
| 261 |
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
| 262 |
info_dict = ydl.extract_info(url, download=True)
|
| 263 |
+
# Construct the expected final filename after postprocessing
|
| 264 |
+
audio_file = os.path.join(temp_dir, f"{info_dict['id']}.mp3")
|
| 265 |
+
if not os.path.exists(audio_file):
|
| 266 |
+
# Fallback if filename doesn't match exactly (e.g., webm -> mp3)
|
| 267 |
+
found_files = [f for f in os.listdir(temp_dir) if f.endswith('.mp3')]
|
| 268 |
+
if found_files:
|
| 269 |
+
audio_file = os.path.join(temp_dir, found_files[0])
|
| 270 |
+
else:
|
| 271 |
+
raise FileNotFoundError(f"Could not find downloaded MP3 in {temp_dir}")
|
| 272 |
+
|
| 273 |
+
logger.info(f"Audio downloaded successfully: {audio_file}")
|
| 274 |
+
# Read the file content to return, as the temp dir might be cleaned up
|
| 275 |
+
with open(audio_file, 'rb') as f:
|
| 276 |
+
audio_content = f.read()
|
| 277 |
+
|
| 278 |
+
# Clean up the temporary directory and file
|
| 279 |
+
try:
|
| 280 |
+
os.remove(audio_file)
|
| 281 |
+
os.rmdir(temp_dir)
|
| 282 |
+
except OSError as e:
|
| 283 |
+
logger.warning(f"Could not completely clean up temp download files: {e}")
|
| 284 |
+
|
| 285 |
+
# Save the content to a new temporary file that Gradio can handle
|
| 286 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_output_file:
|
| 287 |
+
temp_output_file.write(audio_content)
|
| 288 |
+
final_path = temp_output_file.name
|
| 289 |
+
logger.info(f"Audio saved to temporary file: {final_path}")
|
| 290 |
+
return final_path
|
| 291 |
+
|
| 292 |
+
except yt_dlp.utils.DownloadError as e:
|
| 293 |
+
logger.error(f"yt-dlp download error for {url}: {str(e)}")
|
| 294 |
+
# Clean up temp dir on error
|
| 295 |
+
try:
|
| 296 |
+
if os.path.exists(temp_dir):
|
| 297 |
+
import shutil
|
| 298 |
+
shutil.rmtree(temp_dir)
|
| 299 |
+
except Exception as cleanup_e:
|
| 300 |
+
logger.warning(f"Error during cleanup after download failure: {cleanup_e}")
|
| 301 |
+
return None # Return None to indicate failure
|
| 302 |
except Exception as e:
|
| 303 |
+
logger.error(f"Unexpected error downloading video from {url}: {str(e)}")
|
| 304 |
+
logger.error(traceback.format_exc())
|
| 305 |
+
# Clean up temp dir on error
|
| 306 |
+
try:
|
| 307 |
+
if os.path.exists(temp_dir):
|
| 308 |
+
import shutil
|
| 309 |
+
shutil.rmtree(temp_dir)
|
| 310 |
+
except Exception as cleanup_e:
|
| 311 |
+
logger.warning(f"Error during cleanup after download failure: {cleanup_e}")
|
| 312 |
+
return None # Return None
|
| 313 |
|
| 314 |
+
def convert_video_to_audio(video_file_path):
|
| 315 |
"""Convert a video file to audio using ffmpeg directly."""
|
| 316 |
try:
|
| 317 |
+
# Create a temporary file path for the output MP3
|
| 318 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
|
| 319 |
+
output_file_path = temp_file.name
|
| 320 |
+
|
| 321 |
+
logger.info(f"Converting video '{video_file_path}' to audio '{output_file_path}'")
|
| 322 |
+
|
| 323 |
+
# Use ffmpeg directly via subprocess
|
| 324 |
command = [
|
| 325 |
+
"ffmpeg",
|
| 326 |
+
"-i", video_file_path,
|
| 327 |
+
"-vn", # No video
|
| 328 |
+
"-acodec", "libmp3lame", # Specify MP3 codec
|
| 329 |
+
"-ab", "192k", # Audio bitrate
|
| 330 |
+
"-ar", "44100", # Audio sample rate
|
| 331 |
+
"-ac", "2", # Stereo audio
|
| 332 |
+
output_file_path,
|
| 333 |
+
"-y", # Overwrite output file if it exists
|
| 334 |
+
"-loglevel", "error" # Suppress verbose ffmpeg output
|
| 335 |
]
|
| 336 |
+
|
| 337 |
+
process = subprocess.run(command, check=True, capture_output=True, text=True)
|
| 338 |
+
logger.info(f"ffmpeg conversion successful for {video_file_path}.")
|
| 339 |
+
logger.debug(f"ffmpeg stdout: {process.stdout}")
|
| 340 |
+
logger.debug(f"ffmpeg stderr: {process.stderr}")
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
# Verify output file exists and has size
|
| 344 |
+
if not os.path.exists(output_file_path) or os.path.getsize(output_file_path) == 0:
|
| 345 |
+
raise RuntimeError(f"ffmpeg conversion failed: Output file '{output_file_path}' not created or is empty.")
|
| 346 |
+
|
| 347 |
+
logger.info(f"Video converted to audio: {output_file_path}")
|
| 348 |
+
return output_file_path
|
| 349 |
+
except subprocess.CalledProcessError as e:
|
| 350 |
+
logger.error(f"ffmpeg command failed with exit code {e.returncode}")
|
| 351 |
+
logger.error(f"ffmpeg stderr: {e.stderr}")
|
| 352 |
+
logger.error(f"ffmpeg stdout: {e.stdout}")
|
| 353 |
+
# Clean up potentially empty output file
|
| 354 |
+
if os.path.exists(output_file_path):
|
| 355 |
+
os.remove(output_file_path)
|
| 356 |
+
raise RuntimeError(f"ffmpeg conversion failed: {e.stderr}") from e
|
| 357 |
except Exception as e:
|
| 358 |
+
logger.error(f"Error converting video '{video_file_path}': {str(e)}")
|
| 359 |
+
logger.error(traceback.format_exc())
|
| 360 |
+
# Clean up potentially created output file
|
| 361 |
+
if 'output_file_path' in locals() and os.path.exists(output_file_path):
|
| 362 |
+
os.remove(output_file_path)
|
| 363 |
+
raise # Re-raise the exception
|
| 364 |
|
| 365 |
+
def preprocess_audio(input_audio_path):
|
| 366 |
+
"""Preprocess the audio file (e.g., normalize volume)."""
|
| 367 |
try:
|
| 368 |
+
logger.info(f"Preprocessing audio file: {input_audio_path}")
|
| 369 |
+
audio = AudioSegment.from_file(input_audio_path)
|
| 370 |
+
|
| 371 |
+
# Apply normalization (optional, adjust target dBFS as needed)
|
| 372 |
+
# Target loudness: -20 dBFS. Adjust gain based on current loudness.
|
| 373 |
+
# change_in_dBFS = -20.0 - audio.dBFS
|
| 374 |
+
# audio = audio.apply_gain(change_in_dBFS)
|
| 375 |
+
|
| 376 |
+
# Export to a new temporary file
|
| 377 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
|
| 378 |
+
output_path = temp_file.name
|
| 379 |
+
audio.export(output_path, format="mp3")
|
| 380 |
+
|
| 381 |
+
logger.info(f"Audio preprocessed and saved to: {output_path}")
|
| 382 |
+
return output_path
|
| 383 |
except Exception as e:
|
| 384 |
+
logger.error(f"Error preprocessing audio '{input_audio_path}': {str(e)}")
|
| 385 |
+
logger.error(traceback.format_exc())
|
| 386 |
+
# Return original path if preprocessing fails? Or raise error?
|
| 387 |
+
# Let's raise the error to signal failure clearly.
|
| 388 |
raise
|
| 389 |
|
| 390 |
+
@spaces.GPU(duration=300) # Allow more time for transcription
|
| 391 |
+
def transcribe_audio_or_video(file_input):
|
| 392 |
+
"""Transcribe an audio or video file (local path or Gradio File object)."""
|
| 393 |
+
audio_file_to_transcribe = None
|
| 394 |
+
original_input_path = None
|
| 395 |
+
temp_files_to_clean = []
|
| 396 |
+
|
| 397 |
try:
|
| 398 |
model_manager.check_whisper_initialized()
|
| 399 |
+
|
| 400 |
+
if file_input is None:
|
| 401 |
+
logger.info("No file input provided for transcription.")
|
| 402 |
+
return "" # Return empty string for None input
|
| 403 |
+
|
| 404 |
+
# Determine input type and get file path
|
| 405 |
+
if isinstance(file_input, str): # Input is a path
|
| 406 |
+
original_input_path = file_input
|
| 407 |
+
logger.info(f"Processing path input: {original_input_path}")
|
| 408 |
+
if not os.path.exists(original_input_path):
|
| 409 |
+
logger.error(f"Input file path does not exist: {original_input_path}")
|
| 410 |
+
raise FileNotFoundError(f"Input file not found: {original_input_path}")
|
| 411 |
+
input_path = original_input_path
|
| 412 |
+
elif hasattr(file_input, 'name'): # Input is a Gradio File object
|
| 413 |
+
original_input_path = file_input.name
|
| 414 |
+
logger.info(f"Processing Gradio file input: {original_input_path}")
|
| 415 |
+
input_path = original_input_path # Gradio usually provides a temp path
|
| 416 |
else:
|
| 417 |
+
logger.error(f"Unsupported input type for transcription: {type(file_input)}")
|
| 418 |
+
raise TypeError("Invalid input type for transcription. Expected file path or Gradio File object.")
|
| 419 |
+
|
| 420 |
+
file_extension = os.path.splitext(input_path)[1].lower()
|
| 421 |
+
|
| 422 |
+
# Check if it's a video file that needs conversion
|
| 423 |
+
if file_extension in ['.mp4', '.avi', '.mov', '.mkv', '.webm']:
|
| 424 |
+
logger.info(f"Detected video file ({file_extension}), converting to audio...")
|
| 425 |
+
converted_audio_path = convert_video_to_audio(input_path)
|
| 426 |
+
temp_files_to_clean.append(converted_audio_path)
|
| 427 |
+
audio_file_to_process = converted_audio_path
|
| 428 |
+
elif file_extension in ['.mp3', '.wav', '.ogg', '.flac', '.m4a']:
|
| 429 |
+
logger.info(f"Detected audio file ({file_extension}).")
|
| 430 |
+
audio_file_to_process = input_path
|
| 431 |
+
else:
|
| 432 |
+
logger.error(f"Unsupported file extension for transcription: {file_extension}")
|
| 433 |
+
raise ValueError(f"Unsupported file type: {file_extension}")
|
| 434 |
+
|
| 435 |
+
# Preprocess the audio (optional, could be skipped if causing issues)
|
| 436 |
try:
|
| 437 |
+
preprocessed_audio_path = preprocess_audio(audio_file_to_process)
|
| 438 |
+
# If preprocessing creates a new file different from the input, add it to cleanup
|
| 439 |
+
if preprocessed_audio_path != audio_file_to_process:
|
| 440 |
+
temp_files_to_clean.append(preprocessed_audio_path)
|
| 441 |
+
audio_file_to_transcribe = preprocessed_audio_path
|
| 442 |
+
except Exception as preprocess_err:
|
| 443 |
+
logger.warning(f"Audio preprocessing failed: {preprocess_err}. Using original/converted audio.")
|
| 444 |
+
audio_file_to_transcribe = audio_file_to_process # Fallback
|
| 445 |
+
|
| 446 |
+
logger.info(f"Transcribing audio file: {audio_file_to_transcribe}")
|
| 447 |
+
if not os.path.exists(audio_file_to_transcribe):
|
| 448 |
+
raise FileNotFoundError(f"Audio file to transcribe not found: {audio_file_to_transcribe}")
|
| 449 |
+
|
| 450 |
+
# Perform transcription
|
| 451 |
+
with torch.inference_mode(): # Ensure inference mode for efficiency
|
| 452 |
+
# Use fp16 if available on CUDA
|
| 453 |
+
use_fp16 = torch.cuda.is_available()
|
| 454 |
+
result = model_manager.whisper_model.transcribe(
|
| 455 |
+
audio_file_to_transcribe,
|
| 456 |
+
fp16=use_fp16
|
| 457 |
+
)
|
| 458 |
+
if not result:
|
| 459 |
+
raise RuntimeError("Transcription failed to produce results")
|
| 460 |
+
|
| 461 |
+
transcription = result.get("text", "Error: Transcription result empty")
|
| 462 |
+
# Limit transcription length shown in logs
|
| 463 |
+
log_transcription = (transcription[:100] + '...') if len(transcription) > 100 else transcription
|
| 464 |
+
logger.info(f"Transcription completed: {log_transcription}")
|
| 465 |
+
|
| 466 |
return transcription
|
| 467 |
+
|
| 468 |
+
except FileNotFoundError as e:
|
| 469 |
+
logger.error(f"File not found error during transcription: {e}")
|
| 470 |
+
return f"Error: Input file not found ({e})"
|
| 471 |
+
except ValueError as e:
|
| 472 |
+
logger.error(f"Value error during transcription: {e}")
|
| 473 |
+
return f"Error: Unsupported file type ({e})"
|
| 474 |
+
except TypeError as e:
|
| 475 |
+
logger.error(f"Type error during transcription setup: {e}")
|
| 476 |
+
return f"Error: Invalid input provided ({e})"
|
| 477 |
+
except RuntimeError as e:
|
| 478 |
+
logger.error(f"Runtime error during transcription: {e}")
|
| 479 |
+
logger.error(traceback.format_exc())
|
| 480 |
+
return f"Error during processing: {e}"
|
| 481 |
except Exception as e:
|
| 482 |
+
logger.error(f"Unexpected error during transcription: {str(e)}")
|
| 483 |
+
logger.error(traceback.format_exc())
|
| 484 |
+
return f"Error processing the file: An unexpected error occurred."
|
| 485 |
+
|
| 486 |
+
finally:
|
| 487 |
+
# Clean up all temporary files created during the process
|
| 488 |
+
for temp_file in temp_files_to_clean:
|
| 489 |
+
try:
|
| 490 |
+
if os.path.exists(temp_file):
|
| 491 |
+
os.remove(temp_file)
|
| 492 |
+
logger.info(f"Cleaned up temporary file: {temp_file}")
|
| 493 |
+
except Exception as e:
|
| 494 |
+
logger.warning(f"Could not remove temporary file {temp_file}: {str(e)}")
|
| 495 |
+
# Optionally reset models if idle (might be too aggressive here)
|
| 496 |
+
# model_manager.reset_models()
|
| 497 |
|
| 498 |
+
@lru_cache(maxsize=16)
|
| 499 |
def read_document(document_path):
|
| 500 |
+
"""Read the content of a document (PDF, DOCX, XLSX, CSV)."""
|
| 501 |
try:
|
| 502 |
+
logger.info(f"Reading document: {document_path}")
|
| 503 |
+
if not os.path.exists(document_path):
|
| 504 |
+
raise FileNotFoundError(f"Document not found: {document_path}")
|
| 505 |
+
|
| 506 |
+
file_extension = os.path.splitext(document_path)[1].lower()
|
| 507 |
+
|
| 508 |
+
if file_extension == ".pdf":
|
| 509 |
doc = fitz.open(document_path)
|
| 510 |
+
text = "\n".join([page.get_text() for page in doc])
|
| 511 |
+
doc.close()
|
| 512 |
+
return text
|
| 513 |
+
elif file_extension == ".docx":
|
| 514 |
doc = docx.Document(document_path)
|
| 515 |
return "\n".join([paragraph.text for paragraph in doc.paragraphs])
|
| 516 |
+
elif file_extension in (".xlsx", ".xls"):
|
| 517 |
+
# Read all sheets and combine
|
| 518 |
+
xls = pd.ExcelFile(document_path)
|
| 519 |
+
text = ""
|
| 520 |
+
for sheet_name in xls.sheet_names:
|
| 521 |
+
df = pd.read_excel(xls, sheet_name=sheet_name)
|
| 522 |
+
text += f"--- Sheet: {sheet_name} ---\n{df.to_string()}\n\n"
|
| 523 |
+
return text.strip()
|
| 524 |
+
elif file_extension == ".csv":
|
| 525 |
+
# Try detecting separator
|
| 526 |
+
try:
|
| 527 |
+
df = pd.read_csv(document_path)
|
| 528 |
+
except pd.errors.ParserError:
|
| 529 |
+
logger.warning(f"Could not parse CSV {document_path} with default comma separator, trying semicolon.")
|
| 530 |
+
df = pd.read_csv(document_path, sep=';')
|
| 531 |
+
return df.to_string()
|
| 532 |
else:
|
| 533 |
+
logger.warning(f"Unsupported document type: {file_extension}")
|
| 534 |
return "Unsupported file type. Please upload a PDF, DOCX, XLSX or CSV document."
|
| 535 |
+
except FileNotFoundError as e:
|
| 536 |
+
logger.error(f"Error reading document: {e}")
|
| 537 |
+
return f"Error: Document file not found at {document_path}"
|
| 538 |
except Exception as e:
|
| 539 |
+
logger.error(f"Error reading document {document_path}: {str(e)}")
|
| 540 |
+
logger.error(traceback.format_exc())
|
| 541 |
return f"Error reading document: {str(e)}"
|
| 542 |
|
| 543 |
+
@lru_cache(maxsize=16)
|
| 544 |
def read_url(url):
|
| 545 |
+
"""Read the main textual content of a URL."""
|
| 546 |
+
if not url or not url.strip().startswith('http'):
|
| 547 |
+
logger.info(f"Invalid or empty URL provided: '{url}'")
|
| 548 |
+
return "" # Return empty for invalid or empty URLs
|
| 549 |
+
|
| 550 |
try:
|
| 551 |
+
logger.info(f"Reading URL: {url}")
|
| 552 |
headers = {
|
| 553 |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
| 554 |
}
|
| 555 |
+
# Increased timeout
|
| 556 |
+
response = requests.get(url, headers=headers, timeout=20, allow_redirects=True)
|
| 557 |
+
response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
|
| 558 |
+
|
| 559 |
+
# Check content type - proceed only if likely HTML/text
|
| 560 |
+
content_type = response.headers.get('content-type', '').lower()
|
| 561 |
+
if not ('html' in content_type or 'text' in content_type):
|
| 562 |
+
logger.warning(f"URL {url} has non-text content type: {content_type}. Skipping.")
|
| 563 |
+
return f"Error: URL content type ({content_type}) is not text/html."
|
| 564 |
+
|
| 565 |
soup = BeautifulSoup(response.content, 'html.parser')
|
| 566 |
+
|
| 567 |
+
# Remove non-content elements like scripts, styles, nav, footers etc.
|
| 568 |
+
for element in soup(["script", "style", "meta", "noscript", "iframe", "header", "footer", "nav", "aside", "form", "button"]):
|
| 569 |
element.extract()
|
| 570 |
+
|
| 571 |
+
# Attempt to find main content area (common tags/attributes)
|
| 572 |
+
main_content = (
|
| 573 |
+
soup.find("main") or
|
| 574 |
+
soup.find("article") or
|
| 575 |
+
soup.find("div", class_=["content", "main", "post-content", "entry-content", "article-body"]) or
|
| 576 |
+
soup.find("div", id=["content", "main", "article"])
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
if main_content:
|
| 580 |
text = main_content.get_text(separator='\n', strip=True)
|
| 581 |
else:
|
| 582 |
+
# Fallback to body if no specific main content found
|
| 583 |
+
body = soup.find("body")
|
| 584 |
+
if body:
|
| 585 |
+
text = body.get_text(separator='\n', strip=True)
|
| 586 |
+
else: # Very basic fallback
|
| 587 |
+
text = soup.get_text(separator='\n', strip=True)
|
| 588 |
+
|
| 589 |
+
# Clean up whitespace: replace multiple newlines/spaces with single ones
|
| 590 |
+
text = '\n'.join([line.strip() for line in text.split('\n') if line.strip()])
|
| 591 |
+
text = ' '.join(text.split()) # Consolidate spaces within lines
|
| 592 |
+
|
| 593 |
+
if not text:
|
| 594 |
+
logger.warning(f"Could not extract meaningful text from URL: {url}")
|
| 595 |
+
return "Error: Could not extract text content from URL."
|
| 596 |
+
|
| 597 |
+
# Limit content size to avoid overwhelming the LLM
|
| 598 |
+
max_chars = 15000
|
| 599 |
+
if len(text) > max_chars:
|
| 600 |
+
logger.info(f"URL content truncated to {max_chars} characters.")
|
| 601 |
+
text = text[:max_chars] + "... [content truncated]"
|
| 602 |
+
|
| 603 |
+
return text
|
| 604 |
+
except requests.exceptions.RequestException as e:
|
| 605 |
+
logger.error(f"Error fetching URL {url}: {str(e)}")
|
| 606 |
+
return f"Error reading URL: Could not fetch content ({e})"
|
| 607 |
except Exception as e:
|
| 608 |
+
logger.error(f"Error parsing URL {url}: {str(e)}")
|
| 609 |
+
logger.error(traceback.format_exc())
|
| 610 |
+
return f"Error reading URL: Could not parse content ({e})"
|
| 611 |
|
| 612 |
+
def process_social_media_url(url):
|
| 613 |
+
"""Process a social media URL, attempting to get text and transcribe video/audio."""
|
| 614 |
+
if not url or not url.strip().startswith('http'):
|
| 615 |
+
logger.info(f"Invalid or empty social media URL: '{url}'")
|
| 616 |
return None
|
| 617 |
+
|
| 618 |
+
logger.info(f"Processing social media URL: {url}")
|
| 619 |
+
text_content = None
|
| 620 |
+
video_transcription = None
|
| 621 |
+
error_occurred = False
|
| 622 |
+
|
| 623 |
+
# 1. Try extracting text content using read_url (might work for some platforms/posts)
|
| 624 |
try:
|
| 625 |
text_content = read_url(url)
|
| 626 |
+
if text_content and text_content.startswith("Error:"):
|
| 627 |
+
logger.warning(f"Failed to read text content from social URL {url}: {text_content}")
|
| 628 |
+
text_content = None # Reset if it was an error message
|
| 629 |
+
except Exception as e:
|
| 630 |
+
logger.error(f"Error reading text content from social URL {url}: {e}")
|
| 631 |
+
error_occurred = True
|
| 632 |
|
| 633 |
+
# 2. Try downloading and transcribing potential video/audio content
|
| 634 |
+
downloaded_audio_path = None
|
| 635 |
+
try:
|
| 636 |
+
downloaded_audio_path = download_social_media_video(url)
|
| 637 |
+
if downloaded_audio_path:
|
| 638 |
+
logger.info(f"Audio downloaded from {url}, proceeding to transcription.")
|
| 639 |
+
video_transcription = transcribe_audio_or_video(downloaded_audio_path)
|
| 640 |
+
if video_transcription and video_transcription.startswith("Error"):
|
| 641 |
+
logger.warning(f"Transcription failed for audio from {url}: {video_transcription}")
|
| 642 |
+
video_transcription = None # Reset if it was an error
|
| 643 |
+
else:
|
| 644 |
+
logger.info(f"No downloadable audio/video found or download failed for URL: {url}")
|
| 645 |
+
except Exception as e:
|
| 646 |
+
logger.error(f"Error processing video content from social URL {url}: {e}")
|
| 647 |
+
logger.error(traceback.format_exc())
|
| 648 |
+
error_occurred = True
|
| 649 |
+
finally:
|
| 650 |
+
# Clean up downloaded file if it exists
|
| 651 |
+
if downloaded_audio_path and os.path.exists(downloaded_audio_path):
|
| 652 |
+
try:
|
| 653 |
+
os.remove(downloaded_audio_path)
|
| 654 |
+
logger.info(f"Cleaned up downloaded audio: {downloaded_audio_path}")
|
| 655 |
+
except Exception as e:
|
| 656 |
+
logger.warning(f"Failed to cleanup downloaded audio {downloaded_audio_path}: {e}")
|
| 657 |
+
|
| 658 |
+
# Return results only if some content was found or no critical error occurred
|
| 659 |
+
if text_content or video_transcription or not error_occurred:
|
| 660 |
return {
|
| 661 |
+
"text": text_content or "", # Ensure string type
|
| 662 |
+
"video": video_transcription or "" # Ensure string type
|
| 663 |
}
|
| 664 |
+
else:
|
| 665 |
+
logger.error(f"Failed to process social media URL {url} completely.")
|
| 666 |
+
return None # Indicate failure
|
| 667 |
+
|
| 668 |
|
| 669 |
+
@spaces.GPU(duration=300) # Allow more time for generation
|
| 670 |
def generate_news(instructions, facts, size, tone, *args):
|
| 671 |
+
"""Generate a news article based on provided data using an LLM."""
|
| 672 |
+
request_start_time = time.time()
|
| 673 |
+
logger.info("Received request to generate news.")
|
| 674 |
try:
|
| 675 |
+
# Ensure size is integer
|
| 676 |
+
try:
|
| 677 |
+
size = int(size) if size else 250 # Default size if None/empty
|
| 678 |
+
except ValueError:
|
| 679 |
+
logger.warning(f"Invalid size value '{size}', defaulting to 250.")
|
| 680 |
size = 250
|
| 681 |
+
|
| 682 |
+
# Check if models are initialized, load if necessary
|
| 683 |
+
model_manager.check_llm_initialized() # LLM is essential
|
| 684 |
+
# Whisper might be needed later, check/load if audio sources exist
|
| 685 |
+
|
| 686 |
+
# --- Argument Parsing ---
|
| 687 |
+
# The order *must* match the order components are added to inputs_list in create_demo
|
| 688 |
+
# Fixed inputs: instructions, facts, size, tone (already passed directly)
|
| 689 |
+
# Dynamic inputs from *args:
|
| 690 |
+
# Expected order in *args based on create_demo:
|
| 691 |
+
# 5 Documents, 15 Audio-related, 5 URLs, 9 Social-related
|
| 692 |
+
num_docs = 5
|
| 693 |
+
num_audio_sources = 5
|
| 694 |
+
num_audio_inputs_per_source = 3
|
| 695 |
+
num_urls = 5
|
| 696 |
+
num_social_sources = 3
|
| 697 |
+
num_social_inputs_per_source = 3
|
| 698 |
+
|
| 699 |
+
total_expected_args = num_docs + (num_audio_sources * num_audio_inputs_per_source) + num_urls + (num_social_sources * num_social_inputs_per_source)
|
| 700 |
+
|
| 701 |
+
args_list = list(args)
|
| 702 |
+
# Pad args_list with None if fewer arguments were received than expected
|
| 703 |
+
args_list.extend([None] * (total_expected_args - len(args_list)))
|
| 704 |
+
|
| 705 |
+
# Slice arguments based on the expected order
|
| 706 |
+
doc_files = args_list[0:num_docs]
|
| 707 |
+
audio_inputs_flat = args_list[num_docs : num_docs + (num_audio_sources * num_audio_inputs_per_source)]
|
| 708 |
+
url_inputs = args_list[num_docs + (num_audio_sources * num_audio_inputs_per_source) : num_docs + (num_audio_sources * num_audio_inputs_per_source) + num_urls]
|
| 709 |
+
social_inputs_flat = args_list[num_docs + (num_audio_sources * num_audio_inputs_per_source) + num_urls : total_expected_args]
|
| 710 |
+
|
| 711 |
knowledge_base = {
|
| 712 |
+
"instructions": instructions or "No specific instructions provided.",
|
| 713 |
+
"facts": facts or "No specific facts provided.",
|
| 714 |
"document_content": [],
|
| 715 |
"audio_data": [],
|
| 716 |
"url_content": [],
|
| 717 |
"social_content": []
|
| 718 |
}
|
| 719 |
+
raw_transcriptions = "" # Initialize transcription log
|
| 720 |
|
| 721 |
+
# --- Process Inputs ---
|
| 722 |
+
logger.info("Processing document inputs...")
|
| 723 |
+
for i, doc_file in enumerate(doc_files):
|
| 724 |
+
if doc_file and hasattr(doc_file, 'name'):
|
| 725 |
+
try:
|
| 726 |
+
content = read_document(doc_file.name) # doc_file.name is the temp path
|
| 727 |
+
if content and not content.startswith("Error"):
|
| 728 |
+
# Truncate long documents for the knowledge base summary
|
| 729 |
+
doc_excerpt = (content[:1000] + "... [document truncated]") if len(content) > 1000 else content
|
| 730 |
+
knowledge_base["document_content"].append(f"[Document {i+1} Source: {os.path.basename(doc_file.name)}]\n{doc_excerpt}")
|
| 731 |
+
else:
|
| 732 |
+
logger.warning(f"Skipping document {i+1} due to read error or empty content: {content}")
|
| 733 |
+
except Exception as e:
|
| 734 |
+
logger.error(f"Failed to process document {i+1} ({doc_file.name}): {e}")
|
| 735 |
+
# No cleanup needed here, Gradio handles temp file uploads
|
| 736 |
+
|
| 737 |
+
logger.info("Processing URL inputs...")
|
| 738 |
+
for i, url in enumerate(url_inputs):
|
| 739 |
+
if url and isinstance(url, str) and url.strip().startswith('http'):
|
| 740 |
+
try:
|
| 741 |
+
content = read_url(url)
|
| 742 |
+
if content and not content.startswith("Error"):
|
| 743 |
+
# Content is already truncated in read_url if needed
|
| 744 |
+
knowledge_base["url_content"].append(f"[URL {i+1} Source: {url}]\n{content}")
|
| 745 |
+
else:
|
| 746 |
+
logger.warning(f"Skipping URL {i+1} ({url}) due to read error or empty content: {content}")
|
| 747 |
+
except Exception as e:
|
| 748 |
+
logger.error(f"Failed to process URL {i+1} ({url}): {e}")
|
| 749 |
|
| 750 |
+
logger.info("Processing audio/video inputs...")
|
| 751 |
+
has_audio_source = False
|
| 752 |
+
for i in range(num_audio_sources):
|
| 753 |
+
start_idx = i * num_audio_inputs_per_source
|
| 754 |
+
audio_file = audio_inputs_flat[start_idx]
|
| 755 |
+
name = audio_inputs_flat[start_idx + 1] or f"Source {i+1}"
|
| 756 |
+
position = audio_inputs_flat[start_idx + 2] or "N/A"
|
| 757 |
+
|
| 758 |
+
if audio_file and hasattr(audio_file, 'name'):
|
| 759 |
+
# Store info for transcription later
|
| 760 |
+
knowledge_base["audio_data"].append({
|
| 761 |
+
"file_path": audio_file.name, # Use the temp path
|
| 762 |
+
"name": name,
|
| 763 |
+
"position": position,
|
| 764 |
+
"original_filename": os.path.basename(audio_file.name) # Keep original for logs
|
| 765 |
+
})
|
| 766 |
+
has_audio_source = True
|
| 767 |
+
logger.info(f"Added audio source {i+1}: {name} ({position}) - File: {knowledge_base['audio_data'][-1]['original_filename']}")
|
| 768 |
+
|
| 769 |
+
logger.info("Processing social media inputs...")
|
| 770 |
+
has_social_source = False
|
| 771 |
+
for i in range(num_social_sources):
|
| 772 |
+
start_idx = i * num_social_inputs_per_source
|
| 773 |
+
social_url = social_inputs_flat[start_idx]
|
| 774 |
+
social_name = social_inputs_flat[start_idx + 1] or f"Social Source {i+1}"
|
| 775 |
+
social_context = social_inputs_flat[start_idx + 2] or "N/A"
|
| 776 |
+
|
| 777 |
+
if social_url and isinstance(social_url, str) and social_url.strip().startswith('http'):
|
| 778 |
+
try:
|
| 779 |
+
logger.info(f"Processing social media URL {i+1}: {social_url}")
|
| 780 |
+
social_data = process_social_media_url(social_url)
|
| 781 |
+
if social_data:
|
| 782 |
+
knowledge_base["social_content"].append({
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 783 |
"url": social_url,
|
| 784 |
+
"name": social_name,
|
| 785 |
+
"context": social_context,
|
| 786 |
+
"text": social_data.get("text", ""),
|
| 787 |
+
"video_transcription": social_data.get("video", "") # Store potential transcription
|
| 788 |
})
|
| 789 |
+
has_social_source = True
|
| 790 |
+
logger.info(f"Added social source {i+1}: {social_name} ({social_context}) from {social_url}")
|
| 791 |
+
else:
|
| 792 |
+
logger.warning(f"Could not retrieve any content for social URL {i+1}: {social_url}")
|
| 793 |
+
except Exception as e:
|
| 794 |
+
logger.error(f"Failed to process social URL {i+1} ({social_url}): {e}")
|
| 795 |
|
|
|
|
|
|
|
| 796 |
|
| 797 |
+
# --- Transcribe Audio/Video ---
|
| 798 |
+
# Only initialize Whisper if needed
|
| 799 |
+
transcriptions_for_prompt = ""
|
| 800 |
+
if has_audio_source or any(sc.get("video_transcription") == "[NEEDS_TRANSCRIPTION]" for sc in knowledge_base["social_content"]): # Check if transcription actually needed
|
| 801 |
+
logger.info("Audio sources detected, ensuring Whisper model is ready...")
|
| 802 |
+
try:
|
| 803 |
+
model_manager.check_whisper_initialized()
|
| 804 |
+
except Exception as whisper_init_err:
|
| 805 |
+
logger.error(f"FATAL: Whisper model initialization failed: {whisper_init_err}. Cannot transcribe.")
|
| 806 |
+
# Add error message to raw transcriptions and continue without transcriptions
|
| 807 |
+
raw_transcriptions += f"[ERROR] Whisper model failed to load. Audio sources could not be transcribed: {whisper_init_err}\n\n"
|
| 808 |
+
# Optionally return an error message immediately?
|
| 809 |
+
# return f"Error: Could not initialize transcription model. {whisper_init_err}", raw_transcriptions
|
| 810 |
+
|
| 811 |
+
if model_manager.whisper_model: # Proceed only if whisper loaded successfully
|
| 812 |
+
logger.info("Transcribing collected audio sources...")
|
| 813 |
+
for idx, data in enumerate(knowledge_base["audio_data"]):
|
| 814 |
+
try:
|
| 815 |
+
logger.info(f"Transcribing audio source {idx+1}: {data['original_filename']} ({data['name']}, {data['position']})")
|
| 816 |
+
transcription = transcribe_audio_or_video(data["file_path"])
|
| 817 |
+
if transcription and not transcription.startswith("Error"):
|
| 818 |
+
quote = f'"{transcription}" - {data["name"]}, {data["position"]}'
|
| 819 |
+
transcriptions_for_prompt += f"{quote}\n\n"
|
| 820 |
+
raw_transcriptions += f'[Audio/Video {idx + 1}: {data["original_filename"]} ({data["name"]}, {data["position"]})]\n"{transcription}"\n\n'
|
| 821 |
+
else:
|
| 822 |
+
logger.warning(f"Transcription failed or returned error for audio source {idx+1}: {transcription}")
|
| 823 |
+
raw_transcriptions += f'[Audio/Video {idx + 1}: {data["original_filename"]} ({data["name"]}, {data["position"]})]\n[Error during transcription: {transcription}]\n\n'
|
| 824 |
+
except Exception as e:
|
| 825 |
+
logger.error(f"Error during transcription for audio source {idx+1} ({data['original_filename']}): {e}")
|
| 826 |
+
logger.error(traceback.format_exc())
|
| 827 |
+
raw_transcriptions += f'[Audio/Video {idx + 1}: {data["original_filename"]} ({data["name"]}, {data["position"]})]\n[Error during transcription: {e}]\n\n'
|
| 828 |
+
# Gradio handles cleanup of the uploaded temp file audio_file.name
|
| 829 |
|
| 830 |
+
logger.info("Adding social media content to prompt data...")
|
| 831 |
for idx, data in enumerate(knowledge_base["social_content"]):
|
| 832 |
+
source_id = f'[Social Media {idx+1}: {data["url"]} ({data["name"]}, {data["context"]})]'
|
| 833 |
+
has_content = False
|
| 834 |
+
if data["text"] and not data["text"].startswith("Error"):
|
| 835 |
+
# Truncate long text for the prompt, but keep full in knowledge base maybe?
|
| 836 |
+
text_excerpt = (data["text"][:500] + "...[text truncated]") if len(data["text"]) > 500 else data["text"]
|
| 837 |
+
social_text_prompt = f'{source_id} - Text Content:\n"{text_excerpt}"\n\n'
|
| 838 |
+
transcriptions_for_prompt += social_text_prompt # Add text content as if it were a quote/source
|
| 839 |
+
raw_transcriptions += f"{source_id}\nText Content:\n{data['text']}\n\n" # Log full text
|
| 840 |
+
has_content = True
|
| 841 |
+
if data["video_transcription"] and not data["video_transcription"].startswith("Error"):
|
| 842 |
+
social_video_prompt = f'{source_id} - Video Transcription:\n"{data["video_transcription"]}"\n\n'
|
| 843 |
+
transcriptions_for_prompt += social_video_prompt
|
| 844 |
+
raw_transcriptions += f"{source_id}\nVideo Transcription:\n{data['video_transcription']}\n\n"
|
| 845 |
+
has_content = True
|
| 846 |
+
if not has_content:
|
| 847 |
+
raw_transcriptions += f"{source_id}\n[No usable text or video transcription found]\n\n"
|
| 848 |
+
|
| 849 |
+
|
| 850 |
+
# --- Prepare Final Prompt ---
|
| 851 |
+
# Combine document and URL summaries
|
| 852 |
+
document_summary = "\n\n".join(knowledge_base["document_content"]) if knowledge_base["document_content"] else "No document content provided."
|
| 853 |
+
url_summary = "\n\n".join(knowledge_base["url_content"]) if knowledge_base["url_content"] else "No URL content provided."
|
| 854 |
+
transcription_summary = transcriptions_for_prompt if transcriptions_for_prompt else "No usable transcriptions available."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 855 |
|
| 856 |
+
# Construct the prompt for the LLM
|
| 857 |
+
prompt = f"""<s>[INST] You are a professional news writer. Your task is to synthesize information from various sources into a coherent news article.
|
| 858 |
|
| 859 |
+
Primary Instructions: {knowledge_base["instructions"]}
|
| 860 |
+
Key Facts to Include: {knowledge_base["facts"]}
|
| 861 |
|
| 862 |
+
Supporting Information:
|
| 863 |
|
| 864 |
+
Document Content Summary:
|
| 865 |
+
{document_summary}
|
| 866 |
|
| 867 |
+
Web Content Summary (from URLs):
|
| 868 |
+
{url_summary}
|
| 869 |
|
| 870 |
+
Transcribed Quotes/Content (Use these directly or indirectly):
|
| 871 |
+
{transcription_summary}
|
| 872 |
|
| 873 |
+
Article Requirements:
|
| 874 |
+
- Title: Create a concise and informative title for the article.
|
| 875 |
+
- Hook: Write a compelling 15-word (approx.) hook sentence that complements the title.
|
| 876 |
+
- Body: Write the main news article body, aiming for approximately {size} words.
|
| 877 |
+
- Tone: Adopt a {tone} tone throughout the article.
|
| 878 |
+
- 5 Ws: Ensure the first paragraph addresses the core questions (Who, What, When, Where, Why).
|
| 879 |
+
- Quotes: Incorporate relevant information from the 'Transcribed Quotes/Content' section. Aim to use quotes where appropriate, but synthesize information rather than just listing quotes. Use quotation marks (" ") for direct quotes attributed correctly (e.g., based on name/position provided).
|
| 880 |
+
- Style: Adhere to a professional journalistic style. Be objective and factual.
|
| 881 |
+
- Accuracy: Do NOT invent information. Stick strictly to the provided facts, instructions, and source materials. If information is contradictory or missing, state that or omit the detail.
|
| 882 |
+
- Structure: Organize the article logically with clear paragraphs.
|
| 883 |
+
|
| 884 |
+
Begin the article now. [/INST]
|
| 885 |
+
Article Draft:
|
| 886 |
+
"""
|
| 887 |
+
|
| 888 |
+
# Log the prompt length (useful for debugging context limits)
|
| 889 |
+
logger.info(f"Generated prompt length: {len(prompt.split())} words / {len(prompt)} characters.")
|
| 890 |
+
# Avoid logging the full prompt if it's too long or contains sensitive info
|
| 891 |
+
# logger.debug(f"Generated Prompt:\n{prompt}")
|
| 892 |
+
|
| 893 |
+
# --- Generate News Article ---
|
| 894 |
+
logger.info("Generating news article with LLM...")
|
| 895 |
+
generation_start_time = time.time()
|
| 896 |
+
|
| 897 |
+
# Estimate max_new_tokens based on requested size + buffer
|
| 898 |
+
# Add buffer for title, hook, and potential verbosity
|
| 899 |
+
estimated_tokens_per_word = 1.5
|
| 900 |
+
max_new_tokens = int(size * estimated_tokens_per_word + 150) # size words + buffer
|
| 901 |
+
# Ensure max_new_tokens doesn't exceed model limits (adjust based on model's max context)
|
| 902 |
+
model_max_length = 2048 # Typical for TinyLlama, but check specific model card
|
| 903 |
+
# Calculate available space for generation
|
| 904 |
+
# Note: This token count is approximate. Precise tokenization is needed for accuracy.
|
| 905 |
+
# prompt_tokens = len(model_manager.tokenizer.encode(prompt)) # More accurate but slower
|
| 906 |
+
prompt_tokens_estimate = len(prompt) // 3 # Rough estimate
|
| 907 |
+
max_new_tokens = min(max_new_tokens, model_max_length - prompt_tokens_estimate - 50) # Leave buffer
|
| 908 |
+
max_new_tokens = max(max_new_tokens, 100) # Ensure at least a minimum generation length
|
| 909 |
+
|
| 910 |
+
logger.info(f"Requesting max_new_tokens: {max_new_tokens}")
|
| 911 |
|
| 912 |
try:
|
| 913 |
+
# Generate using the pipeline
|
| 914 |
+
outputs = model_manager.text_pipeline(
|
|
|
|
|
|
|
|
|
|
| 915 |
prompt,
|
| 916 |
+
max_new_tokens=max_new_tokens, # Use max_new_tokens instead of max_length
|
| 917 |
+
do_sample=True,
|
| 918 |
+
temperature=0.7, # Standard temperature for creative but factual
|
| 919 |
+
top_p=0.95,
|
| 920 |
+
top_k=50, # Consider adding top_k
|
| 921 |
+
repetition_penalty=1.15, # Adjusted penalty
|
| 922 |
+
pad_token_id=model_manager.tokenizer.eos_token_id,
|
| 923 |
+
num_return_sequences=1
|
|
|
|
|
|
|
|
|
|
|
|
|
| 924 |
)
|
| 925 |
+
|
| 926 |
+
# Extract generated text
|
| 927 |
+
generated_text = outputs[0]['generated_text']
|
| 928 |
+
|
| 929 |
+
# Clean up the result by removing the prompt
|
| 930 |
+
# Find the end of the prompt marker [/INST] and take text after it
|
| 931 |
+
inst_marker = "[/INST]"
|
| 932 |
+
marker_pos = generated_text.find(inst_marker)
|
| 933 |
+
if marker_pos != -1:
|
| 934 |
+
news_article = generated_text[marker_pos + len(inst_marker):].strip()
|
| 935 |
+
# Further clean potentially leading "Article Draft:" if model included it
|
| 936 |
+
if news_article.startswith("Article Draft:"):
|
| 937 |
+
news_article = news_article[len("Article Draft:"):].strip()
|
| 938 |
else:
|
| 939 |
+
# Fallback: Try removing the input prompt string itself (less reliable)
|
| 940 |
+
if prompt in generated_text:
|
| 941 |
+
news_article = generated_text.replace(prompt, "", 1).strip()
|
| 942 |
+
else:
|
| 943 |
+
# If prompt not found exactly, assume the output is only the generation
|
| 944 |
+
# This might happen if the pipeline handles prompt removal internally sometimes
|
| 945 |
+
news_article = generated_text
|
| 946 |
+
logger.warning("Prompt marker '[/INST]' not found in LLM output. Returning full output.")
|
| 947 |
+
|
| 948 |
+
|
| 949 |
+
generation_time = time.time() - generation_start_time
|
| 950 |
+
logger.info(f"News generation completed in {generation_time:.2f} seconds. Output length: {len(news_article)} characters.")
|
| 951 |
+
|
| 952 |
+
except torch.cuda.OutOfMemoryError as oom_error:
|
| 953 |
+
logger.error(f"CUDA Out of Memory error during LLM generation: {oom_error}")
|
| 954 |
+
logger.error(traceback.format_exc())
|
| 955 |
+
model_manager.reset_models(force=True) # Attempt to recover
|
| 956 |
+
raise RuntimeError("Generation failed due to insufficient GPU memory. Please try reducing article size or complexity.") from oom_error
|
| 957 |
except Exception as gen_error:
|
| 958 |
+
logger.error(f"Error during text generation pipeline: {str(gen_error)}")
|
| 959 |
+
logger.error(traceback.format_exc())
|
| 960 |
+
raise RuntimeError(f"LLM generation failed: {gen_error}") from gen_error
|
| 961 |
+
|
| 962 |
+
total_time = time.time() - request_start_time
|
| 963 |
+
logger.info(f"Total request processing time: {total_time:.2f} seconds.")
|
| 964 |
+
|
| 965 |
+
# Return the generated article and the log of raw transcriptions
|
| 966 |
+
return news_article, raw_transcriptions.strip()
|
| 967 |
|
| 968 |
except Exception as e:
|
| 969 |
+
total_time = time.time() - request_start_time
|
| 970 |
+
logger.error(f"Error in generate_news function after {total_time:.2f} seconds: {str(e)}")
|
| 971 |
+
logger.error(traceback.format_exc())
|
| 972 |
+
# Attempt to reset models to recover state if possible
|
| 973 |
try:
|
| 974 |
model_manager.reset_models(force=True)
|
| 975 |
except Exception as reset_error:
|
| 976 |
+
logger.error(f"Failed to reset models after error: {str(reset_error)}")
|
| 977 |
+
# Return error messages to the UI
|
| 978 |
+
error_message = f"Error generating the news article: {str(e)}"
|
| 979 |
+
transcription_log = raw_transcriptions.strip() + f"\n\n[ERROR] News generation failed: {str(e)}"
|
| 980 |
+
return error_message, transcription_log
|
| 981 |
|
| 982 |
def create_demo():
|
| 983 |
+
"""Creates the Gradio interface"""
|
| 984 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 985 |
gr.Markdown("# π° NewsIA - AI News Generator")
|
| 986 |
+
gr.Markdown("Create professional news articles from multiple information sources.")
|
| 987 |
+
|
| 988 |
+
# Store all input components for easy access/reset
|
| 989 |
+
all_inputs = []
|
| 990 |
+
|
| 991 |
with gr.Row():
|
| 992 |
with gr.Column(scale=2):
|
| 993 |
instructions = gr.Textbox(
|
| 994 |
+
label="Instructions for the News Article",
|
| 995 |
+
placeholder="Enter specific instructions for generating your news article (e.g., focus on the economic impact)",
|
| 996 |
+
lines=2,
|
| 997 |
+
value=""
|
| 998 |
)
|
| 999 |
+
all_inputs.append(instructions)
|
| 1000 |
+
|
| 1001 |
facts = gr.Textbox(
|
| 1002 |
+
label="Main Facts",
|
| 1003 |
+
placeholder="Describe the most important facts the news should include (e.g., Event name, date, location, key people involved)",
|
| 1004 |
+
lines=4,
|
| 1005 |
+
value=""
|
| 1006 |
)
|
| 1007 |
+
all_inputs.append(facts)
|
| 1008 |
+
|
| 1009 |
with gr.Row():
|
| 1010 |
+
size_slider = gr.Slider(
|
| 1011 |
label="Approximate Length (words)",
|
| 1012 |
minimum=100,
|
| 1013 |
+
maximum=700, # Increased max size
|
| 1014 |
value=250,
|
| 1015 |
step=50
|
| 1016 |
)
|
| 1017 |
+
all_inputs.append(size_slider)
|
| 1018 |
+
|
| 1019 |
+
tone_dropdown = gr.Dropdown(
|
| 1020 |
+
label="Tone of the News Article",
|
| 1021 |
+
choices=["neutral", "serious", "formal", "urgent", "investigative", "human-interest", "lighthearted"],
|
| 1022 |
value="neutral"
|
| 1023 |
)
|
| 1024 |
+
all_inputs.append(tone_dropdown)
|
| 1025 |
|
| 1026 |
with gr.Column(scale=3):
|
|
|
|
|
|
|
|
|
|
| 1027 |
with gr.Tabs():
|
| 1028 |
with gr.TabItem("π Documents"):
|
| 1029 |
+
gr.Markdown("Upload relevant documents (PDF, DOCX, XLSX, CSV). Max 5.")
|
| 1030 |
+
doc_inputs = []
|
| 1031 |
for i in range(1, 6):
|
| 1032 |
+
doc_file = gr.File(
|
| 1033 |
label=f"Document {i}",
|
| 1034 |
+
file_types=["pdf", ".docx", ".xlsx", ".csv"], # Explicit extensions for clarity
|
| 1035 |
+
file_count="single" # Ensure single file per component
|
| 1036 |
)
|
| 1037 |
+
doc_inputs.append(doc_file)
|
| 1038 |
+
all_inputs.extend(doc_inputs)
|
| 1039 |
|
| 1040 |
with gr.TabItem("π Audio/Video"):
|
| 1041 |
+
gr.Markdown("Upload audio or video files for transcription (MP3, WAV, MP4, MOV, etc.). Max 5 sources.")
|
| 1042 |
+
audio_video_inputs = []
|
| 1043 |
+
for i in range(1, 6):
|
| 1044 |
with gr.Group():
|
| 1045 |
gr.Markdown(f"**Source {i}**")
|
| 1046 |
+
audio_file = gr.File(
|
| 1047 |
+
label=f"Audio/Video File {i}",
|
| 1048 |
file_types=["audio", "video"]
|
| 1049 |
)
|
| 1050 |
with gr.Row():
|
| 1051 |
+
speaker_name = gr.Textbox(
|
| 1052 |
+
label="Speaker Name",
|
| 1053 |
+
placeholder="Name of the interviewee or speaker",
|
| 1054 |
+
value=""
|
| 1055 |
)
|
| 1056 |
+
speaker_role = gr.Textbox(
|
| 1057 |
+
label="Role/Position",
|
| 1058 |
+
placeholder="Speaker's title or role",
|
| 1059 |
+
value=""
|
| 1060 |
)
|
| 1061 |
+
audio_video_inputs.append(audio_file)
|
| 1062 |
+
audio_video_inputs.append(speaker_name)
|
| 1063 |
+
audio_video_inputs.append(speaker_role)
|
| 1064 |
+
all_inputs.extend(audio_video_inputs)
|
| 1065 |
|
| 1066 |
with gr.TabItem("π URLs"):
|
| 1067 |
+
gr.Markdown("Add URLs to relevant web pages or articles. Max 5.")
|
| 1068 |
+
url_inputs = []
|
| 1069 |
+
for i in range(1, 6):
|
| 1070 |
+
url_textbox = gr.Textbox(
|
| 1071 |
label=f"URL {i}",
|
| 1072 |
+
placeholder="https://example.com/article",
|
| 1073 |
+
value=""
|
| 1074 |
)
|
| 1075 |
+
url_inputs.append(url_textbox)
|
| 1076 |
+
all_inputs.extend(url_inputs)
|
| 1077 |
|
| 1078 |
with gr.TabItem("π± Social Media"):
|
| 1079 |
+
gr.Markdown("Add URLs to social media posts (e.g., Twitter, YouTube, TikTok). Max 3.")
|
| 1080 |
+
social_inputs = []
|
| 1081 |
+
for i in range(1, 4):
|
| 1082 |
with gr.Group():
|
| 1083 |
+
gr.Markdown(f"**Social Media Source {i}**")
|
| 1084 |
+
social_url_textbox = gr.Textbox(
|
| 1085 |
+
label=f"Post URL",
|
| 1086 |
+
placeholder="https://twitter.com/user/status/...",
|
| 1087 |
+
value=""
|
| 1088 |
)
|
| 1089 |
with gr.Row():
|
| 1090 |
+
social_name_textbox = gr.Textbox(
|
| 1091 |
+
label=f"Account Name/User",
|
| 1092 |
+
placeholder="Name or handle (e.g., @username)",
|
| 1093 |
+
value=""
|
| 1094 |
)
|
| 1095 |
+
social_context_textbox = gr.Textbox(
|
| 1096 |
+
label=f"Context",
|
| 1097 |
+
placeholder="Brief context (e.g., statement on event X)",
|
| 1098 |
+
value=""
|
| 1099 |
)
|
| 1100 |
+
social_inputs.append(social_url_textbox)
|
| 1101 |
+
social_inputs.append(social_name_textbox)
|
| 1102 |
+
social_inputs.append(social_context_textbox)
|
| 1103 |
+
all_inputs.extend(social_inputs)
|
| 1104 |
+
|
| 1105 |
|
| 1106 |
with gr.Row():
|
| 1107 |
+
generate_button = gr.Button("β¨ Generate News Article", variant="primary")
|
| 1108 |
+
clear_button = gr.Button("π Clear All Inputs")
|
| 1109 |
|
| 1110 |
with gr.Tabs():
|
| 1111 |
+
with gr.TabItem("π Generated News Article"):
|
| 1112 |
news_output = gr.Textbox(
|
| 1113 |
+
label="Draft News Article",
|
| 1114 |
+
lines=20, # Increased lines
|
| 1115 |
+
show_copy_button=True,
|
| 1116 |
+
value=""
|
| 1117 |
)
|
| 1118 |
+
with gr.TabItem("ποΈ Source Transcriptions & Logs"):
|
|
|
|
| 1119 |
transcriptions_output = gr.Textbox(
|
| 1120 |
+
label="Transcriptions and Processing Log",
|
| 1121 |
+
lines=15, # Increased lines
|
| 1122 |
+
show_copy_button=True,
|
| 1123 |
+
value=""
|
| 1124 |
)
|
| 1125 |
|
| 1126 |
+
# --- Event Handlers ---
|
| 1127 |
+
# Define outputs
|
| 1128 |
+
outputs_list = [news_output, transcriptions_output]
|
| 1129 |
+
|
| 1130 |
+
# Generate button click
|
| 1131 |
+
generate_button.click(
|
| 1132 |
fn=generate_news,
|
| 1133 |
+
inputs=all_inputs, # Pass the consolidated list
|
| 1134 |
+
outputs=outputs_list
|
| 1135 |
)
|
| 1136 |
+
|
| 1137 |
+
# Clear button click
|
| 1138 |
+
def clear_all_inputs_and_outputs():
|
| 1139 |
+
# Return a list of default values matching the number and type of inputs + outputs
|
| 1140 |
+
reset_values = []
|
| 1141 |
+
for input_comp in all_inputs:
|
| 1142 |
+
# Default for Textbox, Dropdown is "", for Slider is its default, for File is None
|
| 1143 |
+
if isinstance(input_comp, (gr.Textbox, gr.Dropdown)):
|
| 1144 |
+
reset_values.append("")
|
| 1145 |
+
elif isinstance(input_comp, gr.Slider):
|
| 1146 |
+
# Find the original default value if needed, or just use a sensible default
|
| 1147 |
+
reset_values.append(250) # Reset slider to default
|
| 1148 |
+
elif isinstance(input_comp, gr.File):
|
| 1149 |
+
reset_values.append(None)
|
| 1150 |
+
else:
|
| 1151 |
+
reset_values.append(None) # Default for unknown/other types
|
| 1152 |
+
|
| 1153 |
+
# Add default values for the output fields
|
| 1154 |
+
reset_values.extend(["", ""]) # Two Textbox outputs
|
| 1155 |
+
|
| 1156 |
+
# Also reset the models in the background
|
| 1157 |
+
model_manager.reset_models(force=True)
|
| 1158 |
+
logger.info("UI cleared and models reset.")
|
| 1159 |
+
|
| 1160 |
+
return reset_values
|
| 1161 |
+
|
| 1162 |
+
clear_button.click(
|
| 1163 |
+
fn=clear_all_inputs_and_outputs,
|
| 1164 |
+
inputs=None, # No inputs needed for the clear function itself
|
| 1165 |
+
outputs=all_inputs + outputs_list # The list of components to clear
|
| 1166 |
)
|
| 1167 |
|
| 1168 |
+
# Add event handler to reset models when the Gradio app closes or reloads (if possible)
|
| 1169 |
+
# demo.unload(model_manager.reset_models, inputs=None, outputs=None) # Might not work reliably in Spaces
|
| 1170 |
+
|
| 1171 |
return demo
|
| 1172 |
|
| 1173 |
if __name__ == "__main__":
|
| 1174 |
+
logger.info("Starting NewsIA application...")
|
| 1175 |
+
|
| 1176 |
+
# Optional: Pre-initialize Whisper on startup if desired and resources allow
|
| 1177 |
+
# This can make the first transcription faster but uses GPU resources immediately.
|
| 1178 |
+
# Consider enabling only if transcriptions are very common.
|
| 1179 |
+
# try:
|
| 1180 |
+
# logger.info("Attempting to pre-initialize Whisper model...")
|
| 1181 |
+
# model_manager.initialize_whisper()
|
| 1182 |
+
# except Exception as e:
|
| 1183 |
+
# logger.warning(f"Pre-initialization of Whisper model failed (will load on demand): {str(e)}")
|
| 1184 |
+
|
| 1185 |
+
# Create the Gradio Demo
|
| 1186 |
+
news_demo = create_demo()
|
| 1187 |
+
|
| 1188 |
+
# Configure the queue - remove concurrency_count and max_size
|
| 1189 |
+
# Use default queue settings, suitable for most Spaces environments
|
| 1190 |
+
news_demo.queue()
|
| 1191 |
+
|
| 1192 |
+
# Launch the Gradio app
|
| 1193 |
+
logger.info("Launching Gradio interface...")
|
| 1194 |
+
news_demo.launch(
|
| 1195 |
+
server_name="0.0.0.0", # Necessary for Docker/Spaces
|
| 1196 |
+
server_port=7860,
|
| 1197 |
+
# share=True # Share=True is often handled by Spaces automatically, can be removed
|
| 1198 |
+
# debug=True # Enable for more detailed Gradio logs if needed
|
| 1199 |
+
)
|
| 1200 |
+
logger.info("NewsIA application finished.")
|