import spaces
from typing import Optional
import logging
import time
import threading

import torch
import librosa
from transformers import pipeline, AutoProcessor, AutoModelForCausalLM, Pipeline
from accelerate import Accelerator

# Set up logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

try:
    import subprocess

    subprocess.run(
        "pip install flash-attn --no-build-isolation",
        env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
        shell=True,
    )
    logger.info("Flash Attention installed successfully.")
    USE_FA = True
except:
    USE_FA = False
    logger.warning("Flash Attention not available. Using standard attention instead.")

# Model constants
MODEL_ID = "JacobLinCool/whisper-large-v3-turbo-common_voice_19_0-zh-TW"
PHI_MODEL_ID = "JacobLinCool/Phi-4-multimodal-instruct-commonvoice-zh-tw"

# Model instances (initialized lazily)
pipe: Optional[Pipeline] = None
phi_model = None
phi_processor = None

# Lock for thread-safe model loading
model_loading_lock = threading.Lock()


def load_model() -> None:
    """
    Load the Whisper model for transcription.
    Uses GPU if available.
    """
    global pipe
    if pipe is not None:
        return  # Model already loaded

    try:
        start_time = time.time()
        logger.info(f"Loading Whisper model {MODEL_ID}...")
        device = Accelerator().device
        pipe = pipeline("automatic-speech-recognition", model=MODEL_ID, device=device)
        logger.info(
            f"Model loaded successfully in {time.time() - start_time:.2f} seconds"
        )
    except Exception as e:
        logger.error(f"Failed to load Whisper model: {str(e)}")
        raise


def get_gpu_duration(audio: str) -> int:
    """
    Calculate required GPU allocation time based on audio duration.

    Args:
        audio: Path to audio file

    Returns:
        GPU allocation time in seconds
    """
    try:
        y, sr = librosa.load(audio)
        duration = librosa.get_duration(y=y, sr=sr) / 60.0
        gpu_duration = max(1.0, (duration + 59.0) // 60.0) * 60.0
        logger.info(
            f"Audio duration: {duration:.2f} min, Allocated GPU time: {gpu_duration:.2f} min"
        )
        return int(gpu_duration)
    except Exception as e:
        logger.error(f"Failed to calculate GPU duration: {str(e)}")
        return 60  # Default to 1 minute if calculation fails


@spaces.GPU(duration=get_gpu_duration)
def transcribe_audio_local(audio: str) -> str:
    """
    Transcribe audio using the Whisper model.

    Args:
        audio: Path to audio file

    Returns:
        Transcribed text
    """
    try:
        logger.info(f"Transcribing audio with Whisper: {audio}")
        if pipe is None:
            load_model()

        out = pipe(audio, return_timestamps=True)
        return out.get("text", "No transcription generated")
    except Exception as e:
        logger.error(f"Whisper transcription error: {str(e)}")
        raise


def load_phi_model() -> None:
    """
    Load the Phi-4 model and processor.
    Uses GPU with Flash Attention if available.
    """
    global phi_model, phi_processor
    if phi_model is not None and phi_processor is not None:
        return  # Model already loaded

    try:
        start_time = time.time()
        logger.info(f"Loading Phi-4 model {PHI_MODEL_ID}...")

        phi_processor = AutoProcessor.from_pretrained(
            PHI_MODEL_ID, trust_remote_code=True
        )

        device = "cuda" if torch.cuda.is_available() else "cpu"
        dtype = torch.bfloat16 if USE_FA else torch.float32
        attn_implementation = "flash_attention_2" if USE_FA else "sdpa"

        phi_model = AutoModelForCausalLM.from_pretrained(
            PHI_MODEL_ID,
            torch_dtype=dtype,
            _attn_implementation=attn_implementation,
            trust_remote_code=True,
        ).to(device)

        logger.info(
            f"Phi-4 model loaded successfully in {time.time() - start_time:.2f} seconds"
        )
    except Exception as e:
        logger.error(f"Failed to load Phi-4 model: {str(e)}")
        raise


@spaces.GPU(duration=get_gpu_duration)
def transcribe_audio_phi(audio: str) -> str:
    """
    Transcribe audio using the Phi-4 model.

    Args:
        audio: Path to audio file

    Returns:
        Transcribed text
    """
    try:
        logger.info(f"Transcribing audio with Phi-4: {audio}")
        load_phi_model()

        # Load and resample audio to 16kHz
        y, sr = librosa.load(audio, sr=16000)

        # Prepare the user message and generate the prompt
        user_message = {
            "role": "user",
            "content": "<|audio_1|> Transcribe the audio clip into text.",
        }
        prompt = phi_processor.tokenizer.apply_chat_template(
            [user_message], tokenize=False, add_generation_prompt=True
        )

        # Build inputs for the model
        inputs = phi_processor(text=prompt, audios=[(y, sr)], return_tensors="pt")
        inputs = {
            k: v.to(phi_model.device) if hasattr(v, "to") else v
            for k, v in inputs.items()
        }

        # Generate transcription without gradients
        with torch.no_grad():
            generated_ids = phi_model.generate(
                **inputs,
                eos_token_id=phi_processor.tokenizer.eos_token_id,
                max_new_tokens=256,  # Increased for longer transcriptions
                do_sample=False,
            )

        # Decode the generated token IDs into text
        transcription = phi_processor.decode(
            generated_ids[0, inputs["input_ids"].shape[1] :],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
        )

        logger.info(f"Phi-4 transcription completed successfully")
        return transcription
    except Exception as e:
        logger.error(f"Phi-4 transcription error: {str(e)}")
        raise


def preload_models() -> None:
    """
    Preload models into memory to reduce cold start time.
    This function can be called at application startup.
    """
    try:
        logger.info("Preloading models to reduce cold start time")
        # Load Whisper model first as it's the default
        load_model()
        # Then load Phi model
        load_phi_model()
        logger.info("All models preloaded successfully")
    except Exception as e:
        logger.error(f"Error during model preloading: {str(e)}")