Spaces:
Sleeping
Sleeping
rbcurzon_laptop
Revert "refactor: adjust temperature parameter in translation model configuration"
9b436d2
| # Standard library imports | |
| import os | |
| import time | |
| import tempfile | |
| import logging | |
| from timeit import default_timer as timer | |
| from contextlib import asynccontextmanager | |
| # Third-party imports | |
| import numpy as np | |
| import scipy.io.wavfile | |
| import torch | |
| import torchaudio | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Form | |
| from fastapi.responses import FileResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from starlette.background import BackgroundTask | |
| from transformers import pipeline, VitsModel, VitsTokenizer | |
| # External service imports | |
| from google import genai | |
| from google.genai import types | |
| from silero_vad import ( | |
| load_silero_vad, | |
| read_audio, | |
| get_speech_timestamps, | |
| save_audio, | |
| collect_chunks, | |
| ) | |
| # Logging configuration | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| async def lifespan(app: FastAPI): | |
| # Load models once at startup and store in app.state | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| model_id = "rbcurzon/whisper-large-v3-turbo" | |
| app.state.pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model=model_id, | |
| chunk_length_s=30, | |
| device=device | |
| ) | |
| app.state.vad_model = load_silero_vad() | |
| app.state.client = genai.Client(api_key=os.environ.get("GENAI_API_KEY")) | |
| yield | |
| # Optionally, add cleanup code here | |
| # FastAPI app initialization | |
| app = FastAPI( | |
| title="Real-Time Audio Processor", | |
| description="Process and transcribe audio in real-time using Whisper", | |
| lifespan=lifespan | |
| ) | |
| def remove_silence(filename): | |
| """Remove silence from an audio file using Silero VAD.""" | |
| sampling_rate = 16000 | |
| try: | |
| wav = read_audio(filename, sampling_rate=sampling_rate) | |
| speech_timestamps = get_speech_timestamps(wav, app.state.vad_model, sampling_rate=sampling_rate) | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name | |
| save_audio( | |
| temp_file, | |
| collect_chunks(speech_timestamps, wav), | |
| sampling_rate=sampling_rate | |
| ) | |
| return temp_file | |
| except Exception as error: | |
| logging.error(f"Error removing silence from {filename}: {error}") | |
| raise HTTPException(status_code=500, detail=str(error)) | |
| def translate(text, srcLang, tgtLang): | |
| """Translate text from srcLang to tgtLang using Gemini API.""" | |
| prompt = f"Translate the following text: '{text}'" | |
| response = app.state.client.models.generate_content( | |
| model="gemini-2.5-flash-lite", | |
| contents=prompt, | |
| config=types.GenerateContentConfig( | |
| system_instruction=f"You are an expert translator. Your task is to translate from {srcLang} to {tgtLang}. You must provide ONLY the translated text. Do not include any explanations, additional commentary, or conversational language. Just the translated text.", | |
| thinking_config=types.ThinkingConfig(thinking_budget=0), # Disables thinking | |
| ) | |
| ) | |
| return response.text | |
| def remove_file(file): | |
| """Remove a file after a delay (for background cleanup).""" | |
| time.sleep(600) # delay for 10 minutes | |
| os.remove(file) | |
| # API Endpoints | |
| def read_root(): | |
| return { | |
| "detail": "Philippine Regional Language Translator" | |
| } | |
| async def translate_audio( | |
| file: UploadFile = File(...), | |
| srcLang: str = Form("Tagalog"), | |
| tgtLang: str = Form("Cebuano") | |
| ): | |
| start = timer() | |
| temp_file = None # initialize temp_file to None | |
| try: | |
| content = await file.read() | |
| with open(file.filename, 'wb') as f: | |
| f.write(content) | |
| print(f"Successfully uploaded {file.filename}") | |
| generate_kwargs = { | |
| "max_new_tokens": 448-4, | |
| "num_beams": 1, | |
| "condition_on_prev_tokens": False, | |
| "compression_ratio_threshold": 1.35, | |
| "temperature": 0.0, | |
| "logprob_threshold": -1.0, | |
| "no_speech_threshold": 0.6, | |
| "return_timestamps": True, | |
| } | |
| temp_file = remove_silence(file.filename) | |
| result = app.state.pipe( | |
| temp_file, | |
| batch_size=2, | |
| generate_kwargs=generate_kwargs | |
| ) | |
| result_dict = { | |
| "transcribed_text": result['text'], | |
| "translated_text": translate(result['text'], srcLang=srcLang, tgtLang=tgtLang), | |
| "srcLang": srcLang, | |
| "tgtLang": tgtLang | |
| } | |
| return result_dict | |
| except Exception as error: | |
| logging.error(f"Error translating audio {file.filename}: {error}") | |
| raise HTTPException(status_code=500, detail=str(error)) | |
| finally: | |
| if file.file: | |
| file.file.close() | |
| if os.path.exists(file.filename): | |
| os.remove(file.filename) | |
| if temp_file is not None and os.path.exists(temp_file): | |
| os.remove(temp_file) | |
| end = timer() | |
| logging.info(f"Translation completed for audio {file.filename} in {end - start:.2f} seconds") | |
| async def translate_text( | |
| text: str, | |
| srcLang: str = Form(...), | |
| tgtLang: str = Form(...) | |
| ): | |
| start = timer() | |
| result = translate(text, srcLang, tgtLang) | |
| if not result: | |
| logging.error("Translation failed for text: %s", text) | |
| raise HTTPException(status_code=500, detail="Translation failed") | |
| result_dict = { | |
| "text": text, | |
| "translated_text": result, | |
| "srcLang": srcLang, | |
| "tgtLang": tgtLang | |
| } | |
| end = timer() | |
| logging.info(f"Translation completed for text: {text} in {end - start:.2f} seconds") | |
| return result_dict | |
| async def synthesize(text: str = Form(...)): | |
| start = timer() | |
| model = VitsModel.from_pretrained("facebook/mms-tts-tgl") | |
| tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-tgl") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| inputs = tokenizer(text, return_tensors="pt") | |
| input_ids = inputs["input_ids"].to(device) | |
| with torch.no_grad(): | |
| outputs = model(input_ids) | |
| speech = outputs["waveform"] | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name | |
| torchaudio.save(temp_file, speech.cpu(), 16000) | |
| logging.info(f"Synthesizing completed for text: {text}") | |
| end = timer() | |
| logging.info(f"Synthesis completed for text: {text} in {end - start:.2f} seconds") | |
| return FileResponse( | |
| temp_file, | |
| media_type="audio/wav", | |
| filename="speech.wav", | |
| background=BackgroundTask(remove_file, temp_file) | |
| ) | |