import importlib
from types import SimpleNamespace

import gradio as gr
import pandas as pd

import spaces
import torch

from utmosv2.utils import get_dataset, get_model

description = (
    "# 🚀 UTMOSv2 demo\n\n"
    "[![GitHub](https://img.shields.io/badge/-GitHub-181717.svg?logo=github&style=flat)](https://github.com/sarulab-speech/UTMOSv2)\n\n"
    "This is a demonstration of MOS prediction using UTMOSv2. "
    "This demonstration only accepts `.wav` format. Best at 16 kHz sampling rate."
)

device = torch.device("cuda")
config = importlib.import_module("utmosv2.config.fusion_stage3")
cfg = SimpleNamespace(**{attr: getattr(config, attr) for attr in config.__dict__ if not attr.startswith("__")})
cfg.reproduce = False
cfg.config = "fusion_stage3"
cfg.print_config = False
cfg.data_config = None
cfg.phase = "inference"
cfg.num_workers = 1

@spaces.GPU
@torch.inference_mode()
def predict_mos(audio_path: str, domain: str, quick: bool) -> float:
    data = pd.DataFrame({"file_path": [audio_path]})
    data["dataset"] = domain
    data["mos"] = 0
    preds = 0.0
    for fold in range(5):
        cfg.now_fold = fold
        cfg.weight = f"models/fusion_stage3/fold{fold}_s42_best_model.pth"
        model = get_model(cfg, device).eval()
        for _ in range(5):
            test_dataset = get_dataset(cfg, data, "test")
            p = model(*[torch.tensor(t,dtype=torch.float32).unsqueeze(0).to(device) for t in test_dataset[0][:-1]])
            preds += p.cpu().numpy()[0][0]
            if quick:
                return preds
    preds /= 25.0
    return preds


with gr.Blocks() as demo:
    gr.Markdown(description)
    with gr.Row():
        with gr.Column():
            audio = gr.Audio(type="filepath", label="Audio")
            domain = gr.Dropdown(
                [
                    "sarulab",
                    "bvcc",
                    "somos",
                    "blizzard2008",
                    "blizzard2009",
                    "blizzard2010-EH1",
                    "blizzard2010-EH2",
                    "blizzard2010-ES1",
                    "blizzard2010-ES3",
                    "blizzard2011",
                ],
                label="Data-domain ID for the MOS prediction",
                value="sarulab",
            )
            quick = gr.Checkbox(
                label="Quick prediction",
                value=True,
                info=(
                    "UTMOSv2 makes predictions repeatedly for five randomly selected frames "
                    "of the input speech waveform for all five folds. "
                    "To make quick predictions by reducing this to a single repetition, "
                    "check this checkbox:",
                ),
            )
            submit = gr.Button(value="Submit")

        with gr.Column():
            output = gr.Textbox(label="Predicted MOS", type="text")
    submit.click(fn=predict_mos, inputs=[audio, domain, quick], outputs=[output])

demo.queue().launch()