|
|
|
|
|
|
|
import torch |
|
from faster_whisper import WhisperModel |
|
from pyannote.audio import Pipeline |
|
|
|
from omegaconf import ListConfig |
|
import logging |
|
import asyncio |
|
from together import Together |
|
from pyngrok import conf |
|
from config import HF_CACHE_DIR, TORCH_CACHE_DIR, token, together_api_key, ngrok_auth_token,PYTHAI_NLP |
|
import os |
|
|
|
|
|
PREFERRED_MODEL = os.environ.get("WHISPER_MODEL", "large-v3-turbo") |
|
FALLBACK_MODELS = os.environ.get("WHISPER_MODEL_FALLBACK", "large-v2,large,medium,small,tiny").split(",") |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
model_lock = asyncio.Lock() |
|
pipelines = [] |
|
models = [] |
|
others = [] |
|
overlap_pipeline = None |
|
|
|
def setup_together_and_ngrok(): |
|
together = Together(api_key=together_api_key) |
|
conf.get_default().auth_token = ngrok_auth_token |
|
|
|
return together |
|
|
|
together = setup_together_and_ngrok() |
|
|
|
async def load_model_bundle(): |
|
global pipelines, models |
|
|
|
if pipelines and models: |
|
logger.info("✅ Models already loaded. Skipping reinitialization.") |
|
return pipelines[0], models[0] |
|
def _load_models(): |
|
n = torch.cuda.device_count() |
|
logger.info(f"🖥️ Found {n} CUDA device(s)") |
|
if n == 0: |
|
device_str = "cpu" |
|
device_torch = torch.device(device_str) |
|
elif n == 1: |
|
device_str = "cuda" |
|
device_torch = torch.device(device_str) |
|
else: |
|
device_str = "cuda" |
|
device_torch = torch.device("cuda:0") |
|
|
|
pipeline = Pipeline.from_pretrained( |
|
"pyannote/speaker-diarization-3.1", |
|
use_auth_token=token, |
|
cache_dir=HF_CACHE_DIR |
|
).to(device_torch) |
|
|
|
|
|
|
|
|
|
|
|
|
|
model_fallback_chain = [PREFERRED_MODEL] + [m for m in FALLBACK_MODELS if m != PREFERRED_MODEL] |
|
|
|
model = None |
|
for model_name in model_fallback_chain: |
|
try: |
|
logger.info(f"🔍 Trying to load Whisper model: {model_name}") |
|
model = WhisperModel(model_name, device=device_str, compute_type="float16") |
|
logger.info(f"✅ Loaded Whisper model: {model_name}") |
|
break |
|
except Exception as e: |
|
logger.warning(f"⚠️ Failed to load {model_name}: {e}") |
|
|
|
if model is None: |
|
raise RuntimeError("❌ Failed to load any Whisper model from fallback chain.") |
|
|
|
pipelines.append(pipeline) |
|
models.append(model) |
|
return pipeline, model, |
|
|
|
loop = asyncio.get_event_loop() |
|
return await loop.run_in_executor(None, _load_models) |