import glob
import os
import shutil
import sys
import re
import tempfile
import zipfile
from pathlib import Path

import gradio as gr

from finetune import finetune_model, log

from language import languages
from task import tasks
import matplotlib.pyplot as plt


def load_markdown():
    with open("intro.md", "r") as f:
        return f.read()


def read_logs(temp_dir):
    if not os.path.exists(f"{temp_dir}/output.log"):
        return "Log file not found."
    try:
        with open(f"{temp_dir}/output.log", "r") as f:
            return f.read()
    except:
        return None


def plot_loss_acc(temp_dir, log_every):
    sys.stdout.flush()
    lines = []
    if not os.path.exists(f"{temp_dir}/output.log"):
        return None, None

    with open(f"{temp_dir}/output.log", "r") as f:
        for line in f.readlines():
            if re.match(r"^\[\d+\] - loss: \d+\.\d+ - acc: \d+\.\d+$", line):
                lines.append(line)
    
    losses = []
    acces = []
    if len(lines) == 0:
        return None, None
    
    for line in lines:
        _, loss, acc = line.split(" - ")
        losses.append(float(loss.split(":")[1].strip()))
        acces.append(float(acc.split(":")[1].strip()))
    
    x = [i * log_every for i in range(1, len(losses) + 1)]
    
    plt.plot(x, losses, label="loss")
    plt.xlim(log_every // 2, x[-1] + log_every // 2)
    plt.savefig(f"{temp_dir}/loss.png")
    plt.clf()
    plt.plot(x, acces, label="acc")
    plt.xlim(log_every // 2, x[-1] + log_every // 2)
    plt.savefig(f"{temp_dir}/acc.png")
    plt.clf()
    return f"{temp_dir}/acc.png", f"{temp_dir}/loss.png"


def upload_file(fileobj, temp_dir):
    """
    Upload a file and check the uploaded zip file.
    """
    # First check if a file is a zip file.
    if not zipfile.is_zipfile(fileobj.name):
        log(temp_dir, "Please upload a zip file.")
        raise gr.Error("Please upload a zip file.")

    # Then unzip file
    log(temp_dir, "Unzipping file...")
    shutil.unpack_archive(fileobj.name, temp_dir)

    # check zip file
    if not os.path.exists(os.path.join(temp_dir, "text")):
        log(temp_dir, "Please upload a valid zip file.")
        raise gr.Error("Please upload a valid zip file.")

    if not os.path.exists(os.path.join(temp_dir, "text_ctc")):
        log(temp_dir, "Please upload a valid zip file.")
        raise gr.Error("Please upload a valid zip file.")

    if not os.path.exists(os.path.join(temp_dir, "audio")):
        log(temp_dir, "Please upload a valid zip file.")
        raise gr.Error("Please upload a valid zip file.")

    # check if all texts and audio matches
    log(temp_dir, "Checking if all texts and audio matches...")
    audio_ids = []
    with open(os.path.join(temp_dir, "text"), "r") as f:
        for line in f.readlines():
            audio_ids.append(line.split(maxsplit=1)[0])

    with open(os.path.join(temp_dir, "text_ctc"), "r") as f:
        ctc_audio_ids = []
        for line in f.readlines():
            ctc_audio_ids.append(line.split(maxsplit=1)[0])

        if len(audio_ids) != len(ctc_audio_ids):
            raise gr.Error(
                f"Length of `text` ({len(audio_ids)}) and `text_ctc` ({len(ctc_audio_ids)}) is different."
            )

        if set(audio_ids) != set(ctc_audio_ids):
            log(temp_dir, f"`text` and `text_ctc` have different audio ids.")
            raise gr.Error(f"`text` and `text_ctc` have different audio ids.")

    for audio_id in glob.glob(os.path.join(temp_dir, "audio", "*")):
        if not Path(audio_id).stem in audio_ids:
            raise gr.Error(f"Audio id {audio_id} is not in `text` or `text_ctc`.")

    log(temp_dir, "Successfully uploaded and validated zip file.")
    gr.Info("Successfully uploaded and validated zip file.")

    return [fileobj]


def delete_tmp_dir(tmp_dir):
    if os.path.exists(tmp_dir):
        shutil.rmtree(tmp_dir)
        print(f"Deleted temporary directory: {tmp_dir}")
    else:
        print("Temporary directory already deleted")


def create_tmp_dir():
    tmp_dir = tempfile.mkdtemp()
    print(f"Created temporary directory: {tmp_dir}")
    return tmp_dir


with gr.Blocks(title="OWSM-finetune") as demo:
    tempdir_path=gr.State(create_tmp_dir, delete_callback=delete_tmp_dir, time_to_live=600)
    gr.Markdown(
        """# OWSM finetune demo!
Finetune `owsm_v3.1_ebf_base` with your own dataset!
Due to resource limitation, you can only train 5 epochs on maximum.
## Upload dataset and define settings
"""
    )

    # main contents
    with gr.Row():
        with gr.Column():
            file_output = gr.File()
            upload_button = gr.UploadButton("Click to Upload a File", file_count="single")
            upload_button.upload(
                upload_file, [upload_button, tempdir_path], [file_output]
            )

        with gr.Column():
            lang = gr.Dropdown(
                languages["espnet/owsm_v3.1_ebf_base"],
                label="Language",
                info="Choose language!",
                value="jpn",
                interactive=True,
            )
            task = gr.Dropdown(
                tasks["espnet/owsm_v3.1_ebf_base"],
                label="Task",
                info="Choose task!",
                value="asr",
                interactive=True,
            )

    gr.Markdown("## Set training settings")

    with gr.Row():
        with gr.Column():
            log_every = gr.Number(value=10, label="log_every", interactive=True)
            max_epoch = gr.Slider(1, 5, step=1, label="max_epoch", interactive=True)
            scheduler = gr.Dropdown(
                ["warmuplr"], label="warmup", value="warmuplr", interactive=True
            )
            warmup_steps = gr.Number(
                value=100, label="warmup_steps", interactive=True
            )

        with gr.Column():
            optimizer = gr.Dropdown(
                ["adam", "adamw", "sgd", "adadelta", "adagrad", "adamax", "asgd", "rmsprop"],
                label="optimizer",
                value="adam",
                interactive=True
            )
            learning_rate = gr.Number(
                value=1e-4, label="learning_rate", interactive=True
            )
            weight_decay = gr.Number(
                value=0.000001, label="weight_decay", interactive=True
            )

    gr.Markdown("## Logs and plots")

    with gr.Row():
        with gr.Column():
            log_output = gr.Textbox(
                show_label=False,
                interactive=False,
                max_lines=23,
                lines=23,
            )
            demo.load(read_logs, [tempdir_path], log_output, every=2)

        with gr.Column():
            log_acc = gr.Image(label="Accuracy", show_label=True, interactive=False)
            log_loss = gr.Image(label="Loss", show_label=True, interactive=False)
            demo.load(plot_loss_acc, [tempdir_path, log_every], [log_acc, log_loss], every=10)

    with gr.Row():
        with gr.Column():
            ref_text = gr.Textbox(
                label="Reference text",
                show_label=True,
                interactive=False,
                max_lines=10,
                lines=10,
            )
        with gr.Column():
            base_text = gr.Textbox(
                label="Baseline text",
                show_label=True,
                interactive=False,
                max_lines=10,
                lines=10,
            )

    with gr.Row():
        with gr.Column():
            hyp_text = gr.Textbox(
                label="Hypothesis text",
                show_label=True,
                interactive=False,
                max_lines=10,
                lines=10,
            )
        with gr.Column():
            trained_model = gr.File(
                label="Trained model",
                interactive=False,
            )
        
    with gr.Row():
        finetune_btn = gr.Button("Finetune Model", variant="primary")
        finetune_btn.click(
            finetune_model,
            [
                lang,
                task,
                tempdir_path,
                log_every,
                max_epoch,
                scheduler,
                warmup_steps,
                optimizer,
                learning_rate,
                weight_decay,
            ],
            [trained_model, ref_text, base_text, hyp_text]
        )

    gr.Markdown(load_markdown())

if __name__ == "__main__":
    try:
        demo.queue().launch()
    except:
        print("Unexpected error:", sys.exc_info()[0])
        raise
    finally:
        shutil.rmtree(os.environ['TEMP_DIR'])