TWASR / model.py
JacobLinCool's picture
feat: add GPU decorator to transcribe_audio_phi function
5f56a57
import spaces
from typing import Optional
import logging
import time
import threading
import torch
import librosa
from transformers import pipeline, AutoProcessor, AutoModelForCausalLM, Pipeline
from accelerate import Accelerator
# Set up logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
try:
import subprocess
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
logger.info("Flash Attention installed successfully.")
USE_FA = True
except:
USE_FA = False
logger.warning("Flash Attention not available. Using standard attention instead.")
# Model constants
MODEL_ID = "JacobLinCool/whisper-large-v3-turbo-common_voice_19_0-zh-TW"
PHI_MODEL_ID = "JacobLinCool/Phi-4-multimodal-instruct-commonvoice-zh-tw"
# Model instances (initialized lazily)
pipe: Optional[Pipeline] = None
phi_model = None
phi_processor = None
# Lock for thread-safe model loading
model_loading_lock = threading.Lock()
def load_model() -> None:
"""
Load the Whisper model for transcription.
Uses GPU if available.
"""
global pipe
if pipe is not None:
return # Model already loaded
try:
start_time = time.time()
logger.info(f"Loading Whisper model {MODEL_ID}...")
device = Accelerator().device
pipe = pipeline("automatic-speech-recognition", model=MODEL_ID, device=device)
logger.info(
f"Model loaded successfully in {time.time() - start_time:.2f} seconds"
)
except Exception as e:
logger.error(f"Failed to load Whisper model: {str(e)}")
raise
def get_gpu_duration(audio: str) -> int:
"""
Calculate required GPU allocation time based on audio duration.
Args:
audio: Path to audio file
Returns:
GPU allocation time in seconds
"""
try:
y, sr = librosa.load(audio)
duration = librosa.get_duration(y=y, sr=sr) / 60.0
gpu_duration = max(1.0, (duration + 59.0) // 60.0) * 60.0
logger.info(
f"Audio duration: {duration:.2f} min, Allocated GPU time: {gpu_duration:.2f} min"
)
return int(gpu_duration)
except Exception as e:
logger.error(f"Failed to calculate GPU duration: {str(e)}")
return 60 # Default to 1 minute if calculation fails
@spaces.GPU(duration=get_gpu_duration)
def transcribe_audio_local(audio: str) -> str:
"""
Transcribe audio using the Whisper model.
Args:
audio: Path to audio file
Returns:
Transcribed text
"""
try:
logger.info(f"Transcribing audio with Whisper: {audio}")
if pipe is None:
load_model()
out = pipe(audio, return_timestamps=True)
return out.get("text", "No transcription generated")
except Exception as e:
logger.error(f"Whisper transcription error: {str(e)}")
raise
def load_phi_model() -> None:
"""
Load the Phi-4 model and processor.
Uses GPU with Flash Attention if available.
"""
global phi_model, phi_processor
if phi_model is not None and phi_processor is not None:
return # Model already loaded
try:
start_time = time.time()
logger.info(f"Loading Phi-4 model {PHI_MODEL_ID}...")
phi_processor = AutoProcessor.from_pretrained(
PHI_MODEL_ID, trust_remote_code=True
)
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if USE_FA else torch.float32
attn_implementation = "flash_attention_2" if USE_FA else "sdpa"
phi_model = AutoModelForCausalLM.from_pretrained(
PHI_MODEL_ID,
torch_dtype=dtype,
_attn_implementation=attn_implementation,
trust_remote_code=True,
).to(device)
logger.info(
f"Phi-4 model loaded successfully in {time.time() - start_time:.2f} seconds"
)
except Exception as e:
logger.error(f"Failed to load Phi-4 model: {str(e)}")
raise
@spaces.GPU(duration=get_gpu_duration)
def transcribe_audio_phi(audio: str) -> str:
"""
Transcribe audio using the Phi-4 model.
Args:
audio: Path to audio file
Returns:
Transcribed text
"""
try:
logger.info(f"Transcribing audio with Phi-4: {audio}")
load_phi_model()
# Load and resample audio to 16kHz
y, sr = librosa.load(audio, sr=16000)
# Prepare the user message and generate the prompt
user_message = {
"role": "user",
"content": "<|audio_1|> Transcribe the audio clip into text.",
}
prompt = phi_processor.tokenizer.apply_chat_template(
[user_message], tokenize=False, add_generation_prompt=True
)
# Build inputs for the model
inputs = phi_processor(text=prompt, audios=[(y, sr)], return_tensors="pt")
inputs = {
k: v.to(phi_model.device) if hasattr(v, "to") else v
for k, v in inputs.items()
}
# Generate transcription without gradients
with torch.no_grad():
generated_ids = phi_model.generate(
**inputs,
eos_token_id=phi_processor.tokenizer.eos_token_id,
max_new_tokens=256, # Increased for longer transcriptions
do_sample=False,
)
# Decode the generated token IDs into text
transcription = phi_processor.decode(
generated_ids[0, inputs["input_ids"].shape[1] :],
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
logger.info(f"Phi-4 transcription completed successfully")
return transcription
except Exception as e:
logger.error(f"Phi-4 transcription error: {str(e)}")
raise
def preload_models() -> None:
"""
Preload models into memory to reduce cold start time.
This function can be called at application startup.
"""
try:
logger.info("Preloading models to reduce cold start time")
# Load Whisper model first as it's the default
load_model()
# Then load Phi model
load_phi_model()
logger.info("All models preloaded successfully")
except Exception as e:
logger.error(f"Error during model preloading: {str(e)}")