import spaces from typing import Optional import logging import time import threading import torch import librosa from transformers import pipeline, AutoProcessor, AutoModelForCausalLM, Pipeline from accelerate import Accelerator # Set up logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) try: import subprocess subprocess.run( "pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, ) logger.info("Flash Attention installed successfully.") USE_FA = True except: USE_FA = False logger.warning("Flash Attention not available. Using standard attention instead.") # Model constants MODEL_ID = "JacobLinCool/whisper-large-v3-turbo-common_voice_19_0-zh-TW" PHI_MODEL_ID = "JacobLinCool/Phi-4-multimodal-instruct-commonvoice-zh-tw" # Model instances (initialized lazily) pipe: Optional[Pipeline] = None phi_model = None phi_processor = None # Lock for thread-safe model loading model_loading_lock = threading.Lock() def load_model() -> None: """ Load the Whisper model for transcription. Uses GPU if available. """ global pipe if pipe is not None: return # Model already loaded try: start_time = time.time() logger.info(f"Loading Whisper model {MODEL_ID}...") device = Accelerator().device pipe = pipeline("automatic-speech-recognition", model=MODEL_ID, device=device) logger.info( f"Model loaded successfully in {time.time() - start_time:.2f} seconds" ) except Exception as e: logger.error(f"Failed to load Whisper model: {str(e)}") raise def get_gpu_duration(audio: str) -> int: """ Calculate required GPU allocation time based on audio duration. Args: audio: Path to audio file Returns: GPU allocation time in seconds """ try: y, sr = librosa.load(audio) duration = librosa.get_duration(y=y, sr=sr) / 60.0 gpu_duration = max(1.0, (duration + 59.0) // 60.0) * 60.0 logger.info( f"Audio duration: {duration:.2f} min, Allocated GPU time: {gpu_duration:.2f} min" ) return int(gpu_duration) except Exception as e: logger.error(f"Failed to calculate GPU duration: {str(e)}") return 60 # Default to 1 minute if calculation fails @spaces.GPU(duration=get_gpu_duration) def transcribe_audio_local(audio: str) -> str: """ Transcribe audio using the Whisper model. Args: audio: Path to audio file Returns: Transcribed text """ try: logger.info(f"Transcribing audio with Whisper: {audio}") if pipe is None: load_model() out = pipe(audio, return_timestamps=True) return out.get("text", "No transcription generated") except Exception as e: logger.error(f"Whisper transcription error: {str(e)}") raise def load_phi_model() -> None: """ Load the Phi-4 model and processor. Uses GPU with Flash Attention if available. """ global phi_model, phi_processor if phi_model is not None and phi_processor is not None: return # Model already loaded try: start_time = time.time() logger.info(f"Loading Phi-4 model {PHI_MODEL_ID}...") phi_processor = AutoProcessor.from_pretrained( PHI_MODEL_ID, trust_remote_code=True ) device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.bfloat16 if USE_FA else torch.float32 attn_implementation = "flash_attention_2" if USE_FA else "sdpa" phi_model = AutoModelForCausalLM.from_pretrained( PHI_MODEL_ID, torch_dtype=dtype, _attn_implementation=attn_implementation, trust_remote_code=True, ).to(device) logger.info( f"Phi-4 model loaded successfully in {time.time() - start_time:.2f} seconds" ) except Exception as e: logger.error(f"Failed to load Phi-4 model: {str(e)}") raise @spaces.GPU(duration=get_gpu_duration) def transcribe_audio_phi(audio: str) -> str: """ Transcribe audio using the Phi-4 model. Args: audio: Path to audio file Returns: Transcribed text """ try: logger.info(f"Transcribing audio with Phi-4: {audio}") load_phi_model() # Load and resample audio to 16kHz y, sr = librosa.load(audio, sr=16000) # Prepare the user message and generate the prompt user_message = { "role": "user", "content": "<|audio_1|> Transcribe the audio clip into text.", } prompt = phi_processor.tokenizer.apply_chat_template( [user_message], tokenize=False, add_generation_prompt=True ) # Build inputs for the model inputs = phi_processor(text=prompt, audios=[(y, sr)], return_tensors="pt") inputs = { k: v.to(phi_model.device) if hasattr(v, "to") else v for k, v in inputs.items() } # Generate transcription without gradients with torch.no_grad(): generated_ids = phi_model.generate( **inputs, eos_token_id=phi_processor.tokenizer.eos_token_id, max_new_tokens=256, # Increased for longer transcriptions do_sample=False, ) # Decode the generated token IDs into text transcription = phi_processor.decode( generated_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True, clean_up_tokenization_spaces=False, ) logger.info(f"Phi-4 transcription completed successfully") return transcription except Exception as e: logger.error(f"Phi-4 transcription error: {str(e)}") raise def preload_models() -> None: """ Preload models into memory to reduce cold start time. This function can be called at application startup. """ try: logger.info("Preloading models to reduce cold start time") # Load Whisper model first as it's the default load_model() # Then load Phi model load_phi_model() logger.info("All models preloaded successfully") except Exception as e: logger.error(f"Error during model preloading: {str(e)}")