Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
from typing import Optional | |
import logging | |
import time | |
import threading | |
import torch | |
import librosa | |
from transformers import pipeline, AutoProcessor, AutoModelForCausalLM, Pipeline | |
from accelerate import Accelerator | |
# Set up logging | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
) | |
logger = logging.getLogger(__name__) | |
try: | |
import subprocess | |
subprocess.run( | |
"pip install flash-attn --no-build-isolation", | |
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
shell=True, | |
) | |
logger.info("Flash Attention installed successfully.") | |
USE_FA = True | |
except: | |
USE_FA = False | |
logger.warning("Flash Attention not available. Using standard attention instead.") | |
# Model constants | |
MODEL_ID = "JacobLinCool/whisper-large-v3-turbo-common_voice_19_0-zh-TW" | |
PHI_MODEL_ID = "JacobLinCool/Phi-4-multimodal-instruct-commonvoice-zh-tw" | |
# Model instances (initialized lazily) | |
pipe: Optional[Pipeline] = None | |
phi_model = None | |
phi_processor = None | |
# Lock for thread-safe model loading | |
model_loading_lock = threading.Lock() | |
def load_model() -> None: | |
""" | |
Load the Whisper model for transcription. | |
Uses GPU if available. | |
""" | |
global pipe | |
if pipe is not None: | |
return # Model already loaded | |
try: | |
start_time = time.time() | |
logger.info(f"Loading Whisper model {MODEL_ID}...") | |
device = Accelerator().device | |
pipe = pipeline("automatic-speech-recognition", model=MODEL_ID, device=device) | |
logger.info( | |
f"Model loaded successfully in {time.time() - start_time:.2f} seconds" | |
) | |
except Exception as e: | |
logger.error(f"Failed to load Whisper model: {str(e)}") | |
raise | |
def get_gpu_duration(audio: str) -> int: | |
""" | |
Calculate required GPU allocation time based on audio duration. | |
Args: | |
audio: Path to audio file | |
Returns: | |
GPU allocation time in seconds | |
""" | |
try: | |
y, sr = librosa.load(audio) | |
duration = librosa.get_duration(y=y, sr=sr) / 60.0 | |
gpu_duration = max(1.0, (duration + 59.0) // 60.0) * 60.0 | |
logger.info( | |
f"Audio duration: {duration:.2f} min, Allocated GPU time: {gpu_duration:.2f} min" | |
) | |
return int(gpu_duration) | |
except Exception as e: | |
logger.error(f"Failed to calculate GPU duration: {str(e)}") | |
return 60 # Default to 1 minute if calculation fails | |
def transcribe_audio_local(audio: str) -> str: | |
""" | |
Transcribe audio using the Whisper model. | |
Args: | |
audio: Path to audio file | |
Returns: | |
Transcribed text | |
""" | |
try: | |
logger.info(f"Transcribing audio with Whisper: {audio}") | |
if pipe is None: | |
load_model() | |
out = pipe(audio, return_timestamps=True) | |
return out.get("text", "No transcription generated") | |
except Exception as e: | |
logger.error(f"Whisper transcription error: {str(e)}") | |
raise | |
def load_phi_model() -> None: | |
""" | |
Load the Phi-4 model and processor. | |
Uses GPU with Flash Attention if available. | |
""" | |
global phi_model, phi_processor | |
if phi_model is not None and phi_processor is not None: | |
return # Model already loaded | |
try: | |
start_time = time.time() | |
logger.info(f"Loading Phi-4 model {PHI_MODEL_ID}...") | |
phi_processor = AutoProcessor.from_pretrained( | |
PHI_MODEL_ID, trust_remote_code=True | |
) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.bfloat16 if USE_FA else torch.float32 | |
attn_implementation = "flash_attention_2" if USE_FA else "sdpa" | |
phi_model = AutoModelForCausalLM.from_pretrained( | |
PHI_MODEL_ID, | |
torch_dtype=dtype, | |
_attn_implementation=attn_implementation, | |
trust_remote_code=True, | |
).to(device) | |
logger.info( | |
f"Phi-4 model loaded successfully in {time.time() - start_time:.2f} seconds" | |
) | |
except Exception as e: | |
logger.error(f"Failed to load Phi-4 model: {str(e)}") | |
raise | |
def transcribe_audio_phi(audio: str) -> str: | |
""" | |
Transcribe audio using the Phi-4 model. | |
Args: | |
audio: Path to audio file | |
Returns: | |
Transcribed text | |
""" | |
try: | |
logger.info(f"Transcribing audio with Phi-4: {audio}") | |
load_phi_model() | |
# Load and resample audio to 16kHz | |
y, sr = librosa.load(audio, sr=16000) | |
# Prepare the user message and generate the prompt | |
user_message = { | |
"role": "user", | |
"content": "<|audio_1|> Transcribe the audio clip into text.", | |
} | |
prompt = phi_processor.tokenizer.apply_chat_template( | |
[user_message], tokenize=False, add_generation_prompt=True | |
) | |
# Build inputs for the model | |
inputs = phi_processor(text=prompt, audios=[(y, sr)], return_tensors="pt") | |
inputs = { | |
k: v.to(phi_model.device) if hasattr(v, "to") else v | |
for k, v in inputs.items() | |
} | |
# Generate transcription without gradients | |
with torch.no_grad(): | |
generated_ids = phi_model.generate( | |
**inputs, | |
eos_token_id=phi_processor.tokenizer.eos_token_id, | |
max_new_tokens=256, # Increased for longer transcriptions | |
do_sample=False, | |
) | |
# Decode the generated token IDs into text | |
transcription = phi_processor.decode( | |
generated_ids[0, inputs["input_ids"].shape[1] :], | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=False, | |
) | |
logger.info(f"Phi-4 transcription completed successfully") | |
return transcription | |
except Exception as e: | |
logger.error(f"Phi-4 transcription error: {str(e)}") | |
raise | |
def preload_models() -> None: | |
""" | |
Preload models into memory to reduce cold start time. | |
This function can be called at application startup. | |
""" | |
try: | |
logger.info("Preloading models to reduce cold start time") | |
# Load Whisper model first as it's the default | |
load_model() | |
# Then load Phi model | |
load_phi_model() | |
logger.info("All models preloaded successfully") | |
except Exception as e: | |
logger.error(f"Error during model preloading: {str(e)}") | |