sivakorn-su
delete overlab
b0c0237
# ============================
# models.py — Model Loader, Globals
# ============================
import torch
from faster_whisper import WhisperModel
from pyannote.audio import Pipeline
# from torch.serialization import add_safe_globals
from omegaconf import ListConfig
import logging
import asyncio
from together import Together
from pyngrok import conf
from config import HF_CACHE_DIR, TORCH_CACHE_DIR, token, together_api_key, ngrok_auth_token,PYTHAI_NLP
import os
PREFERRED_MODEL = os.environ.get("WHISPER_MODEL", "large-v3-turbo")
FALLBACK_MODELS = os.environ.get("WHISPER_MODEL_FALLBACK", "large-v2,large,medium,small,tiny").split(",")
# ประกาศ global variables ที่นี่
logger = logging.getLogger(__name__)
model_lock = asyncio.Lock()
pipelines = []
models = []
others = []
overlap_pipeline = None
def setup_together_and_ngrok():
together = Together(api_key=together_api_key)
conf.get_default().auth_token = ngrok_auth_token
# add_safe_globals({ListConfig})
return together
together = setup_together_and_ngrok()
async def load_model_bundle():
global pipelines, models
# , overlap_pipeline
if pipelines and models:
logger.info("✅ Models already loaded. Skipping reinitialization.")
return pipelines[0], models[0]
def _load_models():
n = torch.cuda.device_count()
logger.info(f"🖥️ Found {n} CUDA device(s)")
if n == 0:
device_str = "cpu"
device_torch = torch.device(device_str)
elif n == 1:
device_str = "cuda"
device_torch = torch.device(device_str)
else:
device_str = "cuda"
device_torch = torch.device("cuda:0")
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=token,
cache_dir=HF_CACHE_DIR
).to(device_torch)
# overlap_pipeline = Pipeline.from_pretrained(
# "pyannote/overlapped-speech-detection",
# use_auth_token=token,
# cache_dir=HF_CACHE_DIR # ใช้ cache เดียวกับโมเดลอื่น
# )
model_fallback_chain = [PREFERRED_MODEL] + [m for m in FALLBACK_MODELS if m != PREFERRED_MODEL]
model = None
for model_name in model_fallback_chain:
try:
logger.info(f"🔍 Trying to load Whisper model: {model_name}")
model = WhisperModel(model_name, device=device_str, compute_type="float16")
logger.info(f"✅ Loaded Whisper model: {model_name}")
break
except Exception as e:
logger.warning(f"⚠️ Failed to load {model_name}: {e}")
if model is None:
raise RuntimeError("❌ Failed to load any Whisper model from fallback chain.")
pipelines.append(pipeline)
models.append(model)
return pipeline, model,
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, _load_models)