# ============================ # models.py — Model Loader, Globals # ============================ import torch from faster_whisper import WhisperModel from pyannote.audio import Pipeline # from torch.serialization import add_safe_globals from omegaconf import ListConfig import logging import asyncio from together import Together from pyngrok import conf from config import HF_CACHE_DIR, TORCH_CACHE_DIR, token, together_api_key, ngrok_auth_token,PYTHAI_NLP import os PREFERRED_MODEL = os.environ.get("WHISPER_MODEL", "large-v3-turbo") FALLBACK_MODELS = os.environ.get("WHISPER_MODEL_FALLBACK", "large-v2,large,medium,small,tiny").split(",") # ประกาศ global variables ที่นี่ logger = logging.getLogger(__name__) model_lock = asyncio.Lock() pipelines = [] models = [] others = [] overlap_pipeline = None def setup_together_and_ngrok(): together = Together(api_key=together_api_key) conf.get_default().auth_token = ngrok_auth_token # add_safe_globals({ListConfig}) return together together = setup_together_and_ngrok() async def load_model_bundle(): global pipelines, models # , overlap_pipeline if pipelines and models: logger.info("✅ Models already loaded. Skipping reinitialization.") return pipelines[0], models[0] def _load_models(): n = torch.cuda.device_count() logger.info(f"🖥️ Found {n} CUDA device(s)") if n == 0: device_str = "cpu" device_torch = torch.device(device_str) elif n == 1: device_str = "cuda" device_torch = torch.device(device_str) else: device_str = "cuda" device_torch = torch.device("cuda:0") pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.1", use_auth_token=token, cache_dir=HF_CACHE_DIR ).to(device_torch) # overlap_pipeline = Pipeline.from_pretrained( # "pyannote/overlapped-speech-detection", # use_auth_token=token, # cache_dir=HF_CACHE_DIR # ใช้ cache เดียวกับโมเดลอื่น # ) model_fallback_chain = [PREFERRED_MODEL] + [m for m in FALLBACK_MODELS if m != PREFERRED_MODEL] model = None for model_name in model_fallback_chain: try: logger.info(f"🔍 Trying to load Whisper model: {model_name}") model = WhisperModel(model_name, device=device_str, compute_type="float16") logger.info(f"✅ Loaded Whisper model: {model_name}") break except Exception as e: logger.warning(f"⚠️ Failed to load {model_name}: {e}") if model is None: raise RuntimeError("❌ Failed to load any Whisper model from fallback chain.") pipelines.append(pipeline) models.append(model) return pipeline, model, loop = asyncio.get_event_loop() return await loop.run_in_executor(None, _load_models)