import base64
import logging
import math
import tempfile
import time
from typing import Optional, Tuple

import fastapi
import jax.numpy as jnp
import numpy as np
import yt_dlp as youtube_dl
from jax.experimental.compilation_cache import compilation_cache as cc
from pydantic import BaseModel
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
from transformers.pipelines.audio_utils import ffmpeg_read

from whisper_jax import FlaxWhisperPipline

cc.initialize_cache("./jax_cache")
checkpoint = "openai/whisper-large-v3"

BATCH_SIZE = 32
CHUNK_LENGTH_S = 30
NUM_PROC = 32
FILE_LIMIT_MB = 10000
YT_LENGTH_LIMIT_S = 15000  # limit to 2 hour YouTube files

logger = logging.getLogger("whisper-jax-app")
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d %H:%M:%S")
ch.setFormatter(formatter)
logger.addHandler(ch)

pipeline = FlaxWhisperPipline(checkpoint, dtype=jnp.bfloat16, batch_size=BATCH_SIZE)
stride_length_s = CHUNK_LENGTH_S / 6
chunk_len = round(CHUNK_LENGTH_S * pipeline.feature_extractor.sampling_rate)
stride_left = stride_right = round(stride_length_s * pipeline.feature_extractor.sampling_rate)
step = chunk_len - stride_left - stride_right

# do a pre-compile step so that the first user to use the demo isn't hit with a long transcription time
logger.info("compiling forward call...")
start = time.time()
random_inputs = {
    "input_features": np.ones(
        (BATCH_SIZE, pipeline.model.config.num_mel_bins, 2 * pipeline.model.config.max_source_positions)
    )
}
random_timestamps = pipeline.forward(random_inputs, batch_size=BATCH_SIZE, return_timestamps=True)
compile_time = time.time() - start
logger.info(f"compiled in {compile_time}s")

app = fastapi.FastAPI()

class TranscriptionRequest(BaseModel):
    audio_file: str
    task: str = "transcribe"
    return_timestamps: bool = False

class TranscriptionResponse(BaseModel):
    transcription: str
    runtime: float

@app.post("/transcribe", response_model=TranscriptionResponse)
def transcribe_audio(request: TranscriptionRequest):
    logger.info("loading audio file...")
    if not request.audio_file:
        logger.warning("No audio file")
        raise fastapi.HTTPException(status_code=400, detail="No audio file submitted!")

    audio_bytes = base64.b64decode(request.audio_file)
    file_size_mb = len(audio_bytes) / (1024 * 1024)
    if file_size_mb > FILE_LIMIT_MB:
        logger.warning("Max file size exceeded")
        raise fastapi.HTTPException(
            status_code=400,
            detail=f"File size exceeds file size limit. Got file of size {file_size_mb:.2f}MB for a limit of {FILE_LIMIT_MB}MB.",
        )

    inputs = ffmpeg_read(audio_bytes, pipeline.feature_extractor.sampling_rate)
    inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
    logger.info("done loading")
    text, runtime = _tqdm_generate(inputs, task=request.task, return_timestamps=request.return_timestamps)
    return TranscriptionResponse(transcription=text, runtime=runtime)

@app.post("/transcribe_youtube")
def transcribe_youtube(
    yt_url: str, task: str = "transcribe", return_timestamps: bool = False
) -> Tuple[str, str, float]:
    logger.info("loading youtube file...")
    html_embed_str = _return_yt_html_embed(yt_url)
    with tempfile.TemporaryDirectory() as tmpdirname:
        filepath = os.path.join(tmpdirname, "video.mp4")
        _download_yt_audio(yt_url, filepath)

        with open(filepath, "rb") as f:
            inputs = f.read()

    inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate)
    inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
    logger.info("done loading...")
    text, runtime = _tqdm_generate(inputs, task=task, return_timestamps=return_timestamps)
    return html_embed_str, text, runtime

def _tqdm_generate(inputs: dict, task: str, return_timestamps: bool, progress: Optional[fastapi.ProgressBar] = None):
    inputs_len = inputs["array"].shape[0]
    all_chunk_start_idx = np.arange(0, inputs_len, step)
    num_samples = len(all_chunk_start_idx)
    num_batches = math.ceil(num_samples / BATCH_SIZE)

    dataloader = pipeline.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE)
    model_outputs = []
    start_time = time.time()
    logger.info("transcribing...")
    # iterate over our chunked audio samples - always predict timestamps to reduce hallucinations
    for batch, _ in zip(dataloader, range(num_batches)):
        model_outputs.append(pipeline.forward(batch, batch_size=BATCH_SIZE, task=task, return_timestamps=True))
    runtime = time.time() - start_time
    logger.info("done transcription")

    logger.info("post-processing...")
    post_processed = pipeline.postprocess(model_outputs, return_timestamps=True)
    text = post_processed["text"]
    if return_timestamps:
        timestamps = post_processed.get("chunks")
        timestamps = [
            f"[{_format_timestamp(chunk['timestamp'][0])} -> {_format_timestamp(chunk['timestamp'][1])}] {chunk['text']}"
            for chunk in timestamps
        ]
        text = "\n".join(str(feature) for feature in timestamps)
    logger.info("done post-processing")
    return text, runtime

def _return_yt_html_embed(yt_url: str) -> str:
    video_id = yt_url.split("?v=")[-1]
    HTML_str = (
        f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
        " </center>"
    )
    return HTML_str

def _download_yt_audio(yt_url: str, filename: str):
    info_loader = youtube_dl.YoutubeDL()
    try:
        info = info_loader.extract_info(yt_url, download=False)
    except youtube_dl.utils.DownloadError as err:
        raise fastapi.HTTPException(status_code=400, detail=str(err))

    file_length = info["duration_string"]
    file_h_m_s = file_length.split(":")
    file_h_m_s = [int(sub_length) for sub_length in file_h_m_s]
    if len(file_h_m_s) == 1:
        file_h_m_s.insert(0, 0)
    if len(file_h_m_s) == 2:
        file_h_m_s.insert(0, 0)

    file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2]
    if file_length_s > YT_LENGTH_LIMIT_S:
        yt_length_limit_hms = time.strftime("%HH:%MM:%SS", time.gmtime(YT_LENGTH_LIMIT_S))
        file_length_hms = time.strftime("%HH:%MM:%SS", time.gmtime(file_length_s))
        raise fastapi.HTTPException(
            status_code=400,
            detail=f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.",
        )

    ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best"}
    with youtube_dl.YoutubeDL(ydl_opts) as ydl:
        try:
            ydl.download([yt_url])
        except youtube_dl.utils.ExtractorError as err:
            raise fastapi.HTTPException(status_code=400, detail=str(err))

def _format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = "."):
    if seconds is not None:
        milliseconds = round(seconds * 1000.0)

        hours = milliseconds // 3_600_000
        milliseconds -= hours * 3_600_000

        minutes = milliseconds // 60_000
        milliseconds -= minutes * 60_000

        seconds = milliseconds // 1_000
        milliseconds -= seconds * 1_000

        hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
        return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
    else:
        # we have a malformed timestamp so just return it as is
        return seconds